# Generative Adversarial Networks

In this lab, we are going to train a GAN that mimics the images from the MNIST dataset. Most of the code is written for you already, and your task will be to implement some of the techniques that help GANs converge better, as discussed in the lecture material.

As always, let's start by importing all the required libraries:

In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
from IPython.display import clear_output

In [None]:
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning
                            # installed by default.
                            # Hence, we do it here if necessary
    !pip install pytorch-lightning==1.3.4
    import pytorch_lightning as pl

And download the MNIST dataset:

In [None]:
transform = transforms.ToTensor()
ds_train = MNIST(".", train=True, download=True, transform=transform)
ds_test = MNIST(".", train=False, download=True, transform=transform)

Here's a helper function to glue together a bunch of images and plot them in a matrix:

In [None]:
def join_images(images):
    # We expect `images` to be a 2D array of 2D images represented
    # as uint8 numbers (from 0 to 255)
    images = np.array(images)
    assert images.ndim == 4, "This function expects the input " \
                             "to be a 2D matrix of 2D images"

    # Let's pad our images with the white color to separate the
    # nearby images once we glue them together
    joined_image = np.pad(
        images, ((0, 0), (0, 0), (0, 1), (0, 1)), constant_values=255
    )

    # Here we transpose and reshape our 4D array to look at it as if
    # it was a single 2D image with all the sub-images placed in a
    # rectangular grid
    joined_image = np.transpose(joined_image, (0, 2, 1, 3))
    joined_image = joined_image.reshape(
        joined_image.shape[0] * joined_image.shape[1],
        joined_image.shape[2] * joined_image.shape[3],
    )

    # Finally, we pad the top and left sides of our resulting image with
    # the white color
    joined_image = np.pad(
        joined_image, ((1, 0), (1, 0)), constant_values=255
    )

    return joined_image

# A simple function to quickly plot a matrix of images
def plot_images(images):
    plt.imshow(join_images(images), cmap="Greys")
    plt.axis("off")

Let's test our functions:

In [None]:
plot_images(
    ds_train.data.numpy()[:20].reshape(4, 5, 28, 28)
)

Here's a brief reminder about GANs.

- They consist of two networks: **the generator** ($G$) and **the discriminator** ($D$).
- The generator takes in a latent vector $z$ sampled from some known fixed prior distibution $p_z$ (typically, multivariate standard normal) and outputs the generated object.
- The discriminator takes in an object (either generated or a real one) and tries to predict whether it is real or not.
- The whole thing is trained in steps:
    - first, a bunch of discriminator update steps minimizing the negative log-likelihood of its predictions,
    - then, a generator update step dragging the discriminator's loss in the opposite direction.

Let $D(x)$ be the logit output of our classifier (i.e. the class score without any activation), and let $y\in\{0,1\}$ be the true class of the given object $x$. Then, the negative log-likelihood loss (also known as the binary cross-entropy loss, or just BCE) looks like this:
$$l(D(x), y)=-\left[y\cdot\log\sigma(D(x))+(1-y)\cdot\log(1-\sigma(D(x)))\right].$$

In the case of GAN, the object-target pairs $(x, y)$ will be $(x, 1)$ and $(G(z), 0)$, where $x$ is taken from the training set and $z$ is sampled from the latent distribution. The loss above will be minimized by the discriminator and maximized by the generator.

As mentioned in the lectures, the GANs often suffer from problems like vanishing gradients or mode collapse. Here are some techniques that help mitigate these problems:

1. **Non-saturating loss.** Early in the training, the generator samples are very different from the training data, and it's very easy for the discriminator to separate them. This may result in the saturation of the generator loss, $\left[\log(1-\sigma(D(G(z))))\right]$, and therefore no meaningful gradients for the generator. An alternative, *non-saturating* loss for the generator would be $\left[-\log\sigma(D(G(z)))\right]$. Note that it's equivalent to using the BCE loss above with $l(D(G(z)), y=1)$.
2. **Additional noise.** One of the reasons why the mentioned problems may occur is due to the non-overlapping supports of the real and generated distributions. One may artificially make them overlap by smearing the discriminator input with some additional noise. This can be, e.g. multivariate standard normal distribution with some given magnitude.
3. **Label swapping.** If the discriminator gets too strong, there's no good gradients for the generator. One of the ways of preventing the discriminator from getting too strong is by randomly swapping the labels for a small fraction of inputs from "real" to "fake", or vice versa, or both.

