<img src="https://drive.google.com/uc?export=view&id=1x-QAgitB-S5rxGGDqxsJ299ZQTfYtOhb" width=180, align="center"/>

Master's degree in Intelligent Systems

Subject: 11754 - Deep Learning

Year: 2023-2024

Professor: Miguel Ángel Calafat Torrens

# Lab 5 - GANs (Generative Adversarial Networks)

Generative Adversarial Networks (GANs) was a groundbreaking approach in deep learning that revolutionized the field of artificial intelligence. As you have already seen in the theory of the subject, GANs consist of two neural networks –a generator and a discriminator– engaged in a competitive game. The generator aims to create realistic data samples, while the discriminator distinguishes between real and fake samples. Through adversarial training, GANs generate increasingly realistic data.

This notebook covers GAN architecture and training, offering hands-on exercises to build and train models. You'll get to generate new data and grasp the core concepts of GANs. Let's get started.

# Classical GAN

## Setting up

In this notebook we will use the celebA dataset. You can find it in a lot of places all over the internet, but one good place to start is its own [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) where you scroll down to the "**Download**" section, select the link "**In the wild images**" go into the folder "**Img**" and download the file "**image_align_celeba.zip**". Maybe it's even easier download it from this [link of Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset) if you have an account (the account is free).

Anyway, you just remember that **the CelebA dataset is available for non-commercial research purposes only**.

The dataset is 1.34GB in size because it has more than 200,000 images. In this practice, since what is done is a proof of concept, you do not need to upload all the images. A good option would be to take a selection of, for example, 2000 photos and compress them into a file that will be the one you upload to your Drive account.

In my case I've taken the first 2000 images of celebA and compressed them into a file called `img_align_celeba_small.zip` that I've uploaded to my GDrive.

In [None]:
# Connect to your drive
from google.colab import drive, files
drive.mount('/content/gdrive')
%cd '/content/gdrive/MyDrive/LABS2024/LAB5'
%ls -l

# Here the path of the project folder (which is where this file is) is inserted
# into the python path.
import pathlib
import sys
import os
import helper_PR5 as hp

PROJECT_DIR = str(pathlib.Path().resolve())
sys.path.append(PROJECT_DIR)

In [None]:
# Ensure you write down the correct path to your zip file with the dataset

# CelebA full
# dataset_zip_fullpath = '/content/gdrive/MyDrive/datasets/img_align_celeba.zip'

# CelebA small
dataset_zip_fullpath = '/content/gdrive/MyDrive/datasets/img_align_celeba_small.zip'

**Before running the next cell, make sure you have left the zip file in the location you want and assign the variable accordingly.**

In [None]:
# This cell will take several minutes the first time you run it.

# You can execute this cell every time you run the code. It won't unzip
# anything if the dataset is already unzipped.
DATA_DIR = hp.extract_dataset(dataset_zip_fullpath,
                              remove_zip=True)

Perhaps, the first time you run this notebook, you might prefer using the CPU environment, because you'll need some time to understand what is being done (remember the GPU usage limitations set by Colab). Moreover, the GPU will only be necessary when you wish to execute the trainings, that is, when you know how the entire set works.

In [None]:
# Import the necessary libraries
import PIL
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

SEED = 42
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import importlib

# If you have made changes to the hp module and want to reload it
importlib.reload(hp)

## The dataset

The dataset is made up of images of faces of famous people. Taking into account that we do not want to do very expensive computational training, we are not interested in using the images with their original resolution (218 x 178), but it is better for us to convert them to a lower resolution, which in this case will be 64 x 64.

Note that a custom dataset class is defined in the next cell, in a similar way that you did in previous practices. It uses the `__getitem__()` method to implement the corresponding transformations. In addition to the image, a label is returned, which in this case is not necessary. The label is the name of the file.

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_folder, lim=2000, transforms=None):
        # lim is the max number of images that we want to use. It's default is
        # used when doing a prove of concept, so that we don't want to use all
        # images of the dataset
        self.img_folder = img_folder
        self.lim = lim

        # Initialize empty lists for items and labels
        self.items = []
        self.labels = []

        # Walk through all files in img_folder and its subfolders
        for root, _, files in os.walk(img_folder):
            for file in files:
                # Check if the file is an image
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    full_path = os.path.join(root, file)
                    self.items.append(full_path)
                    self.labels.append(file)

                    # Stop if the limit is reached
                    if lim > 0 and len(self.items) >= lim:
                        break

            if lim > 0 and len(self.items) >= lim:
                break

        # Define the transformation pipeline
        self.transforms = transforms

    def __getitem__(self, idx):
        # Load the image
        data = PIL.Image.open(self.items[idx]).convert('RGB')

        # Apply the transformations if provided
        if self.transforms:
            data = self.transforms(data)

        # Return the processed data and its corresponding label
        return data, self.labels[idx]

    def __len__(self):
        return len(self.items)

