In [9]:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from wgan import *
from utils import *

torch.manual_seed(0)


n_epochs = 100
z_dim = 64
display_step = 50  # Only for visualization of my output during training
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10

crit_repeats = 5
# Number of times the Critic will be trained for each Generator Training

device = "cpu"


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataloader = DataLoader(
    MNIST(
        "/Users/abhinaybelde/Desktop/Learning/datasets/raw", download=True, transform=transform
    ),
    batch_size=batch_size,
    shuffle=True,
)


generator = Generator(z_dim).to(device) 
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))

critic = Critic().to(device)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr, betas=(beta_1, beta_2))


generator = generator.apply(weights_init)
critic = critic.apply(weights_init)

#######################################################################
# Gradient Penalty Calculation -  Calculate Gradient of Critic Score
#######################################################################
def gradient_of_critic_score(critic, real, fake, epsilon):
    """
    Function to compute the gradient of the critic's scores for interpolated images.

    This function is a key component of the Gradient Penalty in WGAN-GP (Wasserstein GAN with Gradient Penalty),
    a popular GAN architecture. The gradient penalty encourages the critic's gradient norms to be close to 1,
    which ensures the 1-Lipschitz constraint needed for the Wasserstein distance function to be valid.

    Args:
        critic (nn.Module): The critic model, typically a neural network.
        real (torch.Tensor): Batch of real images.
        fake (torch.Tensor): Batch of fake images generated by the generator.
        epsilon (float): The weight for the interpolation between real and fake images.

    Returns:
        gradient (torch.Tensor): The computed gradient of the critic's scores for the interpolated images.
    """

    # Create the interpolated images as a weighted combination of real and fake images
    interpolated_images = real * epsilon + fake * (1 - epsilon)

    mixed_scores = critic(interpolated_images)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

""" Whats the significane of the below line in above

interpolated_images = real * epsilon + fake * (1 - epsilon)

 this line is used to enforce a constraint on the gradients of the critic (also called discriminator). This is known as the "gradient penalty".

The idea behind the gradient penalty is to prevent the critic from becoming too powerful, which could cause the generator to fail to learn. This is done by encouraging the gradients of the critic to have a norm of 1 across the image space. To enforce this, we don't just consider the gradients at the real images or the fake images, but also at random points along the straight line between a pair of real and fake images. These points are the "interpolated images".

By computing the gradients at these interpolated images and adding a penalty to the critic's loss function if these gradients deviate from 1, we can ensure that the critic is a 1-Lipschitz function, which is a key property needed for the theoretical guarantees of the Wasserstein distance to hold. This results in more stable training dynamics for the GAN.

So, in essence, the line is generating the interpolated images at which the critic's gradients will be evaluated.

"""
pass


In [17]:
###############################################################################
# Unit Test for above Method
###############################################################################
def test_gradient_of_critic_score(image_shape):
    """
    Test the gradient_of_critic_score function by creating real and fake images and a random epsilon.

    This function checks that the gradient has the correct shape and contains both positive and negative values.

    Args:
    image_shape (tuple): The shape of the images in the form of (batch_size, channels, height, width).

    Returns:
    gradient (tensor): The gradient calculated by the gradient_of_critic_score function.
    """

    # Create real and fake images by adding and subtracting 1 to and from random numbers, respectively
    real = torch.randn(*image_shape, device=device) + 1
    fake = torch.randn(*image_shape, device=device) - 1

    # Define the shape of epsilon, which should be the same as image_shape but with all dimensions except the first set to 1
    epsilon_shape = [1 for _ in image_shape]  # [1, 1, 1, 1]
    epsilon_shape[0] = image_shape[0]
    # print(epsilon_shape)
    # Create a random epsilon tensor that requires gradient
    epsilon = torch.rand(epsilon_shape, device=device).requires_grad_()
    # print(epsilon.shape)
    # Compute the gradient of the critic score using the function gradient_of_critic_score
    gradient = gradient_of_critic_score(critic, real, fake, epsilon)

    # Check that the shape of the gradient matches image_shape
    assert tuple(gradient.shape) == image_shape

    # Check that the gradient contains both positive and negative values
    assert gradient.max() > 0
    assert gradient.min() < 0

    # Return the gradient for potential further analysis or use
    return gradient



gradient = test_gradient_of_critic_score((256, 1, 28, 28))
print("Success!")


Success!


In [19]:

###############################################################################
# Gradient Penalty Calculation - Calculate the Penalty on The Norm of Gradient
###############################################################################
def gradient_penalty_l2_norm(gradient):
    """
    Calculate the L2 norm of the gradient for enforcing the 1-Lipschitz constraint in Wasserstein GAN with Gradient Penalty (WGAN-GP).

    The gradient penalty is calculated as the mean square error of the gradient norms from 1. The gradient penalty encourages the gradients of the critic to be unit norm, which is a key property of 1-Lipschitz functions.

    Args:
    gradient (torch.Tensor): The gradients of the critic's scores with respect to the interpolated images.

    Returns:
    torch.Tensor: The gradient penalty.
    """
    # Reshape each image in the batch into a 1D tensor (flatten the images)
    gradient = gradient.view(len(gradient), -1)

    gradient_norm = gradient.norm(2, dim=1)

    # Calculate the penalty as the mean squared distance of the norms from 1.
    penalty = torch.mean((gradient_norm - 1) ** 2)

    return penalty



