# Info GAN - MNIST Dataset

In this notebook we are going to implement a fairly complicated GAN, called InfoGAN. This architecture is based on the paper, [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657) by Chen et al.

The main problem this architecture is seeking to fix is the disentanglement issue, and using this approach is more widely used and known in community.

The interpretation of how InfoGAN work in high-level language is like this; we separate the model into two parts: 1. First part corresponding to truly random noise, 2. Second part is corresponding to the something we call "latent code".

The latent code can be thought as a "hidden" condition in a conditional generator, we actually want this to have an interpretable meaning.

It's the most likely that you are wondering how do we get the latent code, a set of random number to be more interpretable than any dimension in a GAN? The answer to this question is "Mutual Information". We want each dimension of the latent code to be as obvious a function as possible of generated images. We won't go any deeper than this, if you want to read more about this matter, I suggest read the reference paper and Information Entropy.

The implementation of InfoGAN is much like before, generator is just like previous models, but the discriminators will undergo some changes. To be more specific, it will be modified in a way so more dimensions are present in its output.

Architecture of this network is:

Generator:
1. Noise Vector: (64 + 2)
2. Block 1: [ConvTranspose(66, 256)] -> Batch Normalization (256) -> ReLU
3. Block 2: [ConvTranspose(256, 128), Filter: 4, Stride: 1] -> Batch Normalization (128) -> ReLU
4. Block 3: [ConvTranspose(128, 64)] -> Batch Normalization (64) -> ReLU
5. Block 4: [ConvTranspose(64, 1)] -> TanH

Discriminator:
1. Image: (28, 28, 1)*
2. Block 1: [Conv2D(1, 64)] -> Batch Normalization (64) -> LeakyReLU (0.2)
3. Block 2: [Conv2D(64, 128)] -> Batch Normalization (128) -> LeakyReLU (0.2)
4. D Layer: [Conv2D(128, 1)]
5. Q Layer 1: [Conv2D(128, 128)] -> Batch Normalization (128) -> LeakyReLU (0.2)
6. Q Layer 2: [Conv2D(128, 4), Filter: 1]


*: In mathematics notation, the channel layer is presented as the third dimension but in tensor processing libraries it's presented as the first dimension of a cube. i.e. (1, 28, 28)

In [2]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

def generate_noise(n_samples, noise_dim, device = 'cpu'):
    '''
    A Helper function for creating random noise vectors with the dimension of: (n_samples, noise_dim)
    random numbers are from normal distribution
    Input ->
        n_samples: number of samples to generate (row)
        noise_dim: dimension of nouse vector (column)
        device: device type
    '''
    return torch.randn(n_samples, noise_dim, device=device)

def combine_vectors(x, y):
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
    Input ->
      x: (n_samples, ?) the first vector. 
        This will be the noise vector of shape (n_samples, z_dim).
      y: (n_samples, ?) the second vector.
        Once again, in this example this will be the one-hot class vector 
        with the shape (n_samples, n_classes).
    '''
    combined = torch.cat([x.float(), y.float()], 1)
    return combined

## Generator

In [3]:
class Generator(nn.Module):
    '''
    This class is for generator.
    Inputs ->
        noise_dim: dimension of noise vector.
        image_channel: number of channels in images,(Since MNIST is black and white images have 1 channel.)
        hidden_dim: inner dimension of networks.
    '''
    def __init__(self, noise_dim=10, image_channel=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, hidden_dim*4),
            nn.BatchNorm2d(hidden_dim * 4),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.ReLU(inplace=True)
        )
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim, image_channel, kernel_size=4),
            nn.Tanh()
        )
    
    def forward(self, noise):
        '''
        forward pass of generator.
        Input ->
            noise: noise tensor with shape of (number of samples, noise_dim)
        Output ->
            generated image
        '''
        out = noise.view(len(noise), self.noise_dim, 1, 1)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return out

## Discriminator

As you can see in the beginning of this notebook, the final layer of discriminator has 4 channels as output instead of one. The reason behind this fact is that we have to predict a distribution for $c$ from $x$. Since we are assuming a normal prior, we can output a mean and a log-variance prediction.

In [4]:
class Discriminator(nn.Module):
    '''
    This class is for Dsicriminator
    Input ->
        image_channel: number of channels in images,(Since MNIST is black and white images have 1 channel.)
        hidden_dim: inner dimension of networks.
        c_dim: number of latent code dimensions
    '''
    def __init__(self, image_channel=1, hidden_dim=64, c_dim=10):
        super(Discriminator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(image_channel, hidden_dim),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(.2, inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim * 2),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.LeakyReLU(.2, inplace=True)
        )
        self.d_layer = nn.Conv2d(hidden_dim * 2, 1)
        self.q_layer1 = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim * 2),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.LeakyReLU(.2, inplace=True)
        )
        self.q_layer2 = nn.Conv2d(hidden_dim * 2, c_dim * 2, kernel_size=1)
    
    def forward(self, image):
        '''
        Forward pass of discriminator
        Input ->
            image: flattened image tensor with dimension of (image_dim)
        Output ->
            returns a 1 dimension tensor representing whtether input image is generated or original.
        '''
        out = self.block1(image)
        intermed_pred = self.block2(out)
        disc_pred = self.d_layer(intermed_pred)
        q_pred = self.q_layer1(intermed_pred)
        q_pred = self.q_layer2(q_pred)
        return disc_pred.view(len(disc_pred), -1), q_pred.view(len(q_pred), -1)