# Conditional GAN

**Purpose**: This notebook provides a walk through the process of training a Conditional GAN to generate digits on the MNIST dataset. Refer to the paper https://arxiv.org/abs/1411.1784 for a full detailed explanation.

## Package import

In [None]:
!nvidia-smi

In [None]:
import numpy as np
import os
import torch

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import torchvision.datasets as datasets
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from matplotlib.pylab import plt
from torchvision.utils import save_image
from torch.utils.data import DataLoader

## Global variable declaration

In [None]:
TRAIN_PARAMETERS = {'batch_size': 128,
                    'num_classes': 10,
                    'img_shape': (1,28,28),
                    'epochs': 200,
                    'learning_rate': 0.0002}

MODEL_HYPERPARAMETERS = {'generator_latent_dim': 100}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data loading

In [None]:
def get_dataloader(batch_size,
                   img_shape):
    
    img_size = img_shape[1:]
    
    dataset = datasets.MNIST(root='./data/MNIST',
                             train=True,
                             download=True,
                             transform=transforms.Compose([transforms.Resize(img_size),
                                                           transforms.ToTensor(),
                                                           transforms.Normalize([0.5], [0.5])]))
    
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=batch_size,
                                       shuffle=True)

In [None]:
dataloader = get_dataloader(TRAIN_PARAMETERS['batch_size'],
                            TRAIN_PARAMETERS['img_shape'])

## Model definition

Generative Adversarial Nets consist on two components competing against each other in a min-max game. These models are:

- The **Generator:** Captures the data distribution and tries to generate realistic samples accordingly.
- The **Discriminator:** Estimates the probability of a sample of data as being real or created by the Generator.

Both models play a min max game; the objective of the generator is to fool the discriminator by generating more realistic samples. On the other hand, the discriminator's objective is to identify the samples created by the generator.

The input of the **Generator** is typically a noise vector **z**, as to create variety in the generated samples.

The input of the **Discriminator** is a sample of data **x**, which can be a real sample or a generated sample.

Translating these min max game to mathematics, we adjust the Generator's parameters to minimize $log(1-D(G(z))$ and the Discriminator's parameters to minimize $log(D(x))$. Thus, resulting in the following formula:

![title](imgs/gan_training.png)

The **Conditional GAN** framework adds to both, the input of the discriminator and the generator, some extra information **y** (note that the aforementioned formula will need to include the condition on **y**). 

This extra information represent the class of the sample. It will guide the Generator in its creation and help the Discriminator in its prediction.

The model architecture is depicted in the next figure. As it can be seen in the Figure, the generator concatenates the sampling from **z** with the input **y**. Likewise, the discriminator concatenates **x** with the input **y**.

