<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/gan/gan-mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Generative Adversarial Networks for MNIST

In [None]:
%%capture
!pip install pytorch-lightning==1.3.8 torchviz wandb
!git clone https://github.com/wandb/lit_utils
!cd "/content/lit_utils" && git pull

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn
import torch.nn.functional as F
import torchvision
import wandb

import lit_utils as lu

# remove slow mirror from list of MNIST mirrors
torchvision.datasets.MNIST.mirrors = lu.datamodules.ClassificationMNIST.mirrors

lu.utils.filter_warnings()
lu.datamodules.mnist.reverse_palette = lambda img: img  # work-around to turn off palette-reversion

# GAN Module


In [None]:
class LitGAN(lu.nn.modules.LoggedLitModule):
  """A basic image GAN in PyTorch Lightning.

  Also includes some under-the-hood Weights & Biases logging.

  Instantiates a generator and discriminator based on the provided config dictionary,
  defines separate optimizers for each, and defines logging, loss, and forward
  pass logic.

  NOTE: training_step is defined a few cells down, rather than inside the class.
  """

  def __init__(self, config): 
    super().__init__()

    # hyperparameters
    self.image_size = config["image_size"]
    self.latent_dim = config["latent_dim"]
    self.loss = config["loss"]

    # networks
    self.generator = Generator(config)
    self.generator_optim_config = config["generator.optim"]
    self.discriminator = Discriminator(config)
    self.discriminator_optim_config = config["discriminator.optim"]

    # for logging purposes
    self.log_interval = config["log_interval"]
    self.logged_images = 8
    self.logged_metadata = False
    # keep a fixed set of locations in latent space around for logging
    self.validation_z = torch.randn(self.logged_images, self.latent_dim)

    # metrics for GAN training
    #   on fake, how often is discriminator fooled?
    self.generator_win_percentage = pl.metrics.Accuracy()  
    #   on even mix of fake+real, how often is discriminator correct?
    self.discriminator_win_percentage = pl.metrics.Accuracy()

  def forward(self, z):
    return self.discriminator(self.generator(z))

  def adversarial_loss(self, y_hat, y):
    return self.loss(y_hat, y)

  # for a GAN, we need two optimizers: one for generator, one for discriminator
  def configure_optimizers(self):
    generator_optimizer = self.optim_from_config(self.generator_optim_config, self.generator.parameters())
    discriminator_optimizer = self.optim_from_config(self.discriminator_optim_config, self.discriminator.parameters())
    return [generator_optimizer, discriminator_optimizer], []

  # defined below in a different cell
  def training_step(self):
    pass

  ##
  # Logging code
  ##

  # on each epoch, log images and some image statistics
  def on_epoch_end(self):
    # Grab some images from the set sampled during training to log
    sample_imgs = self.generated_images[:self.logged_images].detach()
    # Turn them into a nice grid of images for logging
    sampled_grid = torchvision.utils.make_grid(sample_imgs, nrow=1, value_range=(0, 1), pad_value=0.5)
    # Across the sample, what are the means and variances at each pixel?
    sampled_mean, sampled_var = torch.mean(sample_imgs, dim=0), torch.var(sample_imgs, dim=0)

    # Check the outputs at a fixed set of positions in the latent space
    z = self.validation_z.type_as(self.generator.layers[0].weight)
    valid_imgs = self.generator(z).detach()
    # Turn them into a nice grid of images for logging
    valid_grid = torchvision.utils.make_grid(valid_imgs, nrow=1, value_range=(0, 1), pad_value=0.5)
    # Across the sample, what are the means and variances at each pixel?
    valid_mean, valid_var = torch.mean(valid_imgs, dim=0), torch.var(valid_imgs, dim=0)

    # Log everything to W&B
    self.logger.experiment.log({"image/sampled_images": wandb.Image(sampled_grid),
                                "image/validation_images": wandb.Image(valid_grid),
                                "image/sampled_image_mean": wandb.Image(sampled_mean),
                                "image/sampled_image_var": wandb.Histogram(sampled_var.cpu()),
                                "image/valid_image_mean": wandb.Image(valid_mean),
                                "image/valid_image_var": wandb.Histogram(valid_var.cpu()),
                                "trainer/epoch": self.current_epoch})

    # Log metadata to W&B
    if not self.logged_metadata:
      self.max_logged_images = 0  # deactivate automated logging
      self.do_logging(sample_imgs, None, 0, self.discriminator_outputs, {}, step="training")
      self.logged_metadata = True
      wandb.run.config["generator_nparams"] = lu.callbacks.count_params(gan.generator)
      wandb.run.config["discriminator_nparams"] = lu.callbacks.count_params(gan.discriminator)

  # on each training step (defined below),
  #  log these quantities and report them to W&B, with averages over epochs
  def training_step_end(self, metrics):
    if "g_loss" in metrics.keys():
      loss = metrics["g_loss"]
      win_perc = metrics["generator_win_percentage"]
      prefix = "train/generator"
    else:
      loss = metrics["d_loss"]
      win_perc = metrics["discriminator_win_percentage"]
      prefix = "train/discriminator"
  
    batch_idx = metrics["batch_idx"]
    if not batch_idx % self.log_interval:
      self.log_dict({prefix + "/loss/batch": loss, prefix + "/win_perc/batch": win_perc},
                    on_epoch=False, on_step=True)
    else:
      self.log_dict({prefix + "/loss/epoch": loss, prefix + "/win_perc/epoch": win_perc},
                    on_epoch=True, on_step=False)
  
    return metrics

  @staticmethod
  def optim_from_config(config, parameters):
    optimizer = config["optimizer"](parameters, **config["optimizer.params"])
    return optimizer

