<a href="https://colab.research.google.com/github/Himagination/GANs/blob/main/Conditional_GANs_Handwritten_digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Objective

- Build Conditional Generative Adversial Networks to generate hand-written images of digits, conditioned on the digit to be generated (the class vector).
- There is no change in the architecters of Generator or Discriminator, only change is the data passed to both.
- The Generator will no longer take `z_dim` as an argument, but `input_dim` instead, since we need to pass in both the noise and class vectors.

## Import required packages

In [1]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

## Utility function

In [2]:
def show_tensor_images(image_tensor, 
                       num_images=25, 
                       size=(1, 28, 28), 
                       nrow=5, 
                       show=True):
  image_tensor = (image_tensor + 1) / 2
  image_unflat = image_tensor.detach().cpu()
  image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  if show:
    plt.show()

## Generator and Noise

In [3]:
class Generator(nn.Module):
  def __init__(self, 
               input_dim=10, 
               im_chan=1, 
               hidden_dim=64):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        self.make_gen_block(input_dim, hidden_dim * 4), 
        self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), 
        self.make_gen_block(hidden_dim * 2, hidden_dim), 
        self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True)
    )
  def make_gen_block(self, 
                     input_channels, 
                     output_channels, 
                     kernel_size=4, 
                     stride=2, 
                     final_layer=False):
    if not final_layer:
      return nn.Sequential(
          nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), 
          nn.BatchNorm2d(output_channels), 
          nn.ReLU(inplace=True)
      )
    else:
      return nn.Sequential(
          nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), 
          nn.Tanh()
      )
  def forward(self, noise):
    x = noise.view(len(noise), self.input_dim, 1, 1)
    return self.gen(x)

def get_noise(n_samples, input_dim, device="cpu"):
  return torch.randn(n_samples, input_dim, device=device)

## Discriminator

In [4]:
class Discriminator(nn.Module):
  def __init__(self, im_chan=1, hidden_dim=64):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        self.make_disc_block(im_chan, hidden_dim), 
        self.make_disc_block(hidden_dim, hidden_dim * 2), 
        self.make_disc_block(hidden_dim * 2, 1, final_layer=True)
    )
  def make_disc_block(self, 
                      input_channels, 
                      output_channels, 
                      kernel_size=4, 
                      stride=2, 
                      final_layer=False):
    if not final_layer:
      return nn.Sequential(
          nn.Conv2d(input_channels, output_channels, kernel_size, stride), 
          nn.BatchNorm2d(output_channels), 
          nn.LeakyReLU(0.2, inplace=True)
      )
    else:
      return nn.Sequential(
          nn.Conv2d(input_channels, output_channels, kernel_size, stride)
      )
  def forward(self, image):
    disc_pred = self.disc(image)
    return disc_pred.view(len(disc_pred), -1)

## Class Input

- In Conditional GANs, the input vector for the generator will also need to include the class information. 
- The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class.
- The vector is all 0's and 1 on the chosen class.


In [5]:
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
  return F.one_hot(labels, num_classes=n_classes)

**Concatenate the one-hot class vector to the noise vector before giving it to the generator.**

In [6]:
def combine_vectors(x, y):
  combined = torch.cat((x.float(), y.float()), 1)
  return combined

## Training

- **Parameters**
  - criterion: The loss function
  - n_epochs
  - z_dim: Dimension of the noise vector
  - display step
  - batch_size: Number of images per forward/backward pass
  - lr: Learning Rate
  - device

In [7]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
transform = transforms.Compose([
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, ), (0.5, ))
])

In [9]:
dataloader = DataLoader(
    MNIST('.', download=True, transform=transform), 
    batch_size=batch_size, 
    shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [10]:
mnist_shape = (1, 28, 28)
n_classes = 10

In [11]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
  generator_input_dim = z_dim + n_classes
  discriminator_im_chan = mnist_shape[0] + n_classes
  return generator_input_dim, discriminator_im_chan

## Initialization

In [12]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

**For training, both generator and discriminator should know what class of image should be generated.**

In [14]:
cur_step = 0
generator_losses = []
discriminator_losses = []
noise_and_labels = False
fake = False
fake_image_and_labels = False
real_image_and_lables = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
  for real, labels in tqdm(dataloader):
    cur_batch_size = len(real)
    # Flatten the batch of real images from dataset
    one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
    image_one_hot_labels = one_hot_labels[:, :, None, None]
    image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
    ## Update Discriminator
    # Zero out the discriminator gradients
    disc_opt.zero_grad()
    # Get noise corresponding to the current batch_size
    fake_noise = get_noise(cur_batch_size, z_dim, device=device)
    # Combine the noise vectors and the one-hot labels
    # Generate the conditioned fake images
    noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
    fake = gen(noise_and_labels)
    # Make sure that enough images were generated
    assert len(fake) == len(real)
    fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
    real_image_and_labels = combine_vectors(real, image_one_hot_labels)
    disc_fake_pred = disc(fake_image_and_labels.detach())
    disc_real_pred = disc(real_image_and_labels)
    # Make sure that enough predictions were made
    assert len(disc_real_pred) == len(real)
    # Make sure that the inputs are different
    assert torch.any(fake_image_and_labels != real_image_and_labels)
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    disc_loss.backward(retain_graph=True)
    disc_opt.step()
    # Keep track of the average discriminator loss
    discriminator_losses += [disc_loss.item()]

    ## Update generator ###
    # Zero out the generator gradients
    gen_opt.zero_grad()

    fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
    # This will error if you didn't concatenate your labels to your image correctly
    disc_fake_pred = disc(fake_image_and_labels)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    gen_loss.backward()
    gen_opt.step()

    # Keep track of the generator losses
    generator_losses += [gen_loss.item()]
    if cur_step % display_step == 0 and cur_step > 0:

      gen_mean = sum(generator_losses[-display_step:]) / display_step
      disc_mean = sum(discriminator_losses[-display_step:]) / display_step
      print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
      show_tensor_images(fake)
      show_tensor_images(real)
      step_bins = 20
      x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
      num_examples = (len(generator_losses) // step_bins) * step_bins
      plt.plot(
          range(num_examples // step_bins), 
          torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
          label="Generator Loss"
      )
      plt.plot(
          range(num_examples // step_bins), 
          torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
          label="Discriminator Loss"
      )
      plt.legend()
      plt.show()
    elif cur_step == 0:
        print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
    cur_step += 1


  0%|          | 0/469 [00:00<?, ?it/s]

AttributeError: ignored