![title](imgs/conditional_gan_architecture.png)

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 n_classes,
                 img_shape,
                 latent_dim):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(n_classes, n_classes)
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=False):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 256, normalize=False),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self,
                 n_classes,
                 img_shape):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(n_classes,
                                            n_classes)

        self.model = nn.Sequential(
            nn.Linear(n_classes + int(np.prod(img_shape)), 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

In [None]:
def get_generator_model(num_classes,
                        img_shape,
                        latent_dim,
                        lr,
                        device):
    ''' Returns the generator model and its optimizer '''
    
    generator = Generator(num_classes,
                          img_shape,
                          latent_dim)
    
    optimizer = torch.optim.Adam(generator.parameters(),
                                 lr=lr)
    
    return generator.to(device), optimizer

In [None]:
def get_discriminator_model(num_classes,
                            img_shape,
                            lr,
                            device):
    ''' Returns the discriminator model and its optimizer '''
    
    discriminator = Discriminator(num_classes,
                                  img_shape)
    
    optimizer = torch.optim.Adam(discriminator.parameters(),
                                 lr=lr)
    
    return discriminator.to(device), optimizer

In [None]:
generator, optimizer_G = get_generator_model(TRAIN_PARAMETERS['num_classes'],
                                             TRAIN_PARAMETERS['img_shape'],
                                             MODEL_HYPERPARAMETERS['generator_latent_dim'],
                                             TRAIN_PARAMETERS['learning_rate'],
                                             DEVICE)
discriminator, optimizer_D = get_discriminator_model(TRAIN_PARAMETERS['num_classes'],
                                                     TRAIN_PARAMETERS['img_shape'],
                                                     TRAIN_PARAMETERS['learning_rate'],
                                                     DEVICE)

In [None]:
def get_adversarial_loss():
    ''' Returns the adversarial loss '''
    
    return torch.nn.BCEWithLogitsLoss()

In [None]:
adversarial_loss = get_adversarial_loss()

## Model training

As mentioned earlier, the training of the model is an **adversarial process** in which the Generator and the Discriminator are trained simultaneously.

The adversarial nature of the process is due to the fact that the Discriminator is trained to distinguish between real data and fake data, therefore, to recognize the Generator's flaws while the Generator is trained to fool the Discriminator with its generated samples.

Nonetheless, this process is **highly unstable**. Some of the problems that may occur are:

- **Vanishing gradients**: When the Discriminator gets too good, its loss becomes very close to 0, which blocks the gradient to flow into the generator and prevents it to learn.
- **Mode collapse**: It is sometimes possible that the Generator find specific samples that can fool the Discriminator all the time. When this happens, the Generator starts to create only these samples, and looses its variety.
- **Failure to converge**: If the Loss function is not well designed or the weights update are too large, the two models might not converge to any good solution.

In [None]:
def save_reconstruction(generator,
                        generator_latent_dim,
                        n_row,
                        epoch_n,
                        device):
    """Saves a grid of generated digits ranging from 0 to n_classes"""

    # Create the saving directory
    os.makedirs('results', exist_ok=True)
    # Sample noise
    z = torch.FloatTensor(np.random.normal(0, 1, (n_row ** 2, generator_latent_dim))).to(device)
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.LongTensor(labels).to(device)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "results/%d.png" % epoch_n, nrow=n_row, normalize=True)
    
    return plt.imread("results/%d.png" % epoch_n)

def init_figure():
    """Init interactive figure with 3 subplots for the notebook graphics"""
    fig = plt.figure(figsize=(10, 12))
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(212)

    return fig, (ax1, ax2, ax3)


def plot_current_results(image, model_loss, fig, axes):
    """Use the visuals dict to fill the fig and axes"""
    ax1, ax2, ax3 = axes

    ax1.clear()
    ax2.clear()
    ax3.clear()

    epochs = range(1, len(model_loss['generator'])+1)
    
    ax1.plot(epochs, model_loss['generator'])
    ax1.set_title('Generator Loss')

    ax2.plot(epochs, model_loss['discriminator'])
    ax2.set_title('Discriminator Loss')

    ax3.imshow(image)
    ax3.set_xticks([])
    ax3.set_yticks([])
    plt.show()

In [None]:
def plot_loss(loss_dict):
    ''' Plots the loss evolution of the discriminator and generator '''
    
    for component_name, component_loss in loss_dict.items():
        
        plt.plot(component_loss, label=component_name)
    
    plt.title('Loss plot')
    plt.legend()
    plt.show()

## Excercise:

- As an exercise, try to complete the code for computing the loss of both the discriminator and generator

In [None]:
def train_model(epochs,
                dataloader,
                generator,
                optimizer_G,
                discriminator,
                optimizer_D,
                adv_loss,
                generator_latent_dim,
                n_classes,
                device):
    ''' Trains the GAN model '''
    
    model_loss = {'generator': [], 'discriminator': []}
    
    for epoch in range(epochs):
        
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
        
        running_g_loss = running_d_loss = 0
        
        for x, y in dataloader:
            
            x, y = x.to(device), y.to(device)
            
            batch_size = x.shape[0]
            
            valid = torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
            fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)
        
            
            ## -------------------
            ## Train discriminator
            ## -------------------
            
            optimizer_D.zero_grad()
            
            z = torch.FloatTensor(np.random.normal(0, 1, (batch_size, generator_latent_dim))).to(device)
            gen_y = torch.LongTensor(np.random.randint(0, n_classes, batch_size)).to(device)
            gen_x = generator(z, gen_y)
            
            validity_fake = ... # Implement me!
            validity_real = ... # Implement me!
            d_real_loss = ... # Implement me!
            d_fake_loss = ... # Implement me!
            
            d_loss = ... # Implement me!
            
            d_loss.backward()
            optimizer_D.step()
            
            ## ---------------
            ## Train generator
            ## ---------------
            
            optimizer_G.zero_grad()
            
            # Sample noise and labels
            z = torch.FloatTensor(np.random.normal(0, 1, (batch_size, generator_latent_dim))).to(device)
            gen_y = torch.LongTensor(np.random.randint(0, n_classes, batch_size)).to(device)
            gen_x = generator(z, gen_y)

            validity = ... # Implement me!
            g_loss = ... # Implement me!

            g_loss.backward()
            optimizer_G.step()
            
            running_g_loss += g_loss.item()
            
            running_d_loss += d_loss.item()
            
        epoch_g_loss = running_g_loss / len(dataloader)
        epoch_d_loss = running_d_loss / len(dataloader)
        
        print('G Loss: {:.4f} D Loss: {:.4f}'.format(epoch_g_loss,
                                                     epoch_d_loss))
        
        model_loss['generator'].append(epoch_g_loss)
        model_loss['discriminator'].append(epoch_d_loss)
        
        if epoch % 1 == 0:
            generator.eval()

            image = save_reconstruction(generator,
                                generator_latent_dim,
                                n_classes,
                                epoch,
                                device)
            fig, axs = init_figure()
            plot_current_results(image, model_loss, fig, axs)
            generator.train()
    
    plot_loss(model_loss)

In [None]:
train_model(TRAIN_PARAMETERS['epochs'],
            dataloader,
            generator,
            optimizer_G,
            discriminator,
            optimizer_D,
            adversarial_loss,
            MODEL_HYPERPARAMETERS['generator_latent_dim'],
            TRAIN_PARAMETERS['num_classes'],
            DEVICE)

We just implemented a Conditional GAN using the Binary Cross Entropy loss! 

GANs are an intensely studied topic nowadays, and scientists have designed a lot of different versions of GANs. One of them in particular, called [LSGAN](https://arxiv.org/abs/1611.04076) will be of interests for the future of the workshop.

Without going into details now, the particularity of LSGAN is the use of Mean Squared Error Loss instead of the Binary Cross Entropy Loss.

### Excercise

- As an exercise, implement LSGAN by substituting the Binary Cross Entropy with Mean Squared Error Loss.

In [None]:
adversarial_loss = ... #Implement me!
train_model(TRAIN_PARAMETERS['epochs'],
             dataloader,
             generator,
             optimizer_G,
             discriminator,
             optimizer_D,
             adversarial_loss,
             MODEL_HYPERPARAMETERS['generator_latent_dim'],
             TRAIN_PARAMETERS['num_classes'],
             DEVICE)