# Conditional GAN

**Purpose**: This notebook provides a walk through the process of training a Conditional GAN to generate digits on the MNIST dataset.

## Package import

In [1]:
import numpy as np
import torch
import torchvision.datasets as datasets
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

## Global variable declaration

In [2]:
TRAIN_PARAMETERS = {'batch_size': 8,
                    'num_classes': 10,
                    'img_shape': (1,28,28),
                    'epochs': 200,
                    'learning_rate': 0.0002}

MODEL_HYPERPARAMETERS = {'generator_latent_dim': 100}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data loading

In [3]:
def get_dataloader(batch_size,
                   img_shape):
    
    img_size = img_shape[1:]
    
    dataset = datasets.MNIST(root='./data/MNIST',
                             train=True,
                             download=True,
                             transform=transforms.Compose([transforms.Resize(img_size),
                                                           transforms.ToTensor(),
                                                           transforms.Normalize([0.5], [0.5])]))
    
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=batch_size,
                                       shuffle=True)

In [4]:
dataloader = get_dataloader(TRAIN_PARAMETERS['batch_size'],
                            TRAIN_PARAMETERS['img_shape'])

## Model definition

In [5]:
class Generator(nn.Module):
    def __init__(self,
                 n_classes,
                 img_shape,
                 latent_dim):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(n_classes, n_classes)
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self,
                 n_classes,
                 img_shape):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(n_classes,
                                            n_classes)

        self.model = nn.Sequential(
            nn.Linear(n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

In [7]:
def get_generator_model(num_classes,
                        img_shape,
                        latent_dim,
                        lr,
                        device):
    ''' Returns the generator model and its optimizer '''
    
    generator = Generator(num_classes,
                          img_shape,
                          latent_dim)
    
    optimizer = torch.optim.Adam(generator.parameters(),
                                 lr=lr)
    
    return generator.to(device), optimizer

In [8]:
def get_discriminator_model(num_classes,
                            img_shape,
                            lr,
                            device):
    ''' Returns the discriminator model and its optimizer '''
    
    discriminator = Discriminator(num_classes,
                                  img_shape)
    
    optimizer = torch.optim.Adam(discriminator.parameters(),
                                 lr=lr)
    
    return discriminator.to(device), optimizer

In [9]:
generator, optimizer_G = get_generator_model(TRAIN_PARAMETERS['num_classes'],
                                             TRAIN_PARAMETERS['img_shape'],
                                             MODEL_HYPERPARAMETERS['generator_latent_dim'],
                                             TRAIN_PARAMETERS['learning_rate'],
                                             DEVICE)
discriminator, optimizer_D = get_discriminator_model(TRAIN_PARAMETERS['num_classes'],
                                                     TRAIN_PARAMETERS['img_shape'],
                                                     TRAIN_PARAMETERS['learning_rate'],
                                                     DEVICE)

In [10]:
def get_adversarial_loss():
    ''' Returns the adversarial loss '''
    
    return torch.nn.MSELoss()

In [11]:
adversarial_loss = get_adversarial_loss()

## Model training

In [12]:
def train_model(epochs,
                dataloader,
                generator,
                optimizer_G,
                discriminator,
                optimizer_D,
                adv_loss,
                generator_latent_dim,
                n_classes,
                device):
    ''' Trains the GAN model '''

    for epoch in range(epochs):
        
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
        
        running_g_loss = running_d_loss = 0
        
        for x, y in dataloader:
            
            x, y = x.to(device), y.to(device)
            
            batch_size = x.shape[0]
            
            valid = torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
            fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)
        
            ## ---------------
            ## Train generator
            ## ---------------
            
            optimizer_G.zero_grad()
            
            # Sample noise and labels
            z = torch.FloatTensor(np.random.normal(0, 1, (batch_size, generator_latent_dim))).to(device)
            gen_y = torch.LongTensor(np.random.randint(0, n_classes, batch_size)).to(device)
            gen_x = generator(z, gen_y)

            validity = discriminator(gen_x, gen_y)
            g_loss = adv_loss(validity, valid)
            

            g_loss.backward()
            optimizer_G.step()
            
            running_g_loss += g_loss.item()
            
            ## -------------------
            ## Train discriminator
            ## -------------------
            
            optimizer_D.zero_grad()
            
            validity_real = discriminator(x, y)
            d_real_loss = adv_loss(validity_real, valid)
            
            validity_fake = discriminator(gen_x.detach(), gen_y)
            d_fake_loss = adversarial_loss(validity_fake, fake)
            
            d_loss = (d_real_loss + d_fake_loss) / 2
            
            d_loss.backward()
            optimizer_D.step()
            
            running_d_loss += d_loss.item()
            
        epoch_g_loss = running_g_loss / len(dataloader)
        epoch_d_loss = running_d_loss / len(dataloader)

        print('G Loss: {:.4f} D Loss: {:.4f}'.format(epoch_g_loss,
                                                     epoch_d_loss))

In [13]:
train_model(TRAIN_PARAMETERS['epochs'],
            dataloader,
            generator,
            optimizer_G,
            discriminator,
            optimizer_D,
            adversarial_loss,
            MODEL_HYPERPARAMETERS['generator_latent_dim'],
            TRAIN_PARAMETERS['num_classes'],
            DEVICE)

Epoch 0/199
----------
G Loss: 0.7680 D Loss: 0.0767
Epoch 1/199
----------
G Loss: 0.5689 D Loss: 0.1437
Epoch 2/199
----------


KeyboardInterrupt: 