# InfoGAN

![alt text](https://i.imgur.com/zAopCOi.png)

In [None]:
!pip install matplotlib numpy torch torchvision

In [None]:
from datetime import datetime
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In thic cell we initialize the data loaders for the MNIST dataset that will be used for provide data to training loop later.

Here, we also specify size of the minibatch.

In [None]:
batch_size = 50

train_data = datasets.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True, drop_last=True)

In this cell we specify dimensions of vectors for:
* `X_len` - linearized images (MNIST containes images of size $28 \times 28 = 784$)
* `z_len` - encoding vector. Here, 16 is used, but you can try smaller or larger (especially if images are of poor quality) vectors
* `c_len` - code vactor,  that we decided to be a categorical variable represented as one-hot vector of length 10

In [None]:
X_len = 28 * 28
z_len = 16
c_len = 10

Generator

Model of a Generator $G: (z, c) \rightarrow X$ - takes feature vecor $z$ sampled from unit gaussian distribution $z \sim \mathcal{N}(0, I)  \in \mathbb{R}^{h}$ and code one-hot vector $c \in \{0, 1\}^{\text{c_len}}$ (one-hot - has value `1` on exactly one position). Produces a linearized image $X \in [0, 1]^{784}$







In [None]:
class Generator(nn.Module):
  def __init__(self, X_dim, z_dim, c_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(z_dim + c_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, X_dim),
      torch.nn.Sigmoid()
    )
  
  def forward(self, z, c):
    zc = torch.cat([z, c], dim=1)
    X = self.model(zc)
    return X

Model of a Discriminator $D: X \rightarrow [0, 1]$ - module that should give high probability $p$ for samples from training dataset and low probability $p$ for generated samples.


In [None]:
class Discriminator(nn.Module):
  def __init__(self, X_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(X_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 1),
      nn.Sigmoid()
    )
    
  
  def forward(self, X):
    p = self.model(X)
    return p

Model of a mutual information enforcer $Q : X \rightarrow c $. Given the image $X \in [0, 1]^{784 \times 1} $ perform an embedding into one-hot vector $c \in [0, 1]^{|c|} $, that will maximize mutual information between representations.

In [None]:
class Q(nn.Module):
  def __init__(self, X_dim, c_dim):
    super().__init__()
    
    self.model = nn.Sequential(
      torch.nn.Linear(X_dim, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, c_dim),
      torch.nn.Softmax()
    )
  
  def forward(self, X):
    c = self.model(X)
    return c

In [None]:
G = Generator(X_dim=X_len, z_dim=z_len, c_dim=c_len).to(device)
D = Discriminator(X_dim=X_len).to(device)
Q_ = Q(X_dim=X_len, c_dim=c_len).to(device)

Weight initialization - we use the weight initialization from "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" by He et al.

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname in ('Conv1d', 'Linear'):
        torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

G = G.apply(weights_init)
D = D.apply(weights_init)
Q_ = Q_.apply(weights_init)

Define optimizers that will calculate optimization steps for our weights. Note, that Encoder and Generator share the same optimizer. Here we use Adam from "Adam: A Method for Stochastic Optimization" by Kingma et al.

In [None]:
learning_rate = 3e-4
betas = (0.9, 0.999)

G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
Q_optimizer = optim.Adam(chain(Q_.parameters(), G.parameters()), lr=learning_rate, betas=betas)

Losses:

$L_Q = H(X) - H(X|Y) \simeq -H(X|Y) = - \sum c*\log Q_{c|X}$

$L_G = \log(D(G(z, c))$

$L_D = \log(D(X)) + \log(1 - D(G(z, c)))$

In [None]:
def loss_fn_d(p_real, p_gen):
  eps = 1e-8
  return None

def loss_fn_g(p_gen):
  eps = 1e-8
  return None
  
def loss_fn_q(c, q_c_given_x):
  eps = 1e-8
  return None

In [None]:
max_epochs = 300
for epoch_n in range(1, max_epochs+1):
  
  D.train()
  G.train()
  Q_.train()

  d_losses = 0.0
  g_losses = 0.0
  q_losses = 0.0
  
  start = datetime.now()
  for i, (X, y) in enumerate(train_dataloader, 1):
    X = X.to(device)
    y = y.to(device)
    
    X = X.view(X.size(0), -1)
   
    z = None
    c = None
    c = None
    
    # Train discriminator 
    D_optimizer.zero_grad()
    
    X_ = None
  
    p_real = None
    p_gen = None
    
    loss_d  = -torch.mean(loss_fn_d(p_real, p_gen))

    loss_d.backward(retain_graph=True)
    d_losses += loss_d.item()
    D_optimizer.step()
    
    # Train generator
    G_optimizer.zero_grad()
    
    X_ = None
    p_gen = None
    
    loss_g = -torch.mean(loss_fn_g(p_gen))
    loss_g.backward(retain_graph=True)
    g_losses += loss_g.item()
    G_optimizer.step()
    
    # Train mutual information regularization
    Q_optimizer.zero_grad()
  
    
    X_ = None
    q_c_given_x = None
    
    loss_q = torch.mean(loss_fn_q(c, q_c_given_x))
    
    loss_q.backward()
    q_losses += loss_q.item()
    Q_optimizer.step()
    
   
  print(f'Epoch {epoch_n:03d}: Loss_G: {g_losses / i:.4f} Loss_Q: {q_losses / i:.4f} Loss_D: {d_losses / i:.4f}  Time: {datetime.now() - start}')
  
  n_samples = 3
  with torch.no_grad():
    fig, ax = plt.subplots(n_samples, c_len, figsize=(c_len, n_samples))
    fig.suptitle(f'Feature - epoch: {epoch_n}')
    for c_ in np.arange(c_len):
      for i in np.arange(n_samples): 
        z = torch.randn(1, z_len).to(device)
        c = torch.zeros(1, c_len).to(device)
        c[0, c_] = 1.0
        sample = G(z, c).to(device).view(28, 28).cpu().numpy()
        ax[i][c_].imshow(sample)
        ax[i][c_].axis('off')
    plt.show()