## Deep Convolutional Generative Adversarial Network (DCGAN)

Paper: [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/pdf/1511.06434v2)

Helpful Resources:
- [Aladdin Persson's playlist on GANs](https://youtube.com/playlist?list=PLhhyoLH6IjfwIp8bZnzX8QR30TRcHO8Va&si=8ooImkbbXhCUC1xB)
- [GANs specialization on Coursera](https://www.coursera.org/specializations/generative-adversarial-networks-gans)
- [Stanford's Deep Generative Models playlist](https://youtube.com/playlist?list=PLoROMvodv4rPOWA-omMM6STXaWW4FvJT8&si=N_TpTe1bPIhte-t8)
- [AssemblyAI's GAN tutorial](https://youtu.be/_pIMdDWK5sc?si=Mtx2oWh1ZO9tqWYg)

This notebook just includes the implementation of the DCGAN model and its training loop. The results are not shown here.

Feel free to check the results on my Kaggle notebook: https://www.kaggle.com/code/aryamanbansal/dcgan

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

torch.manual_seed(0)

print("Imports done!")

Imports done!


In [2]:
def plot_images(img_tensor, num_imgs=25, size=(1,28,28)):
    """
    Given a tensor of images, number of images, and size per image, 
    this function plots and prints the images in a uniform grid.
    """
    img_unflat = img_tensor.detach().cpu().view(-1, *size)
    img_grid = make_grid(img_unflat[:num_imgs], nrow=5)
    plt.imshow(img_grid.permute(1,2,0).squeeze())
    plt.show()


In [3]:
def plot_results(results):
    """
    results is dictionary with keys: "gen_train_loss", "gen_test_loss", 
        "disc_train_loss", "disc_test_loss", "gen_train_acc", "gen_test_acc", 
        "disc_train_acc", "disc_test_acc".
    This function plots the train and test losses and accuracies.

    However, for now, we'll only plot the train losses for the generator and discriminator.
    """
    plt.plot(results["gen_train_loss"], label="Generator train loss")
    plt.plot(results["disc_train_loss"], label="Discriminator train loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()
    

Quoting from the DCGAN paper:

> Architecture guidelines for stable Deep Convolutional GANs:
> - Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
> - Use batchnorm in both the generator and the discriminator.
> - Remove fully connected hidden layers for deeper architectures.
> - Use ReLU activation in generator for all layers except for the output, which uses Tanh.
> - Use LeakyReLU activation in the discriminator for all layers.

In [4]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, img_channel=1, hidden_dim=64):
        """
        Parameters:
        - z_dim: the dimension of the noise vector, a scalar
        - img_channel: the number of channels of the output image, a scalar
            (MNIST is grayscale, so default value is img_channel=1)
        - hidden_dim: the inner dimension, a scalar
        """
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim*4),
            self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
            self.gen_block(hidden_dim*2, hidden_dim),
            self.gen_block(hidden_dim, img_channel, kernel_size=4, final_layer=True)
        )

    def gen_block(self, in_channel, out_channel, kernel_size=3, stride=2, 
                  final_layer=False):
        """
        Returns the layers of a generator block.

        Parameters:
        - in_channel: the number of channels in the input, a scalar
        - out_channel: the number of channels in the output, a scalar
        - kernel_size: the size of the kernel, a scalar
        - stride: the stride of the kernel, a scalar
        - final_layer: a boolean, True if this is the final layer and False otherwise
        """
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channel, out_channel, 
                                   kernel_size=kernel_size, stride=stride),
                nn.Tanh()
            )
        
    def forward(self, noise):
        """
        Given a noise tensor, returns the generated image.
        """
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)
    