# minor implementation detail, due to Python inheritance
try:
  # we don't have a separate validation_step, so need to remove
  del lu.nn.modules.LoggedLitModule.validation_step
  del lu.nn.modules.LoggedLitModule.test_step
except AttributeError:
  pass

# Generator


In [None]:
class Generator(pl.LightningModule):
  """Generator module for an image GAN in PyTorch Lightning.

  .forward takes a batch of vectors with dimension config["latent_dim"]
  as input and returns images of size config["image_size"] as output.

  Try defining different .block methods or changing the hyperparameters
  of the blocks.
  """

  def __init__(self, config):
    super().__init__()

    self.latent_dim = config["latent_dim"]
    self.image_size = config["image_size"]
    self.activation = config["activation"]
    self.normalize = config["normalize"]


    self.layers = torch.nn.Sequential(
      *self.block(self.latent_dim, 128, self.activation),
      *self.block(128, 256, self.activation, normalize=self.normalize),
      *self.block(256, 512, self.activation, normalize=self.normalize),
      *self.block(512, 1024, self.activation, normalize=self.normalize),
      torch.nn.Linear(1024, get_flat_size(self.image_size)),
      torch.nn.Sigmoid()  # image pixels are in [0, 1]
    )

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)

    img = x.view(x.shape[0], *self.image_size)
    return img

  @staticmethod
  def block(in_dims, out_dims, activation=torch.nn.ReLU, normalize=False):
      layers = [torch.nn.Linear(in_dims, out_dims)]
      if normalize:
        layers.append(torch.nn.BatchNorm1d(out_dims))
      layers.append(activation())
      return layers


def get_flat_size(image_size):
  return np.prod(image_size)

# Discriminator


In [None]:
class Discriminator(pl.LightningModule):
  """Discriminator module for an image GAN in PyTorch Lightning.

  .forward takes a batch of images with size config["image_size"]
  as input and returns scalars in [0, 1].

  Try adding convolutional components at the beginning of self.layers.
  """
  def __init__(self, config):
    super().__init__()
    self.image_size = config["image_size"]
    self.activation = config["activation"]

    self.layers = torch.nn.Sequential(
      torch.nn.Linear(get_flat_size(self.image_size), 512),
      self.activation(),
      torch.nn.Linear(512, 256),
      self.activation(),
      torch.nn.Linear(256, 1),
      torch.nn.Sigmoid(),
      )

  def forward(self, img):
    x = torch.flatten(img, start_dim=1) # flatten all except batch dimension

    for layer in self.layers:
      x = layer(x)

    return x

# Training Step

The training for a GAN is more complex than for other types of networks:
we have to alternate between training the discriminator and the generator,
and each trains on slightly different data.

So for clarity, we've split the definition of the `training_step` out
from the rest of the module code.

