## In this notebook:
- Training of the Vanilla GAN, for more details refer to paper: https://arxiv.org/abs/1406.2661
- FID Metric. Calculate how good final generator with FID metric, for more details refer to paper: https://arxiv.org/abs/1706.08500

# Import libraries what we will use in this notebook

In [None]:
import matplotlib.pyplot as plt

import numpy as np
import struct
import pandas as pd
import os
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

# Define global constants

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Noise shape for generator
NOISE_SHAPE = 128
BATCH_SIZE = 60
# Image size and number of channels
H, W, C = (28, 28, 1)

# Its better then a number of images in dataset will be divided by batch size without remainder
# So, algorithm below will print this numbers
# By default batch size equal to 60 is okey
"""
for i in range(3, 100):
    if 60_000 % i == 0:
        print(i)
"""


# Read data from CSV file

### We will map data into numpy array for better usage

In [None]:
df = pd.read_csv('../input/mnist-in-csv/mnist_train.csv')
# Create np array from csv
df_as_np = np.asarray(df)
# Wrap images and labels
labels_mnist, data = (
    df_as_np[:, 0],                      # First row - labels
    df_as_np[:, 1:].reshape(-1, H, W, C) # Other rows - images
)

# Define image generator

### Define class for DataLoader in order to create image generator

In [None]:
# Define class with super-class Dataset
# We must implement two methods: getitem__ and __len__, 
#     __getitem__ - gives possibility to apply indexing for the instance of class FashionDataset
#     __len__ - gives possibility to take size of overall dataset
# This methods need in order to use DataLoader
class FashionDataset(Dataset):

    def __init__(self, data, transform = None, H = 28, W = 28, C = 1):
        self._images = np.asarray(data, dtype=np.float32).reshape(-1, H, W, C)
        self._transform = transform

    def __getitem__(self, index):
        image = self._images[index]
        if self._transform is not None:
            image = self._transform(image)
        return image
    
    def __len__(self):
        return len(self._images)

# Create instaince of data loader in order to load and create batches of data
# Also we can specify number of workers in loader which can speed up process of 
# preparing data. We leave it as it is, with default value.
# For more info refer to original docs.
train_set = FashionDataset(
    data, transform=transforms.Compose(
        # Transform data into Tensor that has values in a range from -1 to 1
        [transforms.ToTensor(), transforms.Normalize(128, 128)]
    ),
    H=H, W=W, C=C
)
# Create data loader
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

## Test train loader. Print batch of images

In [None]:
# Test loader
batch_d = next(iter(train_loader))
grid = torchvision.utils.make_grid(batch_d, nrow=10)

plt.figure(figsize=(15, 20))
plt.imshow(np.transpose(grid, (1, 2, 0)).numpy().astype(np.uint8))

# Define Models

### Define some utils for layers/models

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Define Generator model

In [None]:
class GeneratorNN(nn.Module):

    def __init__(self):
        super(GeneratorNN, self).__init__()
        
        self._model = nn.Sequential(
            nn.Linear(NOISE_SHAPE, 256),
            nn.BatchNorm1d(256, momentum=0.8, track_running_stats=False),
            nn.ReLU(inplace=False),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512, momentum=0.8, track_running_stats=False),
            nn.ReLU(inplace=False),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, momentum=0.8, track_running_stats=False),
            nn.ReLU(inplace=False),
            
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048, momentum=0.8, track_running_stats=False),
            nn.ReLU(inplace=False),
            
            nn.Linear(2048, H * W * C),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self._model(x)
        x = x.view(-1, C, H, W)
        return x

### Create instance of generator model and test it with noise

In [None]:
# Generator
gen_nn = GeneratorNN()
gen_nn.to(device=device)
# Init weights of the model with certain initialization
gen_nn.apply(weights_init)
# Turn on training mode
gen_nn.train()