Well, we now have the dataset object, but before instantiating it we need to know what transformations we are going to apply.
Later you will see how the generator typically uses an activation function `tanh`, whose output gives a value between -1 and 1. For this reason it is recommended that the input tensors are already scaled in this range.

In [None]:
# Define the transformations to be done

# As you know, transforms.ToTensor transforms the image into a tensor and
# scales it in the range [0, 1]. In this case we use this class to define a new
# class that fulfills the purpose of scaling in the range [-1, 1] using the
# first one as the base class.

# Be sure you understand this code and why it works

class ToScaledTensor(transforms.ToTensor):
    def __init__(self, low=-1, high=1):
        super().__init__()
        self.low = low
        self.high = high

    def __call__(self, img):
        # Convert image to tensor (same as ToTensor)
        tensor = super().__call__(img)
        # Scale the tensor to the desired range
        tensor = tensor * (self.high - self.low) + self.low
        return tensor

And now we can define the transformations in the usual way. The first step is to crop the original image (218 x 178) to create a centered square image (178 x 178), which is the largest size possible.. This image does not crop faces in the vast majority of cases.
Then we will resize the image to the required dimension (64 x 64).

In [None]:
# Size of the images (height and width)
img_size = 64

# Transformations
transform = transforms.Compose([
    # Center crop the images so that they become square
    transforms.CenterCrop((178, 178)),
    # Resize the image to the specified size (h, w)
    transforms.Resize((img_size, img_size)),
    # Convert the image to a PyTorch tensor and scale to [-1, 1]
    ToScaledTensor(),
    ])

And now, finally, we are ready to instantiate the dataset.

In [None]:
# Max number of images to use (-1 if all)

# Maybe you want to start working with fewer images to see how it works
limit = 2000

# Define the custom dataset object.
dataset = CustomDataset(DATA_DIR,
                        transforms=transform,
                        lim=limit)

## The dataloader

Now we are ready to instantiate the dataloader.

In [None]:
# Batch size (number of images per batch)
batch_size = 128

# Define the DataLoader
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True)

As you can see in the following cell a helper function is used. I recommend that you check these helper functions to fully understand how they work.

In [None]:
# And now let's see images of the first batch
hp.show(next(iter(dataloader))[0])

## The networks

Now the time has come to define the two competing models. On the one hand we have the generator, which will have a decoder configuration, and on the other hand we will have the discriminator, which will have an encoder configuration.

They can be considered symmetrical networks because one attempts to generate a credible image from a latent space, while the other analyzes an image to determine its authenticity by mapping it back to a latent space.

Below you'll find a representation of the tensor's shape that we will have at each step. Note, for example, that in the first downsample step (CONV in the image) of the discriminator it recieves as input an image tensor of shape (3, 64, 64) (3 channels, 64 pixels height, 64 pixels width) and outputs a tensor of (64, 32, 32) due to the number os kernels applied, and the factor 2 reduction with the stride.

In the following downsample steps the width is doubled and the height and width are divided by 2. At the end of the convolutional layers the tensor of shape (512, 4, 4) is squeezed in a tensor of 8192 values, and finally, through a fully connected layer, it outputs a single value between 0 and 1 that indicates wether the image is fake or real.

<img src="https://drive.google.com/uc?export=view&id=1WXTHp5vY2rEzVZRAoGsfoAZ4pq6kCb6d" width=700, align="center"/>