OK, now we're ready to build our GAN! We'll put everything into a single lightning module. In the code below, **we'll leave blanks for you to implements the techiques 1-3 from above**.

# Task (6 points)

In the code cell below, fill in the blanks to implement the techniques 1-3 mentioned in the previous text cell:
1. **Non-saturating loss** (2 points)
2. **Additional noise** (2 points)
3. **Label swapping** (2 points)

In [None]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        generator, # the generator network (torch.Module)
        discriminator, # the discriminator network (torch.Module)
        latent_size, # the dimensionality of z
        num_disc_steps=1, # number of discriminator updates per single generator update
        initial_lr=0.001, # initial learning rate
        lr_decay_rate=0.95, # a factor, by which the learning rates will be multiplied after each epoch
        non_saturating_loss=False, # boolean, whether to use the non-saturating loss
        additional_noise_power=None, # float, the magnitude of the additional noise
        label_swap_prob=None, # swap each label in the discriminator's loss with this probability
        **kwargs
    ):
        super().__init__(**kwargs)

        self.generator = generator
        self.discriminator = discriminator
        self.latent_size = latent_size
        self.num_disc_steps = num_disc_steps
        self.initial_lr = initial_lr
        self.lr_decay_rate = lr_decay_rate
        self.non_saturating_loss = non_saturating_loss
        self.additional_noise_power = additional_noise_power
        self.label_swap_prob = label_swap_prob

        # We'll use this counter to switch between generator and discriminator steps
        self._gan_step_counter = 0

        # Our loss function
        self.criterion = torch.nn.BCEWithLogitsLoss()

        # Important: This property activates manual optimization in lightning.
        # Since GAN training steps are quite non-standard, we opt for manual
        # optimization that we'll implement on our own.
        self.automatic_optimization = False

    # A function to sample a batch of latent vectors z
    def generate_z(self, N):
        return torch.randn(N, self.latent_size).to(self.device)

    # A function to sample a batch of generated objects G(z)
    def generate(self, N):
        return self.generator(self.generate_z(N))

    # We'll use the function below to calculate both generator's and
    # discriminator's losses
    def _shared_losses_calculation(self, real_img_batch):
        """Calculate the loss value on a given batch"""

        batch_size = len(real_img_batch)

        # generate a batch of fakes:
        fake_img_batch = self.generate(batch_size)

        if self.additional_noise_power is not None:
            # add random normal noize with the magnitude of
            # `self.additional_noise_power` to both fake and real batches

            # <YOUR CODE>
            raise NotImplementedError # <= remove this


        # calculate the discriminator output on real and fake batches:
        d_real = self.discriminator(real_img_batch)
        d_fake = self.discriminator(fake_img_batch)

        # for y=0 and y=1, let's create arrays of labels like this:
        labels_1 = torch.ones(
            batch_size, 1, dtype=real_img_batch.dtype
        ).to(self.device)
        labels_0 = torch.zeros(
            batch_size, 1, dtype=real_img_batch.dtype
        ).to(self.device)

        # In the discriminator loss, we'll pass the labels
        # after a modification done by the `self.swap_labels`
        # function defined below.
        d_loss = (
            self.criterion(d_real, self.swap_labels(labels_1))
            + self.criterion(d_fake, self.swap_labels(labels_0))
        )
        if not self.non_saturating_loss:
            g_loss = -self.criterion(d_fake, labels_0)
        else:
            # Implement the non-saturating version of the loss for the generator

            # g_loss = <YOUR CODE>
            raise NotImplementedError # <= remove this


        return d_loss, g_loss

    # A function to swap labels (0 <=> 1) with a given probability (`self.label_swap_prob`)
    def swap_labels(self, labels):
        if self.label_swap_prob is None:
            return labels

        # For each entry in `labels`, randomly swap 0 to 1 and 1 to 0 with
        # probability `self.label_swap_prob`, return the result

        # <YOUR CODE>
        # return <YOUR CODE>
        raise NotImplementedError # <= remove this


    # This function will be iteratively called by lightning, automatically.
    # Here we make our optimization steps.
    def training_step(self, batch, batch_idx):
        # extract the objects, ignore the MNIST labels (digit indicies):
        batch, _ = batch

        # get the optimizers (see the `configure_optimizers` method below):
        d_opt, g_opt = self.optimizers()

        # calculate both losses:
        d_loss, g_loss = (
            self._shared_losses_calculation(batch)
        )

        # Choose, which update step to make:
        if self._gan_step_counter < self.num_disc_steps:
            # Making a discriminator step
            self._gan_step_counter += 1
            d_opt.zero_grad()
            self.manual_backward(d_loss) # https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#manual-backward
            d_opt.step()
        else:
            # Making a generator step
            self._gan_step_counter = 0
            g_opt.zero_grad()
            self.manual_backward(g_loss) # https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#manual-backward
            g_opt.step()

        # Logging our losses
        self.log("train_loss_discriminator", d_loss)
        self.log("train_loss_generator", g_loss)

    # This function will be automatically called by lightning
    # at each training epoch end. Inside, we are going to schedule our
    # learning rates (see the `configure_optimizers` method below).
    def training_epoch_end(self, outputs):
        for sch in self.lr_schedulers():
            sch.step()

    # At the validation step, we'll just calculate the losses and log them.
    def validation_step(self, batch, batch_idx):
        batch, _ = batch
        d_loss, g_loss = (
            self._shared_losses_calculation(batch)
        )
        self.log("val_loss_discriminator", d_loss)
        self.log("val_loss_generator", g_loss)

    # Here, we configure our optimizers for the discriminator and generator,
    # along with learning rate schedulers.
    def configure_optimizers(self):
        d_opt = torch.optim.RMSprop(self.discriminator.parameters(), lr=self.initial_lr)
        g_opt = torch.optim.RMSprop(self.generator.parameters(), lr=self.initial_lr)

        # A learning rate scheduler is an object that changes the learning rate
        # during training. Below, we create ExponentialLR instances to
        # exponentially decay the learning rates (multiply the learning rate by
        # a factor of 0.95 after each epoch).
        d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_opt, self.lr_decay_rate)
        g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_opt, self.lr_decay_rate)

        return [
            {"optimizer" : d_opt, "lr_scheduler" : d_scheduler},
            {"optimizer" : g_opt, "lr_scheduler" : g_scheduler},
        ]