In [None]:
# Check generator
arr = np.random.randn(BATCH_SIZE, NOISE_SHAPE).astype(np.float32)
res = gen_nn(torch.tensor(arr).to(device=device))
print(res.shape)
plt.imshow( ((res + 1.0) / 2.0)[0].cpu().detach().numpy().transpose(1, 2, 0)[..., 0])

## Define Discriminator model

In [None]:
class DiscriminatorNN(nn.Module):

    def __init__(self):
        super(DiscriminatorNN, self).__init__()

        self._net = nn.Sequential(
            nn.Linear(H * W * C, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Dropout(p=0.2),
            nn.Linear(512, 512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Dropout(p=0.3),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        b = x.shape[0]
        return self._net(x.view(b, -1))

### Create instance of discriminator model and test it with noise

In [None]:
# Discriminator
disc_nn = DiscriminatorNN()
disc_nn.to(device=device)
# Init weights of the model with certain initialization
disc_nn.apply(weights_init)
# Turn on training mode
disc_nn.train()

In [None]:
# Check discriminator on noise data
arr = np.random.randn(BATCH_SIZE, C, H, W).astype(np.float32)
print(arr.shape)
res = disc_nn(torch.tensor(arr, device=device))
res.cpu().detach().numpy()[:5]

In [None]:
# Check discriminator on real data from loaded test batch above
print(batch_d.shape)
res = disc_nn(torch.tensor(batch_d, device=device))
res.cpu().detach().numpy()[:5]

# Training

### Define class which control training of GAN. 
### Main method: `fit` function which start training of a GAN

In [None]:
class TrainGANController:
    
    def __init__(self, disc_nn, gen_nn, batch_size, device = None):
        self._disc_nn = disc_nn
        self._gen_nn = gen_nn
        self._batch_size = batch_size

        self._is_compiled = False
        self._opt_disc = None
        self._opt_gen = None
        self._loss = None
        self._device = device
    
    def compile(
            self, 
            lr_disc=2e-4, lr_gen=3e-4, 
            beta_params_disc=(0.5, 0.999), beta_params_gen=(0.5, 0.999)):
        # Init opt
        self._opt_disc = torch.optim.Adam(
            self._disc_nn.parameters(), lr=lr_disc, betas=beta_params_disc
        )
        self._opt_gen = torch.optim.Adam(
            self._gen_nn.parameters(), lr=lr_gen, betas=beta_params_gen
        )
        # Losses
        self._loss = nn.BCELoss().to(device=self._device)
        # Set flag, in order to start train
        self._is_compiled = True
    
    def train_step_disc(self, real_data, real_label=0.9, fake_label=0.0):
        # Set real label equal to 0.9 in order to use "Label smoothing"
        # Discriminator can produce better gradients, then this technique is used
        # For more detailы about label smoothing you can find in the internet 
        
        # For easy access
        device = self._device
        # Train step for discriminator
        # Zero grads
        self._disc_nn.zero_grad()
        # Forward pass for real data
        label = torch.full((self._batch_size,), real_label, dtype=torch.float, device=device)
        fake = torch.full((self._batch_size,), fake_label, dtype=torch.float, device=device)
        # Generate fake stuf
        noise = torch.randn(self._batch_size, NOISE_SHAPE, device=device)
        generated_imgs = self._gen_nn(noise)
        # Forward pass real batch through D
        errD_real = self._loss(self._disc_nn(real_data).view(-1), label)
        # Forward pass fake batch through D
        errD_fake = self._loss(self._disc_nn(generated_imgs.detach()).view(-1), fake)
        errD = (errD_fake + errD_real) / 2.0
        errD.backward()
        self._opt_disc.step()
        return errD.cpu().detach().numpy()

    def train_step_gen(self, fake_label=1.0):
        # For easy access
        device = self._device
        # Train step for generator
        # Zero grads
        self._gen_nn.zero_grad()
        # fake labels are real for generator cost
        label = torch.full((self._batch_size,), fake_label, dtype=torch.float, device=device)
        # Since we just updated D, perform another forward pass of all-fake batch through D
        # Generate batch of latent vectors
        noise = torch.randn(self._batch_size, NOISE_SHAPE, device=device)
        # Generate fake image batch with G
        generated_imgs = self._gen_nn(noise)
        # Calculate G's loss based on this output
        errG = self._loss(self._disc_nn(generated_imgs).view(-1), label)
        # Calculate gradients for G
        errG.backward()
        # Update G
        self._opt_gen.step()
        return errG.cpu().detach().numpy()

    def fit(self, data_gen, epoch: int, print_it: int = 400):
        for i_e in range(epoch):
            for ii_it, single_data in enumerate(data_gen):
                single_data = single_data.to(device=self._device)
                # Train discriminator
                err_d = self.train_step_disc(single_data)
                # Train generator
                err_g = self.train_step_gen()
                if ii_it % print_it == 0:
                    print(f'epoch: {i_e+1}/{epoch}, it: {ii_it}/{len(data_gen)}'
                          f'|| Loss G: {err_g}, Loss D: {err_d}'
                    )

### Create instance and compile controller

In [None]:
t_gan_c = TrainGANController(disc_nn, gen_nn, BATCH_SIZE, device=device)
t_gan_c.compile()

## Start training

In [None]:
t_gan_c.fit(train_loader, epoch=25)

# Generate digits with trained model

In [None]:
def visualise_sheets_of_images(
    images, prefix_name, unique_index=0,
    show_images=False, subplot_size=(10, 10),
    figsize=(20, 20),use_BGR2RGB=False, use_grey=False):
    """
    Plot sheets of images. Usually used for generated images from GANs.
    Parameters
    ----------
    images : list or np.ndarray
        List of images that should be plotted.
    prefix_name : str
        Prefix name for file with sheets of images.
    unique_index : int
        Unique number for name of file which consist of sheets of images,
        usually this params used for showing at which epoch this result is.
    show_images : bool
        If true, sheets of images will be plotted.
    subplot_size : tuple
        Size of raw and columns. For more detail, see plt docs.
    figsize : tuple
        Size of figure. For more detail, see plt docs.
    use_BGR2RGB : bool
        If true, `images` will be converted into RGB format (if they have BGR format).
    use_grey : bool
        If true, `images` will be plotted as black-white images.
    
    """
    plt.figure(figsize=figsize)
    for z in range(min(len(images), subplot_size[0] * subplot_size[1])):
        plt.subplot(*subplot_size, z + 1)
        if use_BGR2RGB:
            plt.imshow(cv2.cvtColor(images[z], cv2.COLOR_BGR2RGB))
        elif use_grey:
            plt.imshow(images[z], cmap='gray')
        else:
            plt.imshow(images[z])
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(f'{prefix_name}_{unique_index}.png')
    if show_images:
        plt.show()

    plt.close('all')

In [None]:
# Check generator
arr = np.random.randn(BATCH_SIZE, NOISE_SHAPE).astype(np.float32)
gen_nn.eval()
res = gen_nn(torch.tensor(arr).to(device=device))
# Unnormed images and plot big figure
res = ((res + 1.0) / 2.0).cpu().detach().numpy().transpose(0, 2, 3, 1)
visualise_sheets_of_images(res, "generated_digits", show_images=True, use_grey=True)

# Calculate accuracy with FID metric

## Import libraries and define constants

In [None]:
from scipy.linalg import sqrtm
from sklearn.utils import shuffle
import cv2
from tqdm import tqdm
from torchvision.models import inception_v3

# Number of images taken and generated
# In order to estimate generator with FID metric
N_IMAGES = 10_000

### Define some useful methods

#### Scale list of images into certain shape

In [None]:
def scale_images(images, new_shape):
    """
    Scale an array of images to a new size
    
    Parameters
    ----------
    images : list
        List of images. Each image have shape - (C, H_old, W_old)
        Where:
            C - color dimension of the image;
            H_old - height of the image;
            W_old - width of the image.
    new_shape : list or tuple
        (H, W), Height and Width of the result image
    
    Return
    ------
    list
        List of images with shape equal to `new_shape`
    
    """
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = np.transpose(image, (1, 2, 0)) # (C, H, W) --> (H, W, C)
        new_image = cv2.resize(new_image, new_shape, interpolation = cv2.INTER_NEAREST)
        if len(new_image.shape) == 2 or new_image.shape[-1] == 1:
            new_image = cv2.cvtColor(new_image, cv2.COLOR_GRAY2BGR)
        new_image = np.transpose(new_image, (2, 0, 1)) # (H, W, C) --> (C, H, W)
        # store
        images_list.append(new_image)
    return np.asarray(images_list)

#### Collect predictions from InceptionV3
Collect data using certain batch size in order to save memory

In [None]:
def calculate_fid_batched(model_inception, images1, images2, batch_size=128):
    assert len(images1) == len(images2)
    n_batches = len(images1) // batch_size
    preds1 = []
    preds2 = []
    for i in tqdm(range(n_batches)):
        batch_img1 = images1[i*batch_size: (i+1)*batch_size]
        batch_img2 = images2[i*batch_size: (i+1)*batch_size]
        # Resize images
        resized_b_img1 = scale_images(batch_img1, (299, 299))
        resized_b_img2 = scale_images(batch_img2, (299, 299))
        # Normalize images
        resized_b_img1 -= np.array([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
        resized_b_img1 /= np.array([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
        # Run though inception v3 and take prediction
        act1 = model_inception(torch.tensor(resized_b_img1, device=device)).squeeze().cpu().detach().numpy()
        act2 = model_inception(torch.tensor(resized_b_img2, device=device)).squeeze().cpu().detach().numpy()
        preds1.append(act1)
        preds2.append(act2)
    act1 = np.concatenate(preds1, axis=0)
    act2 = np.concatenate(preds2, axis=0)
    
    return act1, act2

#### Calculate FID with predictions from InceptioV3

In [None]:
def calculate_fid(act1, act2):
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

#### Load InceptionV3 and calculate FID

In [None]:
# prepare the inception v3 model
model = inception_v3(pretrained=True)
model.eval()
# Remove certain layers from output
layer_names = []
for layer in list(model.children()):
    if layer.__class__.__name__ not in ['InceptionAux', 'Linear', 'Dropout']:
        layer_names.append(layer)
# Create model without some layers
model = nn.Sequential(*layer_names)
model.eval()
model.to(device=device)
# Define two batches of images
# First - real data
images1 = shuffle(data)[:N_IMAGES]
images1 = np.transpose(np.asarray(images1), (0, 3, 1, 2)).astype(np.float32)
# Images1 in range [0, 255], normalize into [0, 1]
images1 /= 255.0
images1 = torch.tensor(images1, device=device).cpu().detach().numpy()

images2_noise = torch.tensor(
    np.random.randn(N_IMAGES, NOISE_SHAPE).astype(np.float32),
    device=device
)
gen_nn.eval()
images2 = gen_nn(images2_noise).cpu().detach().numpy()
# Generator generate images in range (-1, 1), normalize into [0, 1] range
images2 += 1.0
images2 /= 2.0
print('Prepared', images1.shape, images2.shape)
# Calculate FID with batch size
# fid between images1 and images1
act1, act2 = calculate_fid_batched(model, images1, images2)
fid_same = calculate_fid(act1, act1)
fid = calculate_fid(act1, act2)


print('FID (same): %.3f' % fid_same)
# fid between images1 and images2
print('FID (different): %.3f' % fid)

# Save model

In [None]:
torch.save(gen_nn.state_dict(), 'model.pth')

In order to download final model - click link below.

<h1><a href="model.pth"> Download trained generator </a></h1>