<img src="https://drive.google.com/uc?export=view&id=1x-QAgitB-S5rxGGDqxsJ299ZQTfYtOhb" width=180, align="center"/>

Master's degree in Intelligent Systems

Subject: 11754 - Deep Learning

Year: 2023-2024

Professor: Miguel Ángel Calafat Torrens

# LAB 5 - Generative Adversarial Networks

**In this lab you have to deliver only this file.**

As you will see this notebook shows essentially the same cells that `LSS5-GAN.ipynb`, or at least, the important ones related to the WGAN-GP, without the chatter.

You have to modify this file in the cells you want to achieve the following requirements:
* You have to define the generator and the critic with at least 5 blocks of convolutions.
* The latent space (input) of the generator must be of size 200 instead of 100.
* You have to perform a training for at least 40 epochs with the dataset (you can use the small version, or another one bigger, or even the full version, as you wish).
* You have to save checkpoints at two different epochs (for example, if you've trained 40 epochs, you could save checkpoints at epochs 20 and 40)
* You have to load the checkpoints of these two epochs and show a random generation of a batch of 25 images (for each checkpoint).

**Remember that you have a helper functions in `helper_PR5.py` to load checkpoints and to show batches of tensors**, so you don't have to invent anything but the models.

**You can modify, add or remove any cell that you want to fulfill the requirements.**


# WGAN-GP

## Setting up

In this notebook we will use the celebA dataset. You can find it in a lot of places all over the internet, but one good place to start is its own [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) where you scroll down to the "**Download**" section, select the link "**In the wild images**" go into the folder "**Img**" and download the file "**image_align_celeba.zip**". Maybe it's even easier download it from this [link of Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset) if you have an account (the account is free).

Anyway, you just remember that **the CelebA dataset is available for non-commercial research purposes only**.

In [None]:
# Connect to your drive
from google.colab import drive, files
drive.mount('/content/gdrive')
%cd '/content/gdrive/MyDrive/LABS2024/LAB5'
%ls -l

# Here the path of the project folder (which is where this file is) is inserted
# into the python path.
import pathlib
import sys
import os
import helper_PR5 as hp

PROJECT_DIR = str(pathlib.Path().resolve())
sys.path.append(PROJECT_DIR)

In [None]:
# Ensure you write down the correct path to your zip file with the dataset
dataset_zip_fullpath = '/content/gdrive/MyDrive/datasets/img_align_celeba_small.zip'

**Before running the next cell, make sure you have left the zip file in the location you want and assign the variable accordingly.**

In [None]:
# This cell will take several minutes the first time you run it.

# You can execute this cell every time you run the code. It won't unzip
# anything if the dataset is already unzipped.
DATA_DIR = hp.extract_dataset(dataset_zip_fullpath,
                              remove_zip=True)

Perhaps, the first time you run this notebook, you might prefer using the CPU environment, because you'll need some time to understand what is being done (remember the GPU usage limitations set by Colab). Moreover, the GPU will only be necessary when you wish to execute the trainings, that is, when you know how the entire set works.

In [None]:
# Import the necessary libraries
import PIL
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

SEED = 42
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## The dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_folder, lim=2000, transforms=None):
        # lim is the max number of images that we want to use. It's default is
        # used when doing a prove of concept, so that we don't want to use all
        # images of the dataset
        self.img_folder = img_folder
        self.lim = lim

        # Initialize empty lists for items and labels
        self.items = []
        self.labels = []

        # Walk through all files in img_folder and its subfolders
        for root, _, files in os.walk(img_folder):
            for file in files:
                # Check if the file is an image
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    full_path = os.path.join(root, file)
                    self.items.append(full_path)
                    self.labels.append(file)

                    # Stop if the limit is reached
                    if lim > 0 and len(self.items) >= lim:
                        break

            if lim > 0 and len(self.items) >= lim:
                break

        # Define the transformation pipeline
        self.transforms = transforms

    def __getitem__(self, idx):
        # Load the image
        data = PIL.Image.open(self.items[idx]).convert('RGB')

        # Apply the transformations if provided
        if self.transforms:
            data = self.transforms(data)

        # Return the processed data and its corresponding label
        return data, self.labels[idx]

    def __len__(self):
        return len(self.items)

In [None]:
# Define the transformations to be done
class ToScaledTensor(transforms.ToTensor):
    def __init__(self, low=-1, high=1):
        super().__init__()
        self.low = low
        self.high = high

    def __call__(self, img):
        # Convert image to tensor (same as ToTensor)
        tensor = super().__call__(img)
        # Scale the tensor to the desired range
        tensor = tensor * (self.high - self.low) + self.low
        return tensor

In [None]:
# Size of the images (height and width)
img_size = 64

# Transformations
transform = transforms.Compose([
    # Center crop the images so that they become square
    transforms.CenterCrop((178, 178)),
    # Resize the image to the specified size (h, w)
    transforms.Resize((img_size, img_size)),
    # Convert the image to a PyTorch tensor and scale to [-1, 1]
    ToScaledTensor(),
    ])

In [None]:
# limit is the max number of images to use (-1 if all)

# Maybe you want to start working with fewer images to see how it works
limit = -1

# Define the custom dataset object.
dataset = CustomDataset(DATA_DIR,
                        transforms=transform,
                        lim=limit)

## The dataloader

In [None]:
# Batch size (number of images per batch)
batch_size = 128

# Define the DataLoader
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True)