The code cell below defines the generator and discriminator architectures. Since we work with images, we decide to follow the regular deep convolutional network architecture for the discriminator. For the generator, we interleave convolutions and transposed convolutions with residual connections to upsample the images.

In [None]:
# This will be our upsampling block with a residual connection:
# we'll upsample the input in two ways and then add the results together.
# The two upsampling ways are:
#   1) with a transposed convolution with a trainable kernel
#   2) with a fixed nearest-neighbor interpolation upsampling
class UpsampleWithRes(torch.nn.Module):
    def __init__(self, upconv, activation, factor, **kwargs):
        super().__init__(**kwargs)

        self.path_a = torch.nn.Sequential(
            upconv, activation()
        )
        self.factor = factor

    def forward(self, x):
        x_a = self.path_a(x)
        x_b = torch.nn.functional.interpolate(x, scale_factor=self.factor)
        return x_a + x_b

# Define the generator architecture
class ConvGenerator(torch.nn.Module):
    def __init__(self, activation, latent_size, **kwargs):
        super().__init__(**kwargs)

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(latent_size, 1024),
            activation(),
        )
        self.backbone = torch.nn.Sequential( # 8x8
            torch.nn.Conv2d(
                in_channels=16, out_channels=128, kernel_size=3, padding=1,
            ), # -> 8x8
            activation(),
            UpsampleWithRes(
                upconv=torch.nn.ConvTranspose2d(
                    in_channels=128, out_channels=128, kernel_size=4, stride=4, padding=0,
                ),
                activation=activation,
                factor=4,
            ), # -> 32x32
            torch.nn.Conv2d(
                in_channels=128, out_channels=64, kernel_size=3, padding=0,
            ), # -> 30x30
            activation(),
            torch.nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=0,
            ), # -> 28x28
            activation(),
            torch.nn.Conv2d(
                in_channels=32, out_channels=1, kernel_size=1, padding=0,
            ),
        )

    def forward(self, x):
        return self.backbone(
            self.fc(x).view(-1, 16, 8, 8)
        ).view(-1, 28, 28)


