### Check TPU is available

The cell below makes sure you have access to a TPU on Kaggle.

In [None]:
import tensorflow as tf

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()


REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')
AUTO = tf.data.experimental.AUTOTUNE

### Generative Adersarial Networks (GANs)
In the landmark paper [Goodfellow et al.](https://arxiv.org/abs/1406.2661), published in 2014, authors introduced this novel paradigm for generative models. The fundamental idea proposed in the work is to train a Generator Network in adversarial setup, where a discriminator network downstream critiques the generated samples.

Simply put, generator network generates a sample and discriminator network classifies it as a real or fake. Discriminator is also provided with real samples. The objective functions takes the following form:

$$\underset{G}{\text{minimize}}\; \underset{D}{\text{maximize}}\; \mathbb{E}_{x \sim p_\text{data}}\left[\log D(x)\right] + \mathbb{E}_{z \sim p(z)}\left[\log \left(1-D(G(z))\right)\right]$$
where:
$x \sim p_\text{data}$ are samples from the input data. $z \sim p(z)$ are the random noise samples. $G(z)$ are the generated images using the neural network generator $G$, and $D$ is the output of the discriminator, specifying the probability of an input being real.

### Training Setup

This example is **part 1 of 3 series** for **training a GAN on TPU using Torch XLA package**. This notebook illustrates distributed (data parallel) training of DC-GAN model using MNIST dataset on a TPU device. **The notebook will lay the foundations for our Cycle GAN training on TPU and more**. A TPU device consists of 4 chips (8 cores; 2 cores/chip). Both the discriminator and generator replica are created on each of 8 cores. The dataset is splitted across the 8 cores.
At every training step, each of the cores perfoms the forward (loss computation) and backward (gradient computation) on the given minibatch and then all_reduce is performed across TPU cores to update the parameters. Notice xm.optimizer_step call in the discriminator and optimizer train steps.

General GAN training looks like:

* update the **generator** ($G$) to minimize the probability of the **discriminator making the correct choice**.
* update the **discriminator** ($D$) to maximize the probability of the **discriminator making the correct choice**.

We will use a different objective when we update the generator: maximize the probability of the **discriminator making the incorrect choice.** This small change helps to alleviate problems with the generator gradient vanishing when the discriminator is confident. This is the standard update used in most GAN papers, and was used in the original paper from [Goodfellow et al..](https://arxiv.org/abs/1406.2661)

Therefore the training loop in this notebook will entail:

 * Update the generator ($G$) to maximize the probability of the discriminator making the incorrect choice on generated data:$$\underset{G}{\text{maximize}}\;  \mathbb{E}_{z \sim p(z)}\left[\log D(G(z))\right]$$
 
* Update the discriminator ($D$), to maximize the probability of the discriminator making the correct choice on real and generated data:$$\underset{D}{\text{maximize}}\; \mathbb{E}_{x \sim p_\text{data}}\left[\log D(x)\right] + \mathbb{E}_{z \sim p(z)}\left[\log \left(1-D(G(z))\right)\right]$$

### Deep Convolutional Generative Adersarial Networks

DCGAN is one of the popular and successful network design for GAN. It mainly composes of convolution layers without max pooling or fully connected layers. It uses convolutional stride and transposed convolution for the downsampling and the upsampling. The figure below is the network design for the generator and discriminator,

![DC-GAN](https://gluon.mxnet.io/_images/dcgan.png)

Here is the summary of DCGAN:
   * Replace all `max pooling with convolutional stride`
   * Use `transposed convolution` for upsampling.
   * `Eliminate fully connected` layers.
   * Use `Batch normalization` except the output layer for the generator and the input layer of the discriminator.
   * Use `ReLU in the generator` except for the output which uses tanh.
   * Use `LeakyReLU in the discriminator`.
   
Here are the tuning tips quote directly from the paper.
> All models were trained with mini-batch stochastic gradient descent (SGD) with a mini-batch size of 128. All weights were initialized from a zero-centered Normal distribution with standard deviation 0.02. In the LeakyReLU, the slope of the leak was set to 0.2 in all models. While previous GAN work has used momentum to accelerate training, we used the Adam optimizer with tuned hyperparameters. We found the suggested learning rate of 0.001, to be too high, using 0.0002 instead. Additionally, we found leaving the momentum term β1 at the suggested value of 0.9 resulted in training oscillation and instability while reducing it to 0.5 helped stabilize training.

The simplicity of DCGAN contributes to its success. We reach certain bottleneck that increasing the complexity of the generator does not necessarily improve the image quality. Until we identify the bottleneck and know how to train GANs more effective, DCGAN remains a good start point for a new project.

### Setup Dependencies

#### Download Torch XLA nightly release

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly  --apt-packages libomp5 libopenblas-dev

### Training Setup

In [None]:
import torch
from torch import nn, optim
from torchvision import transforms, datasets
from torch.optim import Adam
import torch.nn.functional as F

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

### Setting up the Global Flags

In the current setup, Discriminator network was chosen to be a smaller capacity than generator. Even with similar capacity networks, generator update path is deeper than discriminator. Therefore uneven learning rates chosen here seems to yield a better convergence.

In [None]:
# Define Parameters
FLAGS = {}
FLAGS['datadir'] = "/tmp/mnist"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['gen_learning_rate'] = 0.005
FLAGS['disc_learning_rate'] = 0.001
FLAGS['num_epochs'] = 30
FLAGS['num_cores'] = 8  

### Data & Image Utilities

In [None]:
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
from IPython import display 
import cv2
    
RESULT_IMG_PATH = '/tmp/test_result.png'

def plot_results(*images):
    num_images = len(images)
    n_rows = 4
    n_columns =len(images) // n_rows
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(30, 18))

    for i, ax in enumerate(fig.axes):
        ax.axis('off') 
        if i >= num_images:
          continue
        img = images[i]
        img = img.squeeze() # [1,Y,X] -> [Y,X]
        ax.imshow(img)
    plt.savefig(RESULT_IMG_PATH, transparent=True)

def display_results():
    img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)
    plt.figure(figsize=(30,18))
    plt.imshow(img)

In [None]:
def mnist_data():
    compose = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    out_dir = '{}/dataset'.format(FLAGS['datadir'])
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

### Discriminator Model

In [None]:
class DiscriminativeNet(torch.nn.Module):
    
    def __init__(self):
        super(DiscriminativeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(4*4*64, 1)

    def forward(self, x):
        x = F.leaky_relu(F.max_pool2d(self.conv1(x), 2), 0.01)
        x = self.bn1(x)
        x = F.leaky_relu(F.max_pool2d(self.conv2(x), 2), 0.01)
        x = self.bn2(x)
        x = torch.flatten(x, 1)
        x = F.leaky_relu(self.fc1(x), 0.01)
        return torch.sigmoid(x)            
        

### Generator Model

In [None]:
class GenerativeNet(torch.nn.Module):
    
    def __init__(self):
        super(GenerativeNet, self).__init__()
        self.input_size = 100
        self.linear1 = nn.Linear(self.input_size, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.linear2 = nn.Linear(1024, 7*7*128)
        self.bn2 = nn.BatchNorm1d(7*7*128)
        self.conv1 = nn.ConvTranspose2d(
            in_channels=128, 
            out_channels=64, 
            kernel_size=4,
            stride=2, 
            padding=1, 
            bias=False
        )
        self.bn3 = nn.BatchNorm2d(64)
        self.conv2 = nn.ConvTranspose2d(
            in_channels=64, 
            out_channels=1, 
            kernel_size=4,
            stride=2, 
            padding=1, 
            bias=False
        )

    # Noise
    def generate_noise(self, size):
        n = torch.randn(size, self.input_size)
        return n 
              
    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn3(x)
        x = self.conv2(x)
        x = torch.tanh(x)
        return x

In [None]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.00, 0.02)

In [None]:
def real_data_target(size, device):
    data = torch.ones(size, 1)
    return data.to(device)

def fake_data_target(size, device):
    data = torch.zeros(size, 1)
    return data.to(device)

### Note on the use of .detach() function

You will notice in the following code snippet that when the generator is used to create the fake_data, **.detach() for the discriminator training step, the .detach call is used to create a new view of the fake_data tensor for which the operations will not be recorded for gradient computation.**

* Since fake_data is an output of an nn.module, by default, pytorch will record all the operations performed on this tensor during the forward pass as DAG. And after the backward pass these DAG and corresponding operations are cleared (unless retain_graph=True). Therefore such a tensor can be part of only one cone of logic where the forward and backward pass is done. If there are two loss function where this tensor is used and backward pass is performed on these two function (or even sum of the functions) for the second backward pass the operations DAG will not be found, leading to an error.

* The second place, where detach() is used is when a numpy() call is to be made to tensor (for plotting purposes). Pytorch also requires that requires_grad should not be true on these tensor. (Ref: RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.)


### Training 

In [None]:
SERIAL_EXEC = xmp.MpSerialExecutor()
# Only instantiate model weights once in memory.
generator = GenerativeNet()
generator.apply(init_weights)
descriminator = DiscriminativeNet()
descriminator.apply(init_weights)
WRAPPED_GENERATOR = xmp.MpModelWrapper(generator)
WRAPPED_DISCRIMINATOR = xmp.MpModelWrapper(descriminator)

In [None]:
def train_gan(rank):
    torch.manual_seed(1) 
    data = SERIAL_EXEC.run(lambda: mnist_data())
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    

    # Create loader with data, so that we can iterate over it
    train_loader = torch.utils.data.DataLoader(
      data,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

    # Num batches
    num_batches = len(train_loader)
    
    device = xm.xla_device()
    
    generator = WRAPPED_GENERATOR.to(device)
    discriminator = WRAPPED_DISCRIMINATOR.to(device)
   
    
    # Optimizers
    d_optimizer = Adam(discriminator.parameters(), lr=FLAGS['disc_learning_rate'], betas=(0.5, 0.999))
    g_optimizer = Adam(generator.parameters(), lr=FLAGS['gen_learning_rate'], betas=(0.5, 0.999))

    # Number of epochs
    num_epochs = FLAGS['num_epochs'] 
    # Loss function
    loss = nn.BCELoss()
    

    def train_step_discriminator(optimizer, real_data, fake_data, device):         
        # Reset gradients
        optimizer.zero_grad()

        # 1. Train on Real Data
        prediction_real = discriminator(real_data)
        # Calculate error and backpropagate
        error_real = loss(prediction_real, real_data_target(real_data.size(0), device))
        

        # 2. Train on Fake Data
        prediction_fake = discriminator(fake_data)
        # Calculate error and backpropagate
        error_fake = loss(prediction_fake, fake_data_target(real_data.size(0), device))
        
        total_error = error_real + error_fake
        total_error.backward()

        # Update weights with gradients
        xm.optimizer_step(optimizer)

        return total_error, prediction_real, prediction_fake

    def train_step_generator(optimizer, fake_data, device):
        # Reset gradients
        optimizer.zero_grad()
        prediction = discriminator(fake_data)
        # Calculate error and backpropagate
        error = loss(prediction, real_data_target(prediction.size(0), device))
        error.backward()
        # Update weights with gradients
        xm.optimizer_step(optimizer)

        # Return error
        return error

    # Notice the use of .detach() when fake_data is to passed into discriminator
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        for n_batch, (real_batch,_) in enumerate(loader):
            # Train Step Descriminator
            real_data = real_batch.to(device)
            # sample noise and generate fake data
            noise = generator.generate_noise(real_data.size(0)).to(device)
            fake_data = generator(noise)
            d_error, d_pred_real, d_pred_fake = train_step_discriminator(
                d_optimizer, real_data, fake_data.detach(), device)
            
            #Train Step Generator
            noise = generator.generate_noise(real_data.size(0)).to(device)
            fake_data = generator(noise)
            g_error = train_step_generator(g_optimizer, fake_data, device)
        return d_error.item(), g_error.item()


    for epoch in range(1, FLAGS['num_epochs'] + 1):
        d_error, g_error = train_loop_fn (pl.MpDeviceLoader(train_loader, device))
        xm.master_print("Finished training epoch {}: D_error:{}, G_error: {}".format(epoch, d_error, g_error))
        
        if epoch == FLAGS['num_epochs']:
            xm.master_print('Saving Model ..')
            xm.save(generator.state_dict(), "generator.bin")
            xm.save(discriminator.state_dict(), "discriminator.bin")
            xm.master_print('Model Saved.')
          
    num_test_samples = 100
    test_noise = generator.generate_noise(num_test_samples).to(device)
    xm.do_on_ordinals(plot_results, generator(test_noise).detach(), (0,))

In [None]:
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    torch.set_default_tensor_type('torch.FloatTensor')
    train_gan(rank)

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')

In [None]:
display_results()

### References

* This notebook is inspired from [Training DC-GAN using Colab Cloud TPU](https://github.com/pytorch/xla/blob/master/contrib/colab/DC-GAN.ipynb).

* The [Unsupervised representation learning with Deep convolutional generative adversarial networks](https://arxiv.org/pdf/1511.06434.pdf) paper.

* [[TPU Training] PyTorch nlp XLMRoberta](https://www.kaggle.com/rhtsingh/tpu-training-pytorch-nlp-xlmroberta)

##### More to come, stay tuned.