In [None]:
def training_step(self, batch, batch_idx, optimizer_idx):
  imgs, _ = batch  # ignore labels
  batch_sz = imgs.shape[0]

  # implementation detail: we use two optimizers, stored in a lsit
  training_generator = optimizer_idx == 0
  training_discriminator = not training_generator

  if training_generator:

    # 1. Sample random input for the generator
    z = torch.randn(batch_sz, self.latent_dim).type_as(imgs)

    # 2. Generator makes images from random input
    self.generated_images = self.generator(z)

    # 3. We pass those images through the discriminator
    self.discriminator_outputs =  self.discriminator(self.generated_images)

    # For the generator, the "target" on a fake input is a 1,
    #  indicating that the discriminator classifies it as real,
    #  even though the ground truth label is 0.
    is_fake = torch.ones(batch_sz, 1).type_as(imgs)
    g_loss = self.adversarial_loss(self.discriminator_outputs, is_fake)

    # Return a dictionary of outputs for logging and automated backward pass by Lightning
    output = {"loss": g_loss, "g_loss": g_loss, "batch_idx": batch_idx,
              "generator_win_percentage": self.generator_win_percentage(self.discriminator_outputs, is_fake.int())}

  if training_discriminator:

    # 1. Obtain the discriminator outputs on real and fake images
    outputs_on_real = self.discriminator(imgs)
    outputs_on_fake = self.discriminator(self.generated_images.detach())

    # For the discriminator, the "target" on a real input is a 1
    targets_real = torch.ones(batch_sz, 1).type_as(imgs)
    # and the "target" on a fake input is a 0
    targets_fake = torch.zeros(batch_sz, 1).type_as(imgs)

    # 2. Combine (concatenate) the outputs/targets in the two cases
    outputs = torch.cat([outputs_on_real, outputs_on_fake])
    targets = torch.cat([targets_real, targets_fake])

    d_loss = self.adversarial_loss(outputs, targets)

    # Return a dictionary of outputs for logging and automated backward pass by Lightning
    output = {"loss": d_loss, "d_loss": d_loss, "batch_idx": batch_idx,
              "discriminator_win_percentage": self.discriminator_win_percentage(outputs, targets.int())}

  return output

# add the training step code from above to the LitGAN class
LitGAN.training_step = training_step

# Training

To run training, execute the cell below.
You can configure the network and training procedure
by changing the values of the `config` dictionary.

In between training runs,
especially runs that crashed,
you may wish to restart the notebook
and re-run the preceding cells
to get rid of accumulated state
(`Runtime > Restart runtime`).

In [None]:
###
# Setup Hyperparameters, Data, and Model
###


config = {  # dictionary of configuration hyperparameters
  "batch_size": 256,  # number of examples in a single batch
  "max_epochs": 32,  # number of times to pass over the whole dataset
  "image_size": (1, 28, 28),  # size of images in this dataset
  "latent_dim": 128,  # size of input
  "loss": torch.nn.BCELoss(),  # loss function for adversarial loss
  "activation": torch.nn.ReLU,  # activation function class (instantiated later)
  "normalize": True,  # whether to use BatchNorm in Generator
  "discriminator.optim" : {
    "optimizer": torch.optim.Adam,  # optimizer class (instantiated later)
    "optimizer.params":  # dict of hyperparameters for optimizer
      {"lr": 0.002,  # learning rate to scale gradients
      "betas": (0.5, 0.999),  # momentum parameters
      "weight_decay": 0}  # if non-zero, reduce weights each batch
  },
  "generator.optim" : {
    "optimizer": torch.optim.Adam,  # optimizer class (instantiated later)
    "optimizer.params":  # dict of hyperparameters for optimizer
      {"lr": 0.0002,  # learning rate to scale gradients
      "betas": (0.5, 0.999),  # momentum parameters
      "weight_decay": 0}  # if non-zero, reduce weights each batch
  }
}