In [None]:
# And now let's see images of the first batch
hp.show(next(iter(dataloader))[0])

## The networks

### The Critic

In [None]:
# Here define your version of the Critic
class Critic(nn.Module):

    pass

### The Generator

In [None]:
# Latent space size of the generator
z_dim = 200

In [None]:
# Here define your version of the Generator
class Generator(nn.Module):

    pass

In [None]:
# Instantiate the models. Of course you can initialize them with
# default values or not.
C = Critic()
G = Generator()

## Optimizers

In [None]:
# Critic optimizer

# Parameters
lr_c = 0.0002
beta1 = 0.5
beta2 = 0.9

# Create optimizer for the critic C
c_optimizer = torch.optim.Adam(C.parameters(), lr_c, [beta1, beta2])

In [None]:
# Generator optimizer

# Parameters
lr_g = 0.0002
beta1 = 0.5
beta2 = 0.9

# Create optimizers for generator G
g_optimizer = torch.optim.Adam(G.parameters(), lr_g, [beta1, beta2])

## Losses

In [None]:
def Wasserstein_loss_fcn(input_tensor, **kwargs):

    # Extract label_type from **kwargs with defaults
    label_type = kwargs.get("label_type", "real")

    # Ensure label_type is valid (real or fake)
    label_type = "real" if label_type.lower() not in ("real", "fake") \
        else label_type.lower()

    # Get sign of the losses
    sign = -1 if label_type == 'real' else 1

    return input_tensor.mean() * sign

In [None]:
def penalty_fcn(real_img, fake_img, critic, gamma=10):
    """
    Calculate gradient penalty
    """

    # Interpolation between real and fake images

    # Define alpha with shape bs x 1 x 1 x 1
    alpha = torch.rand(real_img.shape[0], 1, 1, 1,
                       device=DEVICE,
                       requires_grad=True)

    # Calculate interpolated images
    # You can do it this way, that is easier to read, or with torch.lerp, that
    # is faster: `mix_images = real_img * alpha + fake_img * (1 - alpha)`
    mix_images = torch.lerp(fake_img, real_img, alpha)

    # Get the critic prediction for the mixed images
    mix_pred = critic(mix_images)

    # Calculate the gradients of the output with respect to the input.
    # I.e. the gradients of the prediction with respect to the mixed images.
    gradients = torch.autograd.grad(
        inputs=mix_images,
        outputs=mix_pred,
        grad_outputs=torch.ones_like(mix_pred),
        retain_graph=True,
        create_graph=True,
        only_inputs=True)[0]

    # Reshape gradients to calculate norm
    gradients = gradients.view(len(gradients), -1)  # Flatten the gradients

    # Use torch.norm to calculate the L2 norm of the gradients
    gradient_norm = torch.norm(gradients, p=2, dim=1)

    # Calculate the gradient penalty
    gp = ((gradient_norm - 1) ** 2).mean()

    return gp * gamma

