<img src=https://storage.googleapis.com/kaggle-media/tpu/tpu_cores_and_chips.png >

# <center>Training GAN on 8-cores of TPU🔥using PyTorch/XLA<center>

# Table of contents <a id='0.1'></a>

1. [Introduction](#1)
2. [Import Packages](#2)
3. [Loading Data](#3)
4. [Build the GAN](#4)
5. [Train the GAN](#5)
6. [References](#6)

# 1. <a id='1'>Introduction</a>
[Table of contents](#0.1)

Hey folks. This Notebook will show you how to Train a Generative Adversarial Network (GAN) on **all the 8 cores of the TPU v3**.


> GANs are difficult to train.
> 
> The reason they are difficult to train is that both the generator model and the discriminator model are trained simultaneously in a zero sum game. This means that improvements to one model come at the expense of the other model.
> 
> The goal of training two models involves finding a point of equilibrium between the two competing concerns.
> 
> It also means that every time the parameters of one of the models are updated, the nature of the optimization problem that is being solved is changed. This has the effect of creating a dynamic system. In neural network terms, the technical challenge of training two competing neural networks at the same time is that they can fail to converge.

[source](https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/)


Lets, Dive in.

# 2. <a id='2'>Import Packages</a>
[Table of contents](#0.1)

In [None]:
import os
import gc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import PIL.Image as Image

import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.distributed import DistributedSampler

import torchvision
from torchvision import transforms
torch.manual_seed(42)

In [None]:
#Download & install PyTorch/XLA & Delete the files after installing.
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
!rm ./torch_xla-1.7-cp37-cp37m-linux_x86_64.whl
!rm ./torch-1.7-cp37-cp37m-linux_x86_64.whl
!rm ./torchvision-1.7-cp37-cp37m-linux_x86_64.whl
!rm ./pytorch-xla-env-setup.py

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

os.environ['XLA_USE_BF16'] = '1' #Setting this Environment variable allows TPU to use 'bfloat16'

# 3. <a id='3'>Loading Data</a>
[Table of contents](#0.1)

In [None]:
path = '../input/celeba-dataset/img_align_celeba/img_align_celeba'
df = pd.read_csv('../input/celeba-dataset/list_attr_celeba.csv').iloc[:,0]

In [None]:
class ImageDataset(Dataset):
    
    def __init__(self,path,df,image_transforms):
        super().__init__()
        self.path = path
        self.df   = df
        self.image_transforms = image_transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image = Image.open(os.path.join(self.path,self.df.iloc[idx]))
        if self.image_transforms:
            image = self.image_transforms(image)
            
        return image #We're only returning the image.Since this is a DCGAN,there is no need for labels.

In [None]:
image_transforms = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Normalizes the values b/w -1 & 1.
])

# 4. <a id='4'>Build the GAN</a>
[Table of contents](#0.1)

In [None]:
#Generator outputs fake images given an input noise vector.
class Generator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
                             
                    nn.ConvTranspose2d(100,64*8,4,1,0,bias=False),
                    nn.BatchNorm2d(64*8),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
                    nn.BatchNorm2d(64*4),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.ConvTranspose2d(64*4,64*2,4,2,1,bias=False),
                    nn.BatchNorm2d(64*2),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.ConvTranspose2d(64*2,64,4,2,1,bias=False),
                    nn.BatchNorm2d(64),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.ConvTranspose2d(64,3,4,2,1,bias=False),
                    nn.Tanh(),      
                )
    
    def forward(self,input): #Input is the Noise vector.
        return self.main(input)

In [None]:
#Discriminator classifies whether the given input images are Real or Fake.
class Discriminator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
                    
                    nn.Conv2d(3,64,4,2,1,bias=False),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.Conv2d(64,64*2,4,2,1,bias=False),
                    nn.BatchNorm2d(64*2),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.Conv2d(64*2,64*4,4,2,1,bias=False),
                    nn.BatchNorm2d(64*4),
                    nn.LeakyReLU(0.01,inplace=True),
            
                    nn.Conv2d(64*4,64*8,4,2,1,bias=False),
                    nn.BatchNorm2d(64*8),
                    nn.LeakyReLU(0.01,inplace=True),
                     
                    nn.Conv2d(64*8,1,4,1,0,bias=False),
                    nn.Sigmoid()
               )
    
    def forward(self,input): #Input is the images.
        return self.main(input)

# 5. <a id='5'>Train the GAN</a>
[Table of contents](#0.1)

In [None]:
def reduce(values):
    '''    
    Returns the average of the values.
    Args:
        values : list of any value which is calulated on each core 
    '''
    return sum(values) / len(values)

In [None]:
def train_one_epoch(dataloader,G,D,loss_fn,optimizer_G,optimizer_D,epoch_no,nb_epochs,fixed_noise,device):
        '''
        This function will train both the Generator and the Discriminator for one epoch.
        Args :
             dataloader - DataLoader object for iterating the data.
             G - Generator.
             D - Discriminator.
             loss_fn - Loss to optimize.
             optimizer_G - Optimizer for Generator.
             optimizer_D - Optimizer for Discriminator.
             epoch_no - Number for current epoch which is running.
             nb_epochs - Total Number of epochs to run.
             fixed_noise - Generator's input for outputting images for every 20 epochs.
             device - which device to train on.
        
         Returns : Nothing.
         '''
        
        for batch,imgs in enumerate(dataloader):
            
            real_imgs = imgs.to(device)
            bs = len(imgs) #Batch Size.
            real_labels = torch.ones((bs,),dtype=torch.float,device=device)
            fake_labels = torch.zeros((bs,),dtype=torch.float,device=device)
        
            #Update the Discriminator.
            noise = torch.rand(bs,100,1,1,device=device) #Noise for Generator's input.
            fake_imgs = G(noise) #Generated images.
        
            real_loss = loss_fn(D(real_imgs).view(-1),real_labels)
            fake_loss = loss_fn(D(fake_imgs.detach()).view(-1),fake_labels)
            loss_D = (real_loss+fake_loss)/2 #Discriminator loss.
        
            optimizer_D.zero_grad()
            loss_D.backward()
            xm.optimizer_step(optimizer_D)
        
            #Update the Generator.
            loss_G = loss_fn(D(fake_imgs).view(-1),real_labels) #Generator loss.
        
            optimizer_G.zero_grad()
            loss_G.backward()
            xm.optimizer_step(optimizer_G)
            
            if ((batch+1)%50 == 0):
                reduced_loss_D = xm.mesh_reduce('reduced_loss_D',loss_D,reduce)
                reduced_loss_G = xm.mesh_reduce('reduced_loss_D',loss_G,reduce)
                xm.master_print(f'Epoch[{epoch_no+1}/{nb_epochs}] Batch[{(batch+1)}/{len(dataloader)}] Discriminator Loss:{reduced_loss_D:.7f} Generator Loss:{reduced_loss_G}')
        
        if((epoch_no+1)%10 == 0):
            with torch.no_grad():
                torchvision.utils.save_image(G(fixed_noise),f'fake_images_after_{epoch_no+1}.jpg',8,normalize=True)

In [None]:
def _mp_fn(rank,flags):
    '''
    This function is executed on all the devices when it is spawned.
    Args :
        rank  - Index of the process.
        flags - Arguments you need to pass to each process.
    '''
    device = xm.xla_device()
    #Creates the (distributed) train sampler, which let this process only access its portion of the training data.
    data_sampler = DistributedSampler(dataset=flags['DS'],
                                      num_replicas=xm.xrt_world_size(),
                                      rank=xm.get_ordinal(),
                                      shuffle=True)
    data_loader = DataLoader(dataset=flags['DS'],
                             batch_size=flags['BS'],
                             sampler=data_sampler,
                             num_workers=0)
    del data_sampler
    gc.collect()
    
    G = flags['Generator'].to(device)
    D = flags['Discriminator'].to(device)
    fixed_noise = flags['fixed_noise'].to(device)
    loss_fn = nn.BCELoss()
    optimizer_G = torch.optim.Adam(G.parameters(),lr=flags['lr'],betas=[0.5,0.999]) #Optimizer for Generator.
    optimizer_D = torch.optim.Adam(D.parameters(),lr=flags['lr'],betas=[0.5,0.999]) #Optimizer for Discriminator.
    
    xm.master_print('Training has started\n')
    for epoch in range(flags['nb_epochs']):
        # Calling 'per_device_loader()' on it will return the data loader for the particular device.
        parallel_loader = pl.ParallelLoader(data_loader,[device]).per_device_loader(device)
        train_one_epoch(parallel_loader,G,D,loss_fn,optimizer_G,optimizer_D,epoch,flags['nb_epochs'],fixed_noise,device)
        xm.master_print(f'Epoch[{epoch+1}/{flags["nb_epochs"]}] has completed\n')
        del parallel_loader
        gc.collect()

In [None]:
dataset = ImageDataset(path,df,image_transforms)

G = Generator()
G.float()
D = Discriminator()
D.float()

fixed_noise = torch.rand((32, 100, 1, 1),dtype=torch.float)
flags = {'DS':dataset,
         'Generator': G,
         'Discriminator':D,
         'BS':128, #Batch size.
         'lr':0.0002,
         'nb_epochs':50,
         'fixed_noise':fixed_noise
        }

xmp.spawn(fn=_mp_fn,args=(flags,),nprocs=8,start_method='fork')

In [None]:
#Uncomment the below line of code to save the Generator's weights.
#xm.save(G.state_dict(),'generator_weights.pth')

At the Output section, you can look at the images generated for every 10 epochs.
Though, We're using DCGAN. Our Generator can able to generate some decent images matching the Training data.

If you're looking to generate High-Resolution images, then I'll suggest you to take a look at the other variants of GAN like StyleGAN.

# 6. <a id='6'>References</a>
[Table of contents](#0.1)

1. Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks - https://arxiv.org/pdf/1511.06434.pdf 
1. DCGAN Tutorial - https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
1. PyTorch Implementations of Generative Adversarial Networks - https://github.com/eriklindernoren/PyTorch-GAN
1. [FoldTraining] PyTorch-TPU🔥-8-Cores - https://www.kaggle.com/joshi98kishan/foldtraining-pytorch-tpu-8-cores