<a href="https://colab.research.google.com/github/Mainakdeb/digit-GAN/blob/main/digit-dcgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        #N x C x H x W
        nn.Conv2d(
            channels_img, features_d, kernel_size=4, stride=2, padding=1
        ),
        nn.LeakyReLU(0.2), #no batch_norm here
        self._block(features_d, features_d*2, 4, 2, 1),
        self._block(features_d*2, features_d*4, 4, 2, 1),
        self._block(features_d*4, features_d*8, 4, 2, 1),
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2), #params from paper
    )

  def forward(self,x):
    return self.disc(x)



In [2]:
class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super(Generator, self).__init__()
    self.net = nn.Sequential(
        self._block(channels_noise, features_g * 16, 4, 1, 0),  
        self._block(features_g * 16, features_g * 8, 4, 2, 1),  
        self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
        self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
        nn.ConvTranspose2d(
            features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        # Output: N x channels_img x 64 x 64
        nn.Tanh(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        #nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.net(x)

In [3]:
def initialise_weights(model):
  #like the paper, mean=0, sd=0.02
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [4]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim=100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialise_weights(disc)
  assert disc(x).shape == (N,1,1,1)
  gen = Generator(z_dim, in_channels, 8)
  initialise_weights(gen)
  z = torch.randn(N, z_dim, 1, 1)
  assert gen(z).shape == (N, in_channels, H, W)
  print("********works*******")
test() 

********works*******


In [4]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE=2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG=1 #mnist bw images
NOISE_DIM = 100
Z_DIM=100
NUM_EPOCHS=5
FEATURES_DISC=64
FEATURES_GEN=64

In [None]:
transforms = transforms.Compose(
    [
     transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], 
                         [0.5 for _ in range(CHANNELS_IMG)])
    ]
)