## Checkpoints

Check the criteria used to perform or not the checkpoint.

In [None]:
def checkpointer(epoch,
                 epoch_gener_loss,
                 best_gen_loss,
                 config,
                 save_step,
                 starting_from=20):
    """
    Do checkpoints if required.

    Evaluates the performance of the generator at the current epoch and decides
    whether to save a checkpoint. Checkpoints are saved if the generator's loss
    is improved past a specified epoch or at regular intervals defined by the
    save_step parameter.

    The function logs the reason for saving a checkpoint, updates the best
    generator loss if necessary, and calls a separate function to actually save
    the checkpoint.

    Parameters:
    - epoch (int): The current training epoch.
    - epoch_gener_loss (float): The generator's loss for the current epoch.
    - best_gen_loss (float): The best generator loss observed so far in the
        training.
    - config (dict): A dictionary containing the configuration of the model,
        including the model itself and its optimizer.
    - save_step (int): The interval of epochs at which checkpoints are saved
        regardless of performance improvement.
    - starting_from (int, optional): The epoch number from which to start
        considering saving checkpoints based on loss improvement. Default: 20.

    Returns:
    - None
    """

    # Check if this epoch's generator loss is an improvement
    if epoch_gener_loss < best_gen_loss and epoch > starting_from:
        # Save new best score
        best_gen_loss = epoch_gener_loss

        # Log
        print(f"New best generator loss {best_gen_loss} at epoch",
              f" {epoch}. Saving checkpoint.")

        # Save checkpoint
        hp.save_checkpoint(f"best_{epoch}",
                           epoch,
                           config,
                           PROJECT_DIR)

    # Or maybe the checkpoint is saved due to the number of epochs
    elif epoch % save_step == 0:
        # Log
        print(f"Epoch {epoch}. ",
              "Not best losses, but saving checkpoint anyway.")

        # Save checkpoint
        hp.save_checkpoint(f"epoch_{epoch}",
                           epoch,
                           config,
                           PROJECT_DIR)

## Training