# Define the discriminator architecture
class ConvDiscriminator(torch.nn.Module):
    def __init__(self, activation, **kwargs):
        super().__init__(**kwargs)

        self.backbone = torch.nn.Sequential( # 28x28
            torch.nn.Conv2d(
                in_channels=1, out_channels=64, kernel_size=3, padding=1,
            ), # -> 28x28
            activation(),
            torch.nn.Conv2d(
                in_channels=64, out_channels=128, kernel_size=4, stride=4, padding=0,
            ), # -> 7x7
            activation(),
            torch.nn.Conv2d(
                in_channels=128, out_channels=128, kernel_size=3, padding=0,
            ), # -> 5x5
            activation(),
            torch.nn.Conv2d(
                in_channels=128, out_channels=128, kernel_size=3, padding=0,
            ), # -> 3x3
            activation(),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(9 * 128, 32),
            activation(),
            torch.nn.Linear(32, 1),
        )

    def forward(self, x):
        return self.fc(
            self.backbone(x.view(-1, 1, 28, 28)).view(-1, 9 * 128)
        )

Here, we'll create a callback to plot some generator images at the end of each training epoch. We'll store the generated images in tensorboard.

In [None]:
class PlotDigitsCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, module):
        tensorboard = module.logger.experiment
        image = join_images(
            (
                module.generate(30).detach().cpu().numpy().clip(0, 1).reshape(
                    5, 6, 28, 28
                ) * 255
            ).astype("uint8")
        )
        tensorboard.add_image(
            "generated_images",
            image,
            trainer.global_step, dataformats="HW"
        )

Finally, let's create our networks and the GAN module and a lightning trainer object:

In [None]:
LATENT_SIZE = 128
model = GAN(
    generator=ConvGenerator(
        activation=torch.nn.ELU, latent_size=LATENT_SIZE
    ),
    discriminator=ConvDiscriminator(
        activation=torch.nn.ELU,
    ),
    latent_size=LATENT_SIZE,
    num_disc_steps=1,
    initial_lr=0.001,
    non_saturating_loss=True,
    additional_noise_power=0.05,
    label_swap_prob=0.02,
)

train_loader = torch.utils.data.DataLoader(
    ds_train,
    batch_size=200,
    shuffle=True,
    num_workers=multiprocessing.cpu_count(),
)
test_loader = torch.utils.data.DataLoader(
    ds_test, batch_size=2048
)

trainer = pl.Trainer(
    gpus=1,
    max_epochs=20,
    log_every_n_steps=5,
    flush_logs_every_n_steps=10,
    callbacks=[
        PlotDigitsCallback(),
        pl.callbacks.LearningRateMonitor(),
    ],
)

Initialize tensorboard to monitor progress.

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Finally, we may train our model. While it's training, keep an eye on the tensorboard interface above. You can monitor the losses in the "SCALARS" tab, but you should also be able to look at the images produced by the generator in the "IMAGES" tab.

In [None]:
trainer.fit(model, train_loader, test_loader)

Let's plot some images that our generator produces:

In [None]:
plot_images(
    (model.generate(30).detach().cpu().numpy().clip(0, 1).reshape(5, 6, 28, 28) * 255).astype("uint8")
)