In [1]:
! rm -rf Models Samples
! mkdir Models Samples


# Imports Required

* Torch - For coding neural networks
* Torchvision - Datasets and transforms
* Typing - Python type annotations
* Tqdm - Progress bar
* Matplotlib - Data plotting and visualization


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as T
import torchvision.datasets as datasets
from torchvision.utils import save_image, make_grid

from tqdm.auto import tqdm
from typing import Union
import matplotlib.pyplot as plt

plt.style.use('ggplot')
%matplotlib inline


# Utility functions

* Save-Network - For saving a network by creating a checkpoint
* Load-Network - Loads a network given a checkpoint
* Show-Samples - Displays / Saves image samples
* Get-Noise    - Sample vectors from a gaussian distribution of `mean=0 and variance=1`
* Weights-Init - Weight initialisation for the neural networks


In [9]:
def save_network(filename: str, network: nn.Module, optimizer: optim.Optimizer, **kwargs):
    checkpoint = {'network': network.state_dict(), 'optimizer': optimizer.state_dict()}
    for param in kwargs:
        checkpoint[param] = kwargs[param]
    print(f'-> Saving Model at {filename}')
    torch.save(checkpoint, filename)


def load_network(filename: str, network: nn.Module, optimizer: optim.Optimizer, lr: float, **kwargs):
    checkpoint = torch.load(filename)
    network.load_state_dict(checkpoint['network'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    meta_data = {}
    for param in kwargs:
        if checkpoint.get(param, None) is not None:
            meta_data[param] = checkpoint[param]

    return meta_data


def show_samples(image_tensors: torch.Tensor, shape: tuple, n_samples: int, n_row: int,
                 normalize: bool = True, save: bool = False, factor=None):
    # Detach, put on CPU and reshape the tensors
    image_tensors = image_tensors.detach().cpu().view(-1, *shape)
    images = image_tensors[:n_samples]

    grid = make_grid(images, n_row, padding=1, normalize=normalize)
    if save:
        save_image(grid, f'Samples/Sample-{factor}.png')
        return
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()
    
    
def conditional_noise(n_samples: int, noise_dim: int, labels: torch.Tensor,
                      n_classes: int = -1, device: Union[str, torch.device] = 'cpu'):
    noise = torch.randn(n_samples, noise_dim, device=device)
    labels = one_hot(labels.to(device), num_classes=n_classes)
    return torch.cat([noise, labels], dim=-1)


def channelize(images: torch.Tensor, labels: torch.Tensor,
               device: Union[str, torch.device] = 'cpu', n_classes: int = -1, image_shape: tuple = (28, 28)):
    images = images.to(device)
    labels = labels.to(device)
    labels = one_hot(labels, num_classes=n_classes)[:, :, None, None]
    channelized_labels = labels.repeat(1, 1, *image_shape)
    channelized_images = torch.cat((images, channelized_labels), dim=1)
    return channelized_images


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


# Utility `nn.Module` modules
 
* Lambda - A lambda layer in pytorch that runs a received function on the input data
* GeneratorBlock - A `nn.Sequential` module that does upsampling
* DiscriminatorBlock - A `nn.Sequential` module that does downsampling


In [None]:
class Lambda(nn.Module):
    def __init__(self, func):
        super(Lambda, self).__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


class GeneratorBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[tuple, int] = 3,
                 stride: Union[tuple, int] = 2, final_layer: bool = False):
        super(GeneratorBlock, self).__init__()
        if final_layer:
            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
                nn.Tanh()
            )
        else:
            self.block = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

    def forward(self, x):
        return self.block(x)


class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple] = 4,
                 stride: Union[int, tuple] = 2, final_layer: bool = False):
        super(DiscriminatorBlock, self).__init__()
        if final_layer:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride),
                nn.Sigmoid()
            )
        else:
            self.block = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(.2)
            )

    def forward(self, x):
        return self.block(x)
    


# Architecture guidelines for stable Deep Convolutional GANs

* Replace any pooling layers with strided convolutions (discriminator) and fractional-stridedconvolutions (generator).
* Use batchnorm in both the generator and the discriminator.
* Remove fully connected hidden layers for deeper architectures.
* Use ReLU activation in generator for all layers except for the output, which uses Tanh.
* Use LeakyReLU activation in the discriminator for all layers.



# Generator

* Takes in input noise (N, 100)
* Outputs images of size (N, 1, 28, 28)
* Uses TanH activation in the last layer


In [4]:
class Generator(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, hidden_dim: int = 64):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            Lambda(lambda x: x.view(-1, in_channels, 1, 1)),  # A 1x1 Hypercube
            GeneratorBlock(in_channels, hidden_dim * 4),
            GeneratorBlock(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            GeneratorBlock(hidden_dim * 2, hidden_dim),
            GeneratorBlock(hidden_dim, out_channels, kernel_size=4, final_layer=True)
        )

    def forward(self, noise):
        return self.generator(noise)


# Discriminator

* Receives an input of size (N, 1, 28, 28)
* Outputs a sigmoid activated unit stating how real the input is
* 0 - Fake, 1 - Real