In [None]:
# Let's define the discriminator with an encoder estructure of 4 downscale
# steps.
class Discriminator(nn.Module):
    def __init__(self, d_dim=64, img_size=64):

        # "d_dim" is the output dimension of the first convolutional layer;
        # that is, the number of filters/kernels you have in the first layer
        # (3 inputs, channels RGB, and d_dim outputs)
        super().__init__()

        # Configuration parameters
        kernel_size = 4
        n = 4  # Number of conv layers. Only used for definition of fc_in
        fc_in = int(d_dim * 2**(n-1) * (img_size/(2**n))**2)  # fc input dim
        pad = 1
        stride = 2
        bias = False

        # Helper function for convolutions
        def conv(in_channels, out_channels):
            return nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=pad,
                             bias=bias)

        # Dense or fully connected layer
        self.fc = nn.Linear(fc_in, 1)  # d_dim x 2^3 x (img_size / 2^4)^2 = 8192

        # Convolutional layers (doubling the depth with each step)
        self.conv1 = conv(in_channels=3,
                          out_channels=d_dim)
        self.conv2 = conv(in_channels=d_dim,
                          out_channels=2*d_dim)
        self.conv3 = conv(in_channels=2*d_dim,
                          out_channels=4*d_dim)
        self.conv4 = conv(in_channels=4*d_dim,
                          out_channels=8*d_dim)

        # Batchnorm layers. Not usually applied to the first convolutional layer
        self.bnorm2 = nn.BatchNorm2d(2 * d_dim)
        self.bnorm3 = nn.BatchNorm2d(4 * d_dim)
        self.bnorm4 = nn.BatchNorm2d(8 * d_dim)

        # Define a LeakyReLU activation function layer
        self.leaky_relu = nn.LeakyReLU(0.2)


    def forward(self, x):
        # Input shape: batch x 3 x img_size x img_size  -->  b x 3 x 64 x 64
        batch_size = x.size(0)

        # See output shape at each step
        out = self.leaky_relu(self.conv1(x))                 # b x 64 x 32 x 32
        out = self.leaky_relu(self.bnorm2(self.conv2(out)))  # b x 128 x 16 x 16
        out = self.leaky_relu(self.bnorm3(self.conv3(out)))  # b x 256 x 8 x 8
        out = self.leaky_relu(self.bnorm4(self.conv4(out)))  # b x 512 x 4 x 4

        # Flatten (b x 512 x 4 x 4  ==  b x 8192)
        out = out.contiguous().view(batch_size, -1)  # b x 8192

        # Final output layer without activation function
        scores = self.fc(out)  # b x 1
        return scores

The generator, as you can see in the image below, has the reversed configuration. In this case we start with a latent code of 100 values, which is converted via a fully connected layer into a tensor of 8192 positions. Afterwards, it is viewed as (512, 4, 4) tensor and them back to the shape of a tensor image using 4 steps of transposed convolutional layers.

The typical generator configuration is that of a decoder; that is, it is the opposite of the discriminator.


<img src="https://drive.google.com/uc?export=view&id=1BMvNbC-OK0oINFG0Oz9nmfoN9lwppVw1" width=700, align="center"/>

