# Conditional GAN

![alt text](https://i.imgur.com/jgtlRHS.png)

In [None]:
!pip install torch torchvision matplotlib

In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In thic cell we initialize the data loaders for the MNIST dataset that will be used for provide data to training loop later.

Here, we also specify size of the minibatch.

NOTE: To keep our network simple and training time short, we take only samples of classes 0, 1 and 2

In [None]:
batch_size = 50

train_data = datasets.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.ToTensor())

classes = [0, 1, 2]
filter_labels = [i for i, l in enumerate(train_data.train_labels) if l in classes]

train_data.train_data = train_data.train_data[filter_labels]
train_data.train_labels = train_data.train_labels[filter_labels]

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True, drop_last=True)

In this cell we specify dimensions of vectors for:
* `X_len` - linearized images (MNIST containes images of size $28 \times 28 = 784$)
* `z_len` - encoding vector. Here, 64 is used, but you can try smaller or larger vectors
* `c_len` - code vactor, the same length as the number of classes that are in our dataset

In [None]:
X_len = 28 * 28
z_len = 64
c_len = 3

Model of a Generator $G: (z, c) \rightarrow X$ - takes feature vecor $z$ sampled from unit gaussian distribution $z \sim \mathcal{N}(0, I)  \in \mathbb{R}^{h}$ and code one-hot vector $c \in \{0, 1\}^{\text{c_len}}$ (one-hot - has value `1` on exactly one position). Produces a linearized image $X \in [0, 1]^{784}$







In [None]:
class Generator(nn.Module):
  def __init__(self, X_dim, c_dim, z_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(z_dim + c_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, X_dim),
      torch.nn.Sigmoid()
    )
  
  def forward(self, z, c):
    zc = torch.cat([z, c], dim=1)
    X = self.model(zc)
    return X

Model of a Discriminator $D: (X, c) \rightarrow [0, 1]$ - module that (given the code vector) should give high probability $p$ for samples from training dataset and low probability $p$ for generated samples.


In [None]:
class Discriminator(nn.Module):
  def __init__(self, X_dim, c_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(X_dim + c_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 1),
      torch.nn.Sigmoid()
    )
  
  def forward(self, X, c):
    Xc = torch.cat([X, c], dim=1)
    p = self.model(Xc)
    return p

In [None]:
G = Generator(X_dim=X_len, c_dim=c_len, z_dim=z_len).to(device)
D = Discriminator(X_dim=X_len, c_dim=c_len).to(device)

Weight initialization - we use the weight initialization from "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" by He et al.

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname in ('Conv1d', 'Linear'):
        torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

G = G.apply(weights_init)
D = D.apply(weights_init)

Define optimizers that will calculate optimization steps for our weights. Note, that Encoder and Generator share the same optimizer. Here we use Adam from "Adam: A Method for Stochastic Optimization" by Kingma et al.

In [None]:
learning_rate = 3e-4
betas = (0.9, 0.999)

G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)

Labels for Discriminator: 

1 - true sample

0 - generated samples

Labels for Generator:

1 - generated samples

In [None]:
ones = torch.ones(batch_size, 1).to(device)
zeros = torch.zeros(batch_size, 1).to(device)

Losses:

$L_D = \text{bce}(D(X, c), 1) + \text{bce}(D(G(z, c), c), 0)$

$L_G = \text{bce}(D(G(z, c), c), 1)$

where `bce` - binary cross-entropy, defined in `torch.nn` module, `1` - vector of ones, `0` - vector of zeros.

In [None]:
def loss_fn_g(p_gen):
  return None

def loss_fn_d(p_real, p_gen):
  return None

Training procedure for Conditional GAN - fill the training steps, that are currently `None`

In [None]:
max_epochs = 300
for epoch_n in range(1, max_epochs+1):
  
  D.train()
  G.train()

  d_losses = 0.0
  g_losses = 0.0
  
  start = datetime.now()
  for i, (X, c) in enumerate(train_dataloader, 1):
    X = X.to(device)
    c = c.to(device)
    
    X = X.view(X.size(0), -1)
    
    # convert c to one-hot representation
    c_one_hot = None
   
    z = None
    X_ = None
  
    p_real = None
    p_gen = None
    
    D_optimizer.zero_grad()
    D.zero_grad()
    loss_d  = loss_fn_d(p_real, p_gen)
    loss_d.backward(retain_graph=True)
    d_losses += loss_d.item()
    D_optimizer.step()
    
    
    G_optimizer.zero_grad()
    G.zero_grad()
    loss_g = loss_fn_g(p_gen)
    loss_g.backward()
    g_losses += loss_g.item()
    G_optimizer.step()
   
  print(f'Epoch {epoch_n:03d}: Loss_G: {g_losses / i:.4f} Loss_D: {d_losses / i:.4f}  Time: {datetime.now() - start}')
  
  
  # Visualize learing
  n_samples = 10
  with torch.no_grad():
    fig, ax = plt.subplots(c_len, n_samples, figsize=(10, c_len))
    fig.suptitle(f'Conditional samples: {epoch_n}')
    for c in range(c_len):
      c_one_hot = torch.zeros((n_samples, c_len)).to(device)
      c_one_hot[torch.arange(n_samples), c] = 1
      
      z = torch.randn(n_samples, z_len).to(device) 

      samples = G(z, c_one_hot).view(n_samples, 28, 28).cpu().numpy()
      for i, sample in enumerate(samples):
        ax[c][i].imshow(sample)
        ax[c][i].axis('off')
    plt.show()