In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels: int, hidden_dim: int = 64):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            DiscriminatorBlock(in_channels, hidden_dim),
            DiscriminatorBlock(hidden_dim, hidden_dim * 2),
            DiscriminatorBlock(hidden_dim * 2, 1, final_layer=True),
        )

    def forward(self, images: torch.Tensor):
        predictions = self.discriminator(images)
        return predictions.view(len(images), -1)


# Train-Fn - Function that trains the GAN


In [None]:
def train_fn(generator: Generator, discriminator: Discriminator,
             optimD: optim.Optimizer, optimG: optim.Optimizer, criterion: nn.Module,
             dataloader: DataLoader, device: Union[str, torch.device], noise_dim: int = 100):
    loop = tqdm(enumerate(dataloader), total=len(dataloader), leave=True)
    gen_loss, disc_loss = 0.0, 0.0
    for batch_idx, (images, labels) in loop:
        images = channelize(images, labels, device, 10, (28, 28))
        bs = len(images)
        noise = conditional_noise(bs, noise_dim, labels, 10, device)

        # Training the discriminator
        fakes = generator(noise)  # Generate fake samples
        fakes = channelize(fakes, labels, device, 10, (28, 28))
        real_predictions = discriminator(images)  # Discriminator predictions on real samples
        real_loss = criterion(real_predictions, torch.ones_like(real_predictions, device=device))  # Loss of Disc on real images
        fake_predictions = discriminator(fakes.detach())  # Discriminator predictions on fake samples
        fake_loss = criterion(fake_predictions, torch.zeros_like(fake_predictions, device=device))  # Loss of Disc on fake images

        lossD = (real_loss + fake_loss) / 2  # Complete loss of the discriminator
        # Updating the discriminator
        optimD.zero_grad()
        lossD.backward()
        optimD.step()

        # Training the generator
        fake_predictions = discriminator(fakes)
        lossG = criterion(fake_predictions, torch.ones_like(fake_predictions, device=device))
        # Updating the generator
        optimG.zero_grad()
        lossG.backward()
        optimG.step()

        gen_loss += lossG.item()
        disc_loss += lossD.item()

        loop.set_description(f'Step: [{batch_idx + 1}/{len(dataloader)}]')
        loop.set_postfix(generator_loss=gen_loss / (batch_idx + 1), discriminator_loss=disc_loss / (batch_idx + 1))

    return gen_loss / len(dataloader), disc_loss / len(dataloader), fakes



# Dataloading and transforming


In [8]:
tsfm = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=.5, std=.5)
])
bs = 128
dataset = datasets.MNIST('./Train', download=True, train=True, transform=tsfm)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./Train/MNIST/raw/train-images-idx3-ubyte.gz



KeyboardInterrupt: 


# Hyperparameters, model instantiation, optimizer instantiation and Loss Function

In [None]:
# Hyperparameters
n_epochs = 100
lr = 2e-4
noise_dim = 100
hidden_dim = 64
im_channels = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator(noise_dim, im_channels, hidden_dim).apply(weights_init)
discriminator = Discriminator(im_channels, hidden_dim).apply(weights_init)
generator = generator.to(device)
discriminator = discriminator.to(device)

optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr)
optimizer_gen = optim.Adam(generator.parameters(), lr=lr)
loss_fn = nn.BCELoss()


# The Training Loop

## Training Configuration

* BatchSize = 128
* Learning Rate = 2e-4
* Noise Dimension = 100
* Hidden Dim = Multiples of 64
* Optimizers = Adam for both Generator and Discriminator
* Loss Function = Binary Cross Entropy
* Saves fake samples for every epoch (Samples/)
* Saves both the models for every epoch (Models/)


In [None]:
gen_loss, disc_loss = [], []
print(f'Running on device: {device}')
for epoch in range(n_epochs):
    gen, disc, samples = train_fn(generator, discriminator, optimizer_disc, optimizer_gen, loss_fn, dataloader, device, noise_dim)
    
    samples = samples.detach().cpu()
    show_samples(samples, (1, 28, 28), 25, 5, save=True, factor=epoch)
    
    if epoch % 50 == 0:
        save_network(f'Models/Generator-{epoch}.pth.tar', generator, optimizer_gen, loss=gen, epoch=epoch)
    gen_loss.append(gen)
    disc_loss.append(disc)


# Loss plot


In [None]:
plt.plot(gen_loss)
plt.plot(disc_loss)
plt.legend(['Generator Loss', 'Discriminator Loss']),
plt.title('Generator Vs. Discriminator Loss')
plt.show()

In [None]:
from PIL import Image

image = Image.open('Samples/Sample-99.png')
plt.grid(False)
plt.axis('off')
plt.title('Fake image - 100 Epochs')
plt.imshow(image)
plt.show()

In [None]:
generator.eval()
inputs = torch.zeros(25, 100, device=device)
images = generator(inputs)
images = images.detach().cpu()
plt.title('Noise Vector - Zeros')
show_samples(images, (1, 28, 28), 25, 5, True)