In [None]:
class Generator(nn.Module):

    # 'd_dim' is the input dimension of the last transposed convolutional layer;
    # that is, the output of this layer has to be of depth 3 due to RGB, but the
    # input is arbitrary. Well, here it will be 'd_dim', that is the counterpart
    # of the 'd_dim' defined in the discriminator.

    # 'z_dim' is the size of the input to the network.

    # You design the forward pass from the end to the beginning by calculating
    # the dimensions (see comments in the code)

    def __init__(self, z_dim=100, d_dim=64, img_size=64):
        super().__init__()

        # Configuration parameters
        kernel_size = 4
        self.d_dim = d_dim

        # Number of convolutional layers (only for sizes calculation purposes,
        # i.e. not used dinamically to define a different model)
        n = 4

        # Calculate dimensions of fully connected layer
        fc_out = int(d_dim * 2**(n-1) * (img_size/(2**n))**2)
        # Image size (height and width) on fully connected layer output
        self.i_s_fco = int(img_size / 2**n)

        # Configuration data for transposed convolutional layers
        pad = 1
        stride = 2
        bias = False
        kernel_size = 4

        # Helper function for transposed convolutions
        def tconv(in_channels, out_channels):
            return nn.ConvTranspose2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      padding=pad,
                                      bias=bias)

        # Linear layer
        self.fc = nn.Linear(z_dim, fc_out)

        # Transpose convolutional layers
        self.tconv1 = tconv(in_channels=d_dim*8,
                            out_channels=d_dim*4)
        self.tconv2 = tconv(in_channels=d_dim*4,
                            out_channels=d_dim*2)
        self.tconv3 = tconv(in_channels=d_dim*2,
                            out_channels=d_dim)
        self.tconv4 = tconv(in_channels=d_dim,
                            out_channels=3)

        # Batchnorm layers
        self.bnorm1 = nn.BatchNorm2d(d_dim * 4)
        self.bnorm2 = nn.BatchNorm2d(d_dim * 2)
        self.bnorm3 = nn.BatchNorm2d(d_dim)

        # Define activation function layers
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

    def forward(self, x):

        # Input size: b x z_dim = b x 100
        b = x.size(0)

        # First of all transform the input tensor (the latent generator's code)
        # to dimensions equivalent to those of the latent space of the
        # discriminator, i.e: b x d_dim*8 x img_size/16 x img_size/16
        out = self.fc(x)                                # b x 8192
        out = out.contiguous().view(b, -1,
                                    self.i_s_fco,
                                    self.i_s_fco)       # b x 512 x 4 x 4

        # Now apply the transposed convolutions
        out = self.relu(self.bnorm1(self.tconv1(out)))  # b x 256 x 8 x 8
        out = self.relu(self.bnorm2(self.tconv2(out)))  # b x 128 x 16 x 16
        out = self.relu(self.bnorm3(self.tconv3(out)))  # b x 64 x 32 x 32
        out = self.tanh(self.tconv4(out))               # b x 3 x 64 x 64

        return out

An now you can instantiate the models. You can use the default parameters or, if you wish, you can redefine them if you want to conduct some tests.

In [None]:
# Instantiate the models with default parameters
D = Discriminator()
G = Generator()

## Optimizers

In this section, we configure the optimizers with the same values for the learning rate and beta parameters. While you have the option to adjust these values, it's not necessary, as the provided parameters are already well-suited for most purposes.

In [None]:
# Discriminator optimizer

# Parameters
lr_d = 0.0004
beta1 = 0.6
beta2 = 0.999

# Create optimizer for the discriminator D
d_optimizer = torch.optim.Adam(D.parameters(), lr_d, [beta1, beta2])

In [None]:
# Generator optimizer

# Parameters
lr_g = 0.0004
beta1 = 0.6
beta2 = 0.999

# Create optimizers for generator G
g_optimizer = torch.optim.Adam(G.parameters(), lr_g, [beta1, beta2])

## Losses

The issue of the loss function is especially controversial in the case of GANs. Until now, we have encountered loss functions that influence only a single model. However, in the scenario of GANs, a single loss function must govern two distinct models with fundamentally opposing goals.

Initially, we plan to use Binary Cross Entropy Loss. However, because the discriminator's output lacks an activation function, we will opt for the variant that directly applies to logits or scores; namely, BCEWithLogitsLoss.

Note that the loss function below takes the output of the discriminator as input, but it does not receive the ground truth (i.e., the correct labels identifying whether an image is real or fake). Instead, it accepts a keyword argument called "label_type", which has only two possible values: 'real' or 'fake'. Thus, when training the discriminator with real images, we will assign the value 'real' to this parameter, and when training with fake images, we will assign 'fake'.

Please review the code below to ensure you understand it.

In [None]:
# Define the loss function
def gan_loss_fcn(discr_output, **kwargs):
    """
    Calculates the loss based on specified label type (real or fake).

    Args:
        discr_output: Tensor of discriminator logits.
        **kwargs: Keyword arguments, including:
            - label_type (str, optional): Label type, either "real" or "fake".
              Defaults to "real".
            - smooth (bool, optional): Apply label smoothing on real labels.
              Defaults to False.

    Returns:
        A Tensor containing the calculated discriminator loss.
    """

    batch_size = discr_output.size(0)

    # Extract label_type and smooth from **kwargs with defaults
    label_type = kwargs.get("label_type", "real")
    smooth = kwargs.get("smooth", False)

    # Ensure label_type is valid (real or fake)
    label_type = "real" if label_type.lower() not in ("real", "fake") \
        else label_type.lower()

    # Define labels based on label_type
    if label_type == "real":
        labels = torch.ones(batch_size) * (0.9 if smooth else 1.0)
    else:
        labels = torch.zeros(batch_size)

    # Move labels to GPU if available
    labels = labels.to(DEVICE)

    # Binary cross entropy with logits loss
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(discr_output.squeeze(), labels)

    return loss