In [21]:

###############################################################################
# Unit Test for above Method
###############################################################################
def test_gradient_penalty_l2_norm(image_shape):
    """
    Test the gradient_penalty_l2_norm function with different gradients.

    This function creates gradients that are known to be bad, good, and random, and checks that the gradient penalty is high, low, and close to 1, respectively.

    Args:
    image_shape (tuple): The shape of the images in the form of (batch_size, channels, height, width).
    """
    # Create a gradient of all zeros - this is a "bad" gradient because the norm is 0, not 1
    bad_gradient = torch.zeros(*image_shape)


    # Calculate the penalty for the bad gradient, should be high (1 in this case because (0-1)^2 = 1)
    bad_gradient_penalty = gradient_penalty_l2_norm(bad_gradient)
    print(bad_gradient_penalty)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.0))

    # Calculate the size of an image in the batch
    image_size = torch.prod(torch.Tensor(image_shape[1:]))  # 28 * 28 => 784

    print("torch.sqrt(image_size) ", torch.sqrt(image_size))

    # Create a gradient of all ones divided by the square root of the image size
    # This is a "good" gradient because the norm is 1
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)  # => tensor(28.)

    # Calculate the penalty for the good gradient, should be low (0 in this case because (1-1)^2 = 0)
    good_gradient_penalty = gradient_penalty_l2_norm(good_gradient)
    print(good_gradient_penalty)
    assert torch.isclose(good_gradient_penalty, torch.tensor(0.0))

    # Create a random gradient by calling the gradient_of_critic_score function
    random_gradient = test_gradient_of_critic_score(image_shape)

    # Calculate the penalty for the random gradient, should be close to 1 if gradient_of_critic_score is working correctly
    random_gradient_penalty = gradient_penalty_l2_norm(random_gradient)

    assert torch.abs(random_gradient_penalty - 1) < 0.1



test_gradient_penalty_l2_norm((256, 1, 28, 28))
print("Success!")



tensor(1.)
torch.sqrt(image_size)  tensor(28.)
tensor(5.6843e-14)
Success!


In [None]:

##############################
# Final Training
##############################

import matplotlib.pyplot as plt

current_step = 0
generator_losses = []
critic_losses_across_critic_repeats = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_critic_loss_for_this_iteration = 0
        for _ in range(crit_repeats):

            #########################
            #  Train Critic
            #########################
            critic_optimizer.zero_grad()

            fake_noise = get_noise(cur_batch_size, z_dim, device=device)

            fake = generator(fake_noise)

            critic_fake_prediction = critic(fake.detach())

            crit_real_pred = critic(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            # epsilon will be a Tensor of size torch.Size([128, 1, 1, 1]) for batch_size of 128

            gradient = gradient_of_critic_score(critic, real, fake.detach(), epsilon)

            gp = gradient_penalty_l2_norm(gradient)

            crit_loss = get_crit_loss(
                critic_fake_prediction, crit_real_pred, gp, c_lambda
            )

            # Keep track of the average critic loss in this batch
            mean_critic_loss_for_this_iteration += crit_loss.item() / crit_repeats

            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer i.e. the weights
            critic_optimizer.step()
        critic_losses_across_critic_repeats += [mean_critic_loss_for_this_iteration]

        #########################
        #  Train Generators
        #########################
        gen_optimizer.zero_grad()

        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)

        fake_2 = generator(fake_noise_2)

        critic_fake_prediction = critic(fake_2)

        gen_loss = get_gen_loss(critic_fake_prediction)

        gen_loss.backward()

        # Update the weights
        gen_optimizer.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ##################################
        #  Log Progress and Visualization
        ##################################
        # Do the below visualization for each display_step (i.e. each 50 step)
        if current_step % display_step == 0 and current_step > 0:
            # Calculate Generator Mean loss for the latest display_steps (i.e. latest 50 steps)
            # list[-x:]   # last x items in the array
            generator_mean_loss_display_step = (
                sum(generator_losses[-display_step:]) / display_step
            )

            # Calculate Critic Mean loss for the latest display_steps (i.e. latest 50 steps)
            critic_mean_loss_display_step = (
                sum(critic_losses_across_critic_repeats[-display_step:]) / display_step
            )
            print(
                f"Step {current_step}: Generator loss: {generator_mean_loss_display_step}, critic loss: {critic_mean_loss_display_step}"
            )

            # Plot both the real images and fake generated images
            plot_images_from_tensor(fake)
            plot_images_from_tensor(real)

            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples])
                .view(-1, step_bins)
                .mean(1),
                label="Generator Loss",
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses_across_critic_repeats[:num_examples])
                .view(-1, step_bins)
                .mean(1),
                label="Critic Loss",
            )
            plt.legend()
            plt.show()

        current_step += 1