In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_channel=1, hidden_dim=16):
        """
        Parameters:
        - img_channel: the number of channels of the input image, a scalar
            (MNIST is grayscale, so default value is img_channel=1)
        - hidden_dim: the inner dimension, a scalar
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.disc_block(img_channel, hidden_dim),
            self.disc_block(hidden_dim, hidden_dim*2),
            self.disc_block(hidden_dim*2, 1, final_layer=True)
        )

    def disc_block(self, in_channel, out_channel, kernel_size=4, stride=2,
                   final_layer=False):
          """
          Returns the layers of a discriminator block.
    
          Parameters:
          - in_channel: the number of channels in the input, a scalar
          - out_channel: the number of channels in the output, a scalar
          - kernel_size: the size of the kernel, a scalar
          - stride: the stride of the kernel, a scalar
          - final_layer: a boolean, True if this is the final layer and False otherwise
          """
          if not final_layer:
                return nn.Sequential(
                 nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                           stride=stride),
                 nn.BatchNorm2d(out_channel),
                 nn.LeakyReLU(0.2)
                )
          else:
                return nn.Sequential(
                 nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, 
                           stride=stride),
                nn.Sigmoid()
                )

    def forward(self, image):
        """
        Given an image tensor, returns a 1-dimension tensor 
        representing fake/real.
        Parameters:
            image: a flattened image tensor
        """
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)


Up till now, we observed changes in the architectures of the generator and the discriminator.

Let's now move onto the training loop and the loss functions.

In [6]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4           
z_dim = 64          # latent noise dimension
img_dim = 1         # 1 means grayscale image
batch_size = 128
num_epochs = 50
display_step = 500   # after how many steps to display loss

# These parameters control the optimizer's momentum:
# https://distill.pub/2017/momentum/
beta_1 = 0.5 
beta_2 = 0.999

In [10]:
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=True)
test_dataset = MNIST(root="dataset/", transform=transformations, download=True, train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
def fn():
    for item in train_loader:
        print(len(item))
        print(item[0].shape, item[1].shape)
        break

fn()

2
torch.Size([128, 1, 28, 28]) torch.Size([128])


In [9]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

# fixed_noise is the latent noise vector
# torch.randn generates random numbers from a normal distribution
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# separate optimizers for generator and discriminator
optim_disc = optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))
optim_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

criterion = nn.BCELoss()  # binary cross entropy loss

In [10]:
# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [11]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    """
    Returns the loss of the discriminator.
    Parameters:
        - gen: the generator model, which returns an image given 
               z-dimensional noise
        - disc: the discriminator model, which returns a single-dimensional 
                prediction of real/fake
        - criterion: the loss function, which should be used to compare 
                     the discriminator's predictions to the ground truth 
                     reality of the images (e.g. fake = 0, real = 1)
        - real: a batch of real images
        - num_images: the number of images the generator should produce, 
                      which is also the length of the real images
        - z_dim: the dimension of the noise vector, a scalar
        - device: the device type (eg: cuda or cpu)
    Returns:
        disc_loss: a torch scalar loss value for the current batch

    The following is the mathematical formula for the discriminator loss:
        max(log(D(x)) + log(1 - D(G(z))))
    """
    
    # 1) Create a noise vector and generate a batch (ie, num_images) of fake images.
    noise_vector = torch.randn(num_images, z_dim).to(device)  # z
    fake_images = gen(noise_vector)                           # G(z)

    # 2) Get the discriminator's prediction of the fake image 
    #    and calculate the loss. Don't forget to detach the generator!
    #    (Remember the loss function you set earlier -- criterion. You need a 
    #    'ground truth' tensor in order to calculate the loss. 
    #    For example, a ground truth tensor for a fake image is all zeros.)
    disc_fake_preds = disc(fake_images.detach())                   # D(G(z))
    disc_fake_loss = criterion(disc_fake_preds, 
                               torch.zeros_like(disc_fake_preds))  # log(1 - D(G(z)))
    
    # 3) Get the discriminator's prediction of the real image and calculate the loss.
    disc_real_preds = disc(real)                                   # D(x)
    disc_real_loss = criterion(disc_real_preds, 
                               torch.ones_like(disc_real_preds))   # log(D(x))

    # 4) Calculate the discriminator's loss by averaging the real and fake loss
    #    and set it to disc_loss.
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    
    return disc_loss


In [12]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    """
    Returns the loss of the generator.
    Parameters:
        - gen: the generator model, which returns an image given 
               z-dimensional noise
        - disc: the discriminator model, which returns a single-dimensional 
                prediction of real/fake
        - criterion: the loss function, which should be used to compare 
                     the discriminator's predictions to the ground truth 
                     reality of the images (e.g. fake = 0, real = 1)
        - num_images: the number of images the generator should produce, 
                      which is also the length of the real images
        - z_dim: the dimension of the noise vector, a scalar
        - device: the device type (eg: cuda or cpu)
    Returns:
        gen_loss: a torch scalar loss value for the current batch

    The following is the mathematical formula for the generator loss:
        max(log(D(G(z))))
    """

    # 1) Create noise vectors and generate a batch of fake images.
    noise_vector = torch.randn(num_images, z_dim).to(device)  # z
    fake_images = gen(noise_vector)                           # G(z)

    # 2) Get the discriminator's prediction of the fake image.
    disc_fake_preds = disc(fake_images)                       # D(G(z))

    # 3) Calculate the generator's loss. Remember the generator wants
    #    the discriminator to think that its fake images are real
    gen_loss = criterion(disc_fake_preds, 
                         torch.ones_like(disc_fake_preds))    # log(D(G(z)))

    return gen_loss


In [11]:
current_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
results = {
    "gen_train_loss": [],
    "disc_train_loss": [],
}

for epoch in tqdm(range(num_epochs)):
    
    # we iterate over the training dataloader
    # we only need the images, and not the labels
    for real_img, _ in train_loader:
        
        curr_batch_size = len(real_img)
        # No need to flatten the batch of real images,
        # as we're using DCGAN which uses convolutional layers
        real_img = real_img.to(device)

        # Update discriminator (Notice that we first train the discriminator)
        # Zero out the gradients before backpropagation
        optim_disc.zero_grad()
        # Calculate the discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real_img, curr_batch_size, z_dim, device)
        # Update gradients
        disc_loss.backward(retain_graph=True)  # we need to re-use the gradients for the generator
        # Update optimizer
        optim_disc.step()

        # Update generator
        # Zero out the gradients before backpropagation
        optim_gen.zero_grad()
        # Calculate the generator loss
        gen_loss = get_gen_loss(gen, disc, criterion, curr_batch_size, z_dim, device)
        # Update gradients
        gen_loss.backward()   # we have re-used the gradients for the generator, so no need to save the gradients
        # Update optimizer
        optim_gen.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item()
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item()

        # Visualization code
        if current_step % display_step == 0 and current_step > 0:
            mean_discriminator_loss = mean_discriminator_loss / display_step
            mean_generator_loss = mean_generator_loss / display_step
            results["gen_train_loss"].append(mean_generator_loss)
            results["disc_train_loss"].append(mean_discriminator_loss)
            print(f"Step {current_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
            fake_noise = torch.randn(curr_batch_size, z_dim).to(device)
            fake_img = gen(fake_noise)
            plot_images(fake_img)
            plot_images(real_img)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        current_step += 1
        

**A note about image generation using DCGAN**

During training, you see 2 grids of MNIST digits displayed after every epoch. The first one shows fake digits produced by the generator and the second one shows the real images. You might be tempted to think that the generator is trying to mimic the real images and is attempting to *copy* the real images. But this is not true. What is actually happening is that we feed a noise vector to the generator that causes it to generate the fake images. So, the digits displayed might be different as they're randomly being generated by the generator. 

Note: The type of GAN we're using is a very basic version of the different variants of GANs present today. It does not have the capability to generate an image given some prompt. At this stage, it can only generate random images, given some input noise vector.

In [12]:
plot_results(results)