## Checkpoints

As usual, we're going to perform checkpoints. The code below is a function that will be called within the training function. It will assess whether a checkpoint is required, and if so, it will proceed by utilizing a helper function from `helper_PR5.py`.

Check the criteria used to perform or not the checkpoint.

In [None]:
def checkpointer(epoch,
                 epoch_gener_loss,
                 best_gen_loss,
                 config,
                 save_step,
                 starting_from=20):
    """
    Do checkpoints if required.

    Evaluates the performance of the generator at the current epoch and decides
    whether to save a checkpoint. Checkpoints are saved if the generator's loss
    is improved past a specified epoch or at regular intervals defined by the
    save_step parameter.

    The function logs the reason for saving a checkpoint, updates the best
    generator loss if necessary, and calls a separate function to actually save
    the checkpoint.

    Parameters:
    - epoch (int): The current training epoch.
    - epoch_gener_loss (float): The generator's loss for the current epoch.
    - best_gen_loss (float): The best generator loss observed so far in the
        training.
    - config (dict): A dictionary containing the configuration of the model,
        including the model itself and its optimizer.
    - save_step (int): The interval of epochs at which checkpoints are saved
        regardless of performance improvement.
    - starting_from (int, optional): The epoch number from which to start
        considering saving checkpoints based on loss improvement. Default: 20.

    Returns:
    - None
    """

    # Check if this epoch's generator loss is an improvement
    if epoch_gener_loss < best_gen_loss and epoch > starting_from:
        # Save new best score
        best_gen_loss = epoch_gener_loss

        # Log
        print(f"New best generator loss {best_gen_loss} at epoch",
              f" {epoch}. Saving checkpoint.")

        # Save checkpoint
        hp.save_checkpoint(f"best_{epoch}",
                           epoch,
                           config,
                           PROJECT_DIR)

    # Or maybe the checkpoint is saved due to the number of epochs
    elif epoch % save_step == 0:
        # Log
        print(f"Epoch {epoch}. ",
              "Not best losses, but saving checkpoint anyway.")

        # Save checkpoint
        hp.save_checkpoint(f"epoch_{epoch}",
                           epoch,
                           config,
                           PROJECT_DIR)

## Training

At this point, we reach the training phase, which, as you know, involves training the two models in a competition against each other.

Given that the training function is quite extensive, it's recommended to be analyzed in detail. Notice that within the loop through batches, both the discriminator and the generator are trained. Moreover, within this same sweep, a new loop is introduced that could train the discriminator more times than the generator.

<img src="https://drive.google.com/uc?export=view&id=1oZ6TVh1dkRIHIBmuovS2QM1yYgVSLgz3" width=900, align="center"/>

In short, an overview of what you'll find in the training function can be seen in the following scheme in pseudocode.

<img src="https://drive.google.com/uc?export=view&id=1gtHLkmCqJRmpPwfY9PVJA3BZ33-k0qAz" width=500, align="center"/>

