<a href="https://colab.research.google.com/github/Mainakdeb/digit-GAN/blob/main/digit-dcgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Made using guidelines from the DCGAN paper - https://arxiv.org/abs/1511.06434

### Define the Discriminator net :


In [1]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        #N x C x H x W
        nn.Conv2d(
            channels_img, features_d, kernel_size=4, stride=2, padding=1
        ),
        nn.LeakyReLU(0.2), #no batch_norm here
        self._block(features_d, features_d*2, 4, 2, 1),
        self._block(features_d*2, features_d*4, 4, 2, 1),
        self._block(features_d*4, features_d*8, 4, 2, 1),
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2), #params from paper
    )

  def forward(self,x):
    return self.disc(x)



### Define the Generator net :

In [2]:
class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super(Generator, self).__init__()
    self.net = nn.Sequential(
        self._block(channels_noise, features_g * 16, 4, 1, 0),  
        self._block(features_g * 16, features_g * 8, 4, 2, 1),  
        self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
        self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
        nn.ConvTranspose2d(
            features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        # Output: N x channels_img x 64 x 64
        nn.Tanh(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        #nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.net(x)

### Define a function to initialise model weights. 


In [3]:
def initialise_weights(model):
  #like the paper, mean=0, sd=0.02
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

### Make sure everything works until now :

In [4]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim=100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialise_weights(disc)
  assert disc(x).shape == (N,1,1,1)
  gen = Generator(z_dim, in_channels, 8)
  initialise_weights(gen)
  z = torch.randn(N, z_dim, 1, 1)
  assert gen(z).shape == (N, in_channels, H, W)
  print("********works*******")
test() 

********works*******


### Define training hyprparameters :

In [5]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE=2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG=1 
NOISE_DIM = 100 
Z_DIM=100 
NUM_EPOCHS=5
FEATURES_DISC=64
FEATURES_GEN=64

### All training images are resized to 64x64

In [6]:
transforms = transforms.Compose(
    [
     transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], 
                         [0.5 for _ in range(CHANNELS_IMG)])
    ]
)

### Define the dataset and the dataloader :

In [7]:
dataset = datasets.MNIST(root="/dataset/", train=True, transform=transforms, download=True)

loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /dataset/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /dataset/MNIST/raw/train-images-idx3-ubyte.gz to /dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Misc :

In [8]:
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

initialise_weights(gen)
initialise_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
step=0

gen.train()
disc.train()

#for tensorboard
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

### Initiate Tensorboard :

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

### Begin Training :

In [None]:
for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
    fake = gen(noise)

    #train discriminator
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).reshape(-1)
    loss_disc_fake =criterion(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real+loss_disc_fake)/2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    #train generator
    output = disc(fake).reshape(-1) 
    loss_gen = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    # Print losses occasionally and print to tensorboard
    if batch_idx % 100 == 0:
        print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
              Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
          fake = gen(fixed_noise)
          # take out (up to) 32 examples
          img_grid_real = torchvision.utils.make_grid(
              real[:32], normalize=True
          )
          img_grid_fake = torchvision.utils.make_grid(
              fake[:32], normalize=True
          )

          writer_real.add_image("Real", img_grid_real, global_step=step)
          writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1

Epoch [0/5] Batch 0/469               Loss D: 0.5502, loss G: 0.2296
Epoch [0/5] Batch 100/469               Loss D: 0.7400, loss G: 0.6152
Epoch [0/5] Batch 200/469               Loss D: 0.3406, loss G: 3.2650
Epoch [0/5] Batch 300/469               Loss D: 0.1070, loss G: 3.3198
Epoch [0/5] Batch 400/469               Loss D: 0.1275, loss G: 2.5441
Epoch [1/5] Batch 0/469               Loss D: 0.0697, loss G: 2.7750
Epoch [1/5] Batch 100/469               Loss D: 0.0840, loss G: 3.4144
Epoch [1/5] Batch 200/469               Loss D: 0.0826, loss G: 2.9037
Epoch [1/5] Batch 300/469               Loss D: 0.0897, loss G: 3.6982
Epoch [1/5] Batch 400/469               Loss D: 0.0892, loss G: 3.4266
Epoch [2/5] Batch 0/469               Loss D: 0.0533, loss G: 3.9370
Epoch [2/5] Batch 100/469               Loss D: 0.0261, loss G: 4.0631
Epoch [2/5] Batch 200/469               Loss D: 0.0401, loss G: 4.3643
Epoch [2/5] Batch 300/469               Loss D: 0.0360, loss G: 3.9056
Epoch [2/5] 

### Save models :

In [11]:
torch.save(gen.state_dict(), "generator_net.pt")
torch.save(disc.state_dict(), "discriminator_net.pt")

### Load existing model :

In [12]:
gen.load_state_dict(torch.load("generator_net.pt"))
disc.load_state_dict(torch.load("discriminator_net.pt"))

<All keys matched successfully>