config["log_interval"] = max(int((50000 // config["batch_size"]) / 10), 1)

# 📸 set up the dataset of images
dmodule = lu.datamodules.mnist.MNISTDataModule(
    batch_size=config["batch_size"])
dmodule.prepare_data()
dmodule.setup()

# 🥅 instantiate the network
gan = LitGAN(config)

###
# Train the model
###


with wandb.init(project="lit-gan", entity="wandb", config=config) as run:
  # 👀 watch the gradients, log to Weights & Biases
  wandb.watch(gan)

  # 👟 configure Trainer 
  trainer = pl.Trainer(gpus=1,  # use the GPU for .forward
                      logger=pl.loggers.WandbLogger(
                        log_model=True, save_code=True),  # log to Weights & Biases
                      max_epochs=config["max_epochs"], log_every_n_steps=1,
                      progress_bar_refresh_rate=50)
                      
  # 🏃‍♀️ run the Trainer on the model
  trainer.fit(gan, dmodule)

### Exercises

The cell above will output links to Weights & Biases dashboards where you can review the training process and the final resulting model.

These dashboards will be useful in working through the exercises below.

#### 1. Balancing Act: Speed

Though our eventual goal is for the generator to "win" the competition
between the two networks,
we need the discriminator to also learn effectively
if the generator is to learn to make realistic images.
But if the discriminator is too good,
training can also fail.
This balancing act makes GAN training difficult.

Let's see this phenomenon in action.
First, reduce the `lr` of the discriminator by a factor of at least 1000
so that it learns much more
slowly than the generator. What happens?
Then, return the `lr` of the discriminator to its original value
and decrease the `lr` of the generator
by a factor of at least 100
so that it learns much more
slowly than the discriminator. What happens?

Make sure to return the `lr`s to their original values
(`0.002` for the discriminator and `0.0002` for the generator)
before proceeding!

#### 2. Balancing Act: Capacity

Another balancing act involves the _capacity_ or _expressivity_
of the two networks --
how powerful and flexible the networks are,
or how much they can learn from the data.
Loosely, we want as high capacity of a generator
as we can comfortably execute and train,
and we want a discriminator with sufficient capacity
to prevent the generator from winning with a "cheap trick",
like always returning the same input.
But we don't want the discriminator to have
such high capacity that it's impossible to fool
(see the **Challenge** section for GAN training tricks
that enable the use of higher-capacity discriminators).

Capacity is hard to quantify and even harder to measure.
As a first pass at the capacity,
we just count the total number of parameters
(mostly, the weights and biases of the linear layers)
in each network.
More parameters generally result in greater capacity.

First, decrease the size, depth, and `latent_dim` of the generator 
until the results at the end of training drop in quality.

Then, return to the default hyperparameters and 
increase the size and depth of the discriminator
until the results at the end of training drop in quality.
Compare the outputs
(in qualitative terms and in terms of the image pixel statistics)
and the training dynamics (e.g. win percentages)
with those from the previous set of experiments.

_Note_: the parameter counts are logged to W&B as
`generator_nparams` and `discriminator_nparams`.

#### 3. Just Use More Compute

In addition to GAN training being finicky,
is notoriously difficult to tell when GAN training is finished.
The goal is for the output images to fool a _human_ discriminator --
to have a high "perceptual quality" --
but we can't cheaply measure this,
let alone backpropagate through the procedure.
Values of the generator/discriminator loss have little meaning in terms
of this perceptual quality.

One common, if inelegant, solution, is to simply train for much longer
than would otherwise seem reasonable.
Increase the number of epochs (`config["max_epochs"]`)
and decrease the batch size
until training takes at least 20 minutes --
with the default hyperparameters and a batch size of `32`,
this would be about `100` epochs.
The iteration time should be roughly linear in the number of epochs.

Does the "perceptual quality" of the images seem to increase?
What happens to the generator and discriminator losses?
Do they appear to converge?

> _Note_: You can increase the runtime further if you'd like, but
[Google Colab places limits on GPU usage](https://research.google.com/colaboratory/faq.html),
so if you run for multiple hours, you may find your access temporarily curtailed!

#### **Challenge**: Stability Tricks for GANs

[Generative modeling is hard](https://wandb.ai/ayush-thakur/keras-gan/reports/Towards-Deep-Generative-Modeling-with-W-B--Vmlldzo4MDI4Mw),
and GANs are notoriously a particularly challenging type of generative model.

There are [many tricks for making GAN training easier](https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b).

A few are listed below. Try implementing them!

1. _One-Sided Label Smoothing_. In
[label smoothing](https://towardsdatascience.com/what-is-label-smoothing-108debd7ef06),
we train networks on "not-quite-one-hot" vectors,
whose entries are close to, but not quite, `0` and `1`.
In GANs, this can prevent the discriminator from over-fitting
to the current generator.
Add this to the `training_step` by adjusting the values of the
`target`s for the discriminator.
2. _Noisy Inputs_. A clever discriminator could memorize
every single digit in the dataset and prevent the generator
from learning to generate new digits.
One way around this is to add noise
to both real and fake inputs so that the discriminator.
Add this to the `training_step`.
See if this allows you to use a bigger discriminator
and generate better digits.

_Hint_: Generate random tensors with
`torch.rand*_like` methods, 
like [this one](https://pytorch.org/docs/stable/generated/torch.randn_like.html#torch-randn_like),
as in this snippet:
```
random_noise = torch.randn_like(??)
```


#### **Challenge**: Convolutional GANs

As with classification and auto-encoding tasks,
GANs for images benefit from the use of convolutional layers.
Rework the discriminator and generator to use
`torch.nn.Conv2d` and `torch.nn.ConvTranspose2d` layers,
respectively.
Check out the
[MNIST Autoencoder notebook](https://colab.research.google.com/github/wandb/edu/blob/main/lightning/autoencoder/autoencoder-mnist.ipynb)
for examples of convolutional layers in a generative model.

_Hint_: max-pooling in the discriminator can make it too easy to fool.
Use [strided convolutions](https://www.reddit.com/r/MachineLearning/comments/5x4jbt/d_strided_convolutions_vs_pooling_layers_pros_and/)
instead.
For more tips, see
[this article](https://www.kdnuggets.com/2017/11/generative-adversarial-networks-part2.html).