In [None]:
def train(config, verbose=True):
    """
    Training loop
    """

    # Initialize a variable to track the best generator loss seen so far
    best_gen_loss = float('inf')

    # Initialize lists of generator and discriminator losses per epoch
    gener_losses_epoch_list = []
    discr_losses_epoch_list = []

    # Initialize configuration values from the config variable
    n_epochs = config.get('n_epochs', 100)
    crit_cycles = config.get('crit_cycles', 1)
    z_dim = config.get('z_dim', 100)
    show_step = config.get('show_step', 25)
    save_step = config.get('save_step', 5)
    last_epoch = config.get('epoch', 0)
    save_starting = config.get('save_starting', 20)
    penalty_fcn = config.get('penalty_fcn', lambda *x: 0)

    # Initialize configuration values from the config object
    # config_dict = vars(config)  # Convert instance attributes to a dict
    dataloader = config['dataloader']
    gener = config['generator'].to(DEVICE)
    discr = config['discriminator'].to(DEVICE)
    gener_opt = config['g_optimizer']
    discr_opt = config['d_optimizer']
    loss_fcn = config['loss_fcn']

    # Loop through epochs
    for epoch in range(last_epoch + 1, n_epochs + last_epoch + 1):
        # Initialize accumulators for losses
        epoch_gener_loss = 0.0
        epoch_discr_loss = 0.0

        # Number of batches, used for averaging losses later
        num_batches = 0

        # Loop through images
        for real_imgs, _ in tqdm(dataloader):
            num_batches += 1

            # Current batch size
            current_bs = len(real_imgs)

            # Pass real images to the gpu if available
            real_imgs = real_imgs.to(DEVICE)

            # Train the discriminator (or critic) on real and fake images for
            # the number of cycles proposed.
            discr_loss_for_cycles = 0
            for _ in range(crit_cycles):
                # Zero gradients
                discr_opt.zero_grad()

                # Generate noise (initial latent code of the generator)
                noise = torch.randn(current_bs, z_dim, device=DEVICE)

                # Generate fake image from noise
                fake_imgs = gener(noise)

                # Prediction of Discriminator on fake images (tensor must be
                # detached to avoid backpropagation on the generator weights)
                discr_fake_pred = discr(fake_imgs.detach())

                # Losses of the Discriminator on fake images
                discr_fake_loss = loss_fcn(discr_fake_pred, label_type='fake')

                # Prediction of Discriminator on real images
                discr_real_pred = discr(real_imgs)

                # Losses of the Discriminator on real images
                discr_real_loss = loss_fcn(discr_real_pred, label_type='real')

                # Calculate gradient penalty (not used in classic GAN)
                penalty = penalty_fcn(real_imgs,
                                      fake_imgs.detach(),
                                      discr)

                # Discriminator (or critic) losses for the current cycle.
                discr_loss = discr_fake_loss + discr_real_loss + penalty

                # Calculate losses of all the cycles so far
                discr_loss_for_cycles += discr_loss.item() / crit_cycles

                # Backpropagation and weights update
                discr_loss.backward()
                discr_opt.step()

            # Get final losses of the discriminator in the current epoch
            epoch_discr_loss += discr_loss_for_cycles

            # Train the generator

            # Zero gradients
            gener_opt.zero_grad()

            # Generate noise (initial latent code of the generator)
            noise = torch.randn(current_bs, z_dim, device=DEVICE)

            # Generate fake images from noise
            fake_imgs = gener(noise)

            # Prediction of discriminator on fake images
            discr_fake_pred = discr(fake_imgs)

            # Losses of the generator in the current batch
            gener_loss = loss_fcn(discr_fake_pred, label_type='real')
            epoch_gener_loss += gener_loss.item()

            # Backpropagation and weights update
            gener_loss.backward()
            gener_opt.step()

        # Average the losses for the current epoch
        epoch_gener_loss /= num_batches
        epoch_discr_loss /= num_batches
        gener_losses_epoch_list.append(epoch_gener_loss)
        discr_losses_epoch_list.append(epoch_discr_loss)

        # Log of the epoch
        if verbose:
            print({'Epoch': epoch,
                   'Critic loss': epoch_discr_loss,
                   'Gen loss': epoch_gener_loss})

        # Do checkpoint if required
        checkpointer(epoch=epoch,
                     epoch_gener_loss=epoch_gener_loss,
                     best_gen_loss=best_gen_loss,
                     config=config,
                     save_step=save_step,
                     starting_from=save_starting)

        # Conditional visualization at the end of an epoch
        if epoch % show_step == 0:
            hp.visual_epoch(fake_imgs,
                            real_imgs,
                            gener_losses_epoch_list,
                            discr_losses_epoch_list)

    return gener, discr, [gener_losses_epoch_list, discr_losses_epoch_list]

Let's do the first training.

In [None]:
# If you want to run some trial trainings you can tweak the number of images
# of the dataset (limit) or the epochs to reduce the time invested.

# Note that with this configuration we don't save checkpoints

# Number of epochs to train.
n_epochs = 40