In [None]:
def train(config, verbose=True):
    """
    Training loop
    """

    # Initialize a variable to track the best generator loss seen so far
    best_gen_loss = float('inf')

    # Initialize lists of generator and discriminator losses per epoch
    gener_losses_epoch_list = []
    discr_losses_epoch_list = []

    # Initialize configuration values from the config variable
    n_epochs = config.get('n_epochs', 100)
    crit_cycles = config.get('crit_cycles', 1)
    z_dim = config.get('z_dim', 100)
    show_step = config.get('show_step', 25)
    save_step = config.get('save_step', 5)
    last_epoch = config.get('epoch', 0)
    save_starting = config.get('save_starting', 20)
    penalty_fcn = config.get('penalty_fcn', lambda *x: 0)

    # Initialize configuration values from the config object
    # config_dict = vars(config)  # Convert instance attributes to a dict
    dataloader = config['dataloader']
    gener = config['generator'].to(DEVICE)
    discr = config['discriminator'].to(DEVICE)
    gener_opt = config['g_optimizer']
    discr_opt = config['d_optimizer']
    loss_fcn = config['loss_fcn']

    # Loop through epochs
    for epoch in range(last_epoch + 1, n_epochs + last_epoch + 1):
        # Initialize accumulators for losses
        epoch_gener_loss = 0.0
        epoch_discr_loss = 0.0

        # Number of batches, used for averaging losses later
        num_batches = 0

        # Loop through images
        for real_imgs, _ in tqdm(dataloader):
            num_batches += 1

            # Current batch size
            current_bs = len(real_imgs)

            # Pass real images to the gpu if available
            real_imgs = real_imgs.to(DEVICE)

            # Train the discriminator (or critic) on real and fake images for
            # the number of cycles proposed.
            discr_loss_for_cycles = 0
            for _ in range(crit_cycles):
                # Zero gradients
                discr_opt.zero_grad()

                # Generate noise (initial latent code of the generator)
                noise = torch.randn(current_bs, z_dim, device=DEVICE)

                # Generate fake image from noise
                fake_imgs = gener(noise)

                # Prediction of Discriminator on fake images (tensor must be
                # detached to avoid backpropagation on the generator weights)
                discr_fake_pred = discr(fake_imgs.detach())

                # Losses of the Discriminator on fake images
                discr_fake_loss = loss_fcn(discr_fake_pred, label_type='fake')

                # Prediction of Discriminator on real images
                discr_real_pred = discr(real_imgs)

                # Losses of the Discriminator on real images
                discr_real_loss = loss_fcn(discr_real_pred, label_type='real')

                # Calculate gradient penalty (not used in classic GAN)
                penalty = penalty_fcn(real_imgs,
                                      fake_imgs.detach(),
                                      discr)

                # Discriminator (or critic) losses for the current cycle.
                discr_loss = discr_fake_loss + discr_real_loss + penalty

                # Calculate losses of all the cycles so far
                discr_loss_for_cycles += discr_loss.item() / crit_cycles

                # Backpropagation and weights update
                discr_loss.backward()
                discr_opt.step()

            # Get final losses of the discriminator in the current epoch
            epoch_discr_loss += discr_loss_for_cycles

            # Train the generator

            # Zero gradients
            gener_opt.zero_grad()

            # Generate noise (initial latent code of the generator)
            noise = torch.randn(current_bs, z_dim, device=DEVICE)

            # Generate fake images from noise
            fake_imgs = gener(noise)

            # Prediction of discriminator on fake images
            discr_fake_pred = discr(fake_imgs)

            # Losses of the generator in the current batch
            gener_loss = loss_fcn(discr_fake_pred, label_type='real')
            epoch_gener_loss += gener_loss.item()

            # Backpropagation and weights update
            gener_loss.backward()
            gener_opt.step()

        # Average the losses for the current epoch
        epoch_gener_loss /= num_batches
        epoch_discr_loss /= num_batches
        gener_losses_epoch_list.append(epoch_gener_loss)
        discr_losses_epoch_list.append(epoch_discr_loss)

        # Log of the epoch
        if verbose:
            print({'Epoch': epoch,
                   'Critic loss': epoch_discr_loss,
                   'Gen loss': epoch_gener_loss})

        # Do checkpoint if required
        checkpointer(epoch=epoch,
                     epoch_gener_loss=epoch_gener_loss,
                     best_gen_loss=best_gen_loss,
                     config=config,
                     save_step=save_step,
                     starting_from=save_starting)

        # Conditional visualization at the end of an epoch
        if epoch % show_step == 0:
            hp.visual_epoch(fake_imgs,
                            real_imgs,
                            gener_losses_epoch_list,
                            discr_losses_epoch_list)

    return gener, discr, [gener_losses_epoch_list, discr_losses_epoch_list]

Define the training parameters.

In [None]:
# Number of epochs to train.
n_epochs = 40

# Define config dict
config = {}

Let's do the training.

In [None]:
# Train
_ = train(config)

# Generate images

In [None]:
# Generate 25 latent codes of size 200
noise = torch.randn(25, z_dim, device=DEVICE)

In [None]:
# Load the checkpoint indicated at 'name'
name = 'epoch_20'
hp.load_checkpoint(name, config, PROJECT_DIR)

# Generate fake images
fake_imgs = config['generator'].eval().forward(noise)

# Show fake images
hp.show(fake_imgs)

In [None]:
# Load the checkpoint indicated at 'name'
name = 'epoch_40'
hp.load_checkpoint(name, config, PROJECT_DIR)

# Generate fake images
fake_imgs = config['generator'].eval().forward(noise)

# Show fake images
hp.show(fake_imgs)