# Let's embed all the necessary data in single dict variable
config = {'n_epochs': n_epochs,         # Number of epochs to train
          'dataloader': dataloader,
          'discriminator': D,
          'generator': G,
          'd_optimizer': d_optimizer,
          'g_optimizer': g_optimizer,
          'project_dir': PROJECT_DIR,   # The folder where checkpoints will be saved
          'loss_fcn': gan_loss_fcn,     # The loss function
          'show_step': 5,               # Show images every 'show_step' steps
          'save_step': 1000,            # Save checkpoint every 'save_step' epochs
          'save_starting': 1000,        # First epoch to save checkpoints when improving
          }

In [None]:
_ = train(config)

# WGAN-GP

Training GANs often encounters problems like mode collapse, where the generator produces limited varieties of samples, and non-convergence, where the generator and discriminator keep outperforming each other without stabilization.

In this section we are going to implement one option to help dealing with non-convergence problems.

The Wasserstein GAN (WGAN) addresses non-convergence by using a different loss function known as the Wasserstein loss, or Earth Mover's distance. This approach improves training stability and provides more meaningful gradients, making it easier for the GAN to learn. It involves constraining the discriminator (now called a critic) to be a 1-Lipschitz function, which is typically enforced through weights clipping or, more recently, through gradient penalty (WGAN-GP), helping the model to converge more reliably.

[Wassestein GAN with gradient penalty](https://paperswithcode.com/method/wgan-gp)

Moreover, by implementing gradient penalty, it enforces a smoother gradient behavior for the critic. This smoother gradient behavior provides better training signals to the generator, encouraging it to produce more diverse samples and thereby also reducing the risk of mode collapse.


## The critic

Note that the model below is dimensionally identical to the Discriminator of the classical GAN defined in the first section of this notebook. The only thing that has changed is the normalization function. In this case we're using instance normalization.

We use instance normalization instead of batch normalization to stabilize the learning process without introducing dependencies between the examples in a batch. Instance normalization normalizes the input across each channel for each example independently, which helps the Critic to focus on the structural content of each input image without being influenced by the variance across a batch of images.

In [None]:
class Critic(nn.Module):
    def __init__(self, d_dim=64, img_size=64):
        # "d_dim" is the output dimension of the first convolutional layer;
        # that is, the number of filters/kernels you have in the first layer
        # (3 inputs, RGB, and d_dim outputs)
        super().__init__()

        # Configuration parameters
        kernel_size = 4
        n = 4  # Number of conv layers. Only used for definition of fc_in
        fc_in = int(d_dim * 2**(n-1) * (img_size/(2**n))**2)  # fc input dim
        pad = 1
        stride = 2
        bias = False

        # Helper function for convolutions
        def conv(in_channels, out_channels):
            return nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=pad,
                             bias=bias)

        # Dense or fully connected layer
        self.fc = nn.Linear(fc_in, 1)  # d_dim x 2^3 x (img_size / 2^4)^2 = 8192

        # Convolutional layers (doubling the depth with each step)
        self.conv1 = conv(in_channels=3,
                          out_channels=d_dim)
        self.conv2 = conv(in_channels=d_dim,
                          out_channels=2*d_dim)
        self.conv3 = conv(in_channels=2*d_dim,
                          out_channels=4*d_dim)
        self.conv4 = conv(in_channels=4*d_dim,
                          out_channels=8*d_dim)

        # Instance normalization layers.
        self.inorm1 = nn.InstanceNorm2d(d_dim)
        self.inorm2 = nn.InstanceNorm2d(2 * d_dim)
        self.inorm3 = nn.InstanceNorm2d(4 * d_dim)
        self.inorm4 = nn.InstanceNorm2d(8 * d_dim)

        # Define a LeakyReLU activation function layer
        self.leaky_relu = nn.LeakyReLU(0.2)


    def forward(self, x):
        # Input shape: batch x 3 x img_size x img_size  -->  b x 3 x 64 x 64
        batch_size = x.size(0)

        # Out shape: b x d_dim x img_size/2 x img_size/2
        out = self.leaky_relu(self.inorm1(self.conv1(x)))    # b x 64 x 32 x 32
        out = self.leaky_relu(self.inorm2(self.conv2(out)))  # b x 128 x 16 x 16
        out = self.leaky_relu(self.inorm3(self.conv3(out)))  # b x 256 x 8 x 8
        out = self.leaky_relu(self.inorm4(self.conv4(out)))  # b x 512 x 4 x 4

        # Flatten (b x 512 x 4 x 4  ==  b x 8192)
        out = out.contiguous().view(batch_size, -1)  # b x 8192

        # Final output layer without activation function
        scores = self.fc(out)  # b x 1
        return scores

Like the previous loss function, this one also takes only the tensor and a parameter indicating if it's from a real or fake image as inputs.

Note that in this case the loss function, instead of being BCE, is directly the average of the input tensor, with a plus or minus sign depending on whether the images are fake or real.

In [None]:
def Wasserstein_loss_fcn(input_tensor, **kwargs):

    # Extract label_type from **kwargs with defaults
    label_type = kwargs.get("label_type", "real")

    # Ensure label_type is valid (real or fake)
    label_type = "real" if label_type.lower() not in ("real", "fake") \
        else label_type.lower()

    # Get sign of the losses
    sign = -1 if label_type == 'real' else 1

    return input_tensor.mean() * sign

We are about to define the penalty function, utilizing the same training algorithm as in our prior experiments. Could you investigate how this function was implemented during our training with the classic GAN setup? Specifically, where it was defined and the values it assumed are of interest.

This penalty function, including a slight modification with the addition of a gamma multiplier, is outlined in the code found at the provided link, referencing the relevant paper. [link to the GitHub](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py)

In [None]:
def penalty_fcn(real_img, fake_img, critic, gamma=10):
    """
    Calculate gradient penalty
    """

    # Interpolation between real and fake images

    # Define alpha with shape bs x 1 x 1 x 1
    alpha = torch.rand(real_img.shape[0], 1, 1, 1,
                       device=DEVICE,
                       requires_grad=True)

    # Calculate interpolated images
    # You can do it this way, that is easier to read, or with torch.lerp, that
    # is faster: `mix_images = real_img * alpha + fake_img * (1 - alpha)`
    mix_images = torch.lerp(fake_img, real_img, alpha)

    # Get the critic prediction for the mixed images
    mix_pred = critic(mix_images)

    # Calculate the gradients of the output with respect to the input.
    # I.e. the gradients of the prediction with respect to the mixed images.
    gradients = torch.autograd.grad(
        inputs=mix_images,
        outputs=mix_pred,
        grad_outputs=torch.ones_like(mix_pred),
        retain_graph=True,
        create_graph=True,
        only_inputs=True)[0]

    # Reshape gradients to calculate norm
    gradients = gradients.view(len(gradients), -1)  # Flatten the gradients

    # Use torch.norm to calculate the L2 norm of the gradients
    gradient_norm = torch.norm(gradients, p=2, dim=1)

    # Calculate the gradient penalty
    gp = ((gradient_norm - 1) ** 2).mean()

    return gp * gamma

And now we define the optimizers.

In [None]:
torch.manual_seed(SEED)
G = Generator()
C = Critic()

# Generator optimizer

# Parameters
lr_g = 0.0002
beta1 = 0.65
beta2 = 0.999

# Create optimizers for generator G
g_optimizer = torch.optim.Adam(G.parameters(), lr_g, [beta1, beta2])


# Critic optimizer

# Parameters
lr_d = 0.0002
beta1 = 0.65
beta2 = 0.999

# Create optimizer for the discriminator D
c_optimizer = torch.optim.Adam(C.parameters(), lr_d, [beta1, beta2])

And finally we train again and see how it goes.

In [None]:
# Number of epochs to train.
n_epochs = 40

# Note that with this configuration we don't save checkpoints

config = {'n_epochs': n_epochs,
          'dataloader': dataloader,
          'discriminator': C,                  # The Critic
          'generator': G,
          'd_optimizer': c_optimizer,
          'g_optimizer': g_optimizer,
          'project_dir': PROJECT_DIR,
          'loss_fcn': Wasserstein_loss_fcn,    # Wasserstein loss function
          'show_step': 5,
          'penalty_fcn': penalty_fcn,          # Gradient penalty
          'crit_cycles': 5,
          'save_step': 1000,
          'save_starting': 1000,
          }

_ = train(config)