# **Wasserstein GAN with Gradient Penalty (WGAN-GP)**

## Introduction

Wasserstein GAN with Gradient Penalty (**WGAN-GP**) [WGAN Paper](https://arxiv.org/pdf/1701.07875), [WGAN-GP Paper](https://arxiv.org/pdf/1704.00028) represents an improvement over traditional GAN architectures by tackling problems of training instability and mode collapse. WGAN-GP utilizes the **Wasserstein distance** instead of standard GAN losses, which promotes more stable training dynamics and improved gradient behavior.

A key innovation in WGAN-GP is the **Gradient Penalty (GP)** mechanism. GP replaces the less effective weight clipping method for enforcing the crucial **1-Lipschitz constraint**.

---

## **1️⃣ Why Wasserstein Distance?**

Conventional GANs often rely on **Jensen-Shannon (JS) divergence** to quantify differences between real and generated data distributions.  However, this approach can lead to Vanishing gradients and Mode collapse.

WGAN addresses these issues by minimizing the **Wasserstein-1 distance (Earth Mover’s Distance)**:

$$
W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [\|x - y\|]
$$

This distance metric offers:

✅ **More stable training processes**
✅ **Enhanced gradient flow during training**
✅ **Reduced susceptibility to mode collapse**

---

## **2️⃣ Enforcing the 1-Lipschitz Constraint**

For WGAN methodology to be effective, the critic function $D(x)$ must satisfy the **1-Lipschitz condition**:

$$
| D(x_1) - D(x_2) | \leq | x_1 - x_2 |
$$

The original WGAN implementation employed **weight clipping** to enforce this constraint. However, weight clipping has drawbacks:

❌ Leads to suboptimal gradients
❌ Reduces the learning capacity of the critic

**WGAN-GP overcomes these limitations by introducing the Gradient Penalty (GP)**.

---

## **3️⃣ Gradient Penalty (GP)**

Instead of directly constraining critic weights, **Gradient Penalty (GP) puts a penalty on the gradient norm** of the critic's output with respect to its input:

$$
\lambda \cdot (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2
$$

Where:

- $ \hat{x} $ represents an **interpolated sample**, created between a real data point $x_r$ and a generated data point $x_g$:

  $$
  \hat{x} = \alpha x_r + (1 - \alpha) x_g, \quad \alpha \sim U(0,1)
  $$

- $ \lambda $ is the **gradient penalty coefficient**, typically set to **10**.

- By penalizing deviations from a gradient norm of 1, **Lipschitz continuity is effectively enforced**.

---

## **4️⃣ WGAN-GP Loss Functions**

###  **Critic (Discriminator) Loss**

The critic's objective is to maximize the difference in scores between real and generated samples, while also incorporating the gradient penalty:

$$
L_D = \mathbb{E}_{x_r \sim P_r} [D(x_r)] - \mathbb{E}_{x_g \sim P_g} [D(x_g)] + \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ (\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2 \right]
$$

Where:

- $ D(x_r) $ is the critic's output for **real data**.
- $ D(x_g) $ is the critic's output for **generated data**.
- The **gradient penalty term**

### **Generator Loss**

The generator aims to minimize the **negative critic scores** for generated samples:

$$
L_G = - \mathbb{E}_{x_g \sim P_g} [D(x_g)]
$$

Where $ x_g $ is the output from the generator network.

**Note:** Sigmoid activation and Binary Cross-Entropy (BCE) loss are not used in WGAN-GP.

---

## **5️⃣ Training WGAN-GP**

WGAN-GP training include:

1. **Frequent Critic Updates:** Train the critic network more often than the generator (for example 5 critic updates for every generator update).
2. **Adam Optimizer with Low Momentum:** Use the Adam optimizer with reduced momentum values (e.g., $β_1 = 0.0, β_2 = 0.9$).
3. **Gradient Penalty Application:** Implement gradient penalty as the method for enforcing the Lipschitz constraint, replacing weight clipping.

### **Pseudocode for Training**

1. **Critic Training Steps:**
   - Sample real data $x_r$ and random noise $z$.
   - Generate fake data: $x_g = G(z)$.
   - Calculate critic outputs: $D(x_r)$ and $D(x_g)$.
   - Generate interpolated samples $\hat{x}$.
   - Compute the **gradient penalty** term.
   - Calculate the **critic loss** $L_D$ and update critic network $D$.

2. **Generator Training Steps:**
   - Generate fake data: $x_g = G(z)$.
   - Calculate generator loss $L_G = -D(x_g)$.
   - Update generator network $G$.

---

You can use the data using torch, Kaggle or via the follwoing link: [Dataset](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg)

you will need img_align_celeba.zip file, the label file is list_attr_celeba.txt

In [4]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import grad
import matplotlib.pyplot as plt
import numpy as np

Matplotlib is building the font cache; this may take a moment.


In [11]:
batch_size = 64               # Number of images per batch
image_size = 64               # Image dimensions (64x64)
num_epochs = 50               # Number of training epochs
learning_rate = 0.0001        # Learning rate for the optimizer
beta1 = 0.0                   # Adam optimizer beta1 hyperparameter
beta2 = 0.9                   # Adam optimizer beta2 hyperparameter
latent_dim = 100              # Dimensionality of the latent vector (input to the generator)
channels = 3                  # Number of color channels (3 for RGB)
gradient_penalty_coefficient = 10  # Weight for the gradient penalty term
critic_steps = 5              # Number of critic updates per generator update
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available

In [6]:
import zipfile
import os

# Path to the zip file and extraction folder
zip_file_path = './img_align_celeba.zip'
extracted_folder_path = './img_align_celeba'  # This is the folder where images will be extracted

# Extract images from zip file if not already extracted
if not os.path.exists(extracted_folder_path):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extracted_folder_path)
    print("Extraction completed.")
else:
    print("Images are already extracted.")

Extraction completed.


In [None]:
# ---------------------------
# Dataset and Dataloader
# ---------------------------

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize(image_size),         
    transforms.CenterCrop(image_size),     
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

dataset = dset.ImageFolder(root=extracted_folder_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [34]:
# ---------------------------
# (Optional) Weight Initialization Function
# ---------------------------
def weights_init(m):
    """
    Custom weight initialization for Conv and BatchNorm layers.
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:  # Applies to both Conv2d and ConvTranspose2d
        nn.init.normal_(m.weight.data, 0.0, 0.02)  # Normal distribution with mean 0 and std 0.02
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)  # Constant zero initialization for bias
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)  # Initialize weights with a normal distribution
        nn.init.constant_(m.bias.data, 0.0)  # Bias initialized to zero

# Example of applying weight initialization to generator and discriminator
generator.apply(weights_init)
discriminator.apply(weights_init)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (fc): Linear(in_features=8192, out_features=1, bias=True)
)

In [35]:
# Import necessary libraries
import torch
import torch.nn as nn

# ----------------------------------------------
# Generator: Define your generator here, 20 Marks
# ----------------------------------------------
class Generator(nn.Module):
    def __init__(self, latent_dim, channels, image_size):
        super(Generator, self).__init__()

        self.init_size = image_size // 16 
        self.l1 = nn.Linear(latent_dim, 128 * self.init_size * self.init_size)

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------
# Discriminator (Critic): Define your discriminator here, 20 Marks
# ---------------------------
class Discriminator(nn.Module):
    def __init__(self, channels, image_size):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if bn:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, bn=False), 
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
        )

        self.fc = nn.Linear(512 * (image_size // 16) ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.size(0), -1) 
        validity = self.fc(out)
        return validity

In [37]:
import torch
from torch.autograd import grad

# ---------------------------
# Gradient Penalty Function: Implement the gradient penalty calculation here, 30 Marks
# ---------------------------
def compute_gradient_penalty(D, real_samples, fake_samples, device):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device, requires_grad=True)
    alpha = alpha.expand_as(real_samples)

    interpolates = alpha * real_samples + ((1 - alpha) * fake_samples)

    interpolates = interpolates.to(device)
    interpolates.requires_grad_(True)

    d_interpolates = D(interpolates)

    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(d_interpolates.size(), device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)

    gradient_penalty = ((gradient_norm - 1) ** 2).mean()

    return gradient_penalty

In [38]:
# ---------------------------
# Instantiate Models (Generator and Discriminator) & Initialize Weights
# ---------------------------

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)

# Instantiate generator and discriminator
generator = Generator(latent_dim=latent_dim, channels=channels, image_size=image_size).to(device)
discriminator = Discriminator(channels=channels, image_size=image_size).to(device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# ---------------------------
# Optimizers
# ---------------------------
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, beta2))

# ---------------------------
# Fixed Noise for Visualization
# ---------------------------
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)  # Create a fixed noise vector

In [39]:
# ---------------------------
# WGAN-GP Training Loop: Implement the training loop here, 30 Marks
# ---------------------------

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        # Move data to the appropriate device
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # --------------------
        # Train Discriminator (Critic)
        # --------------------

        # Generate fake images
        z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_images = generator(z).detach()  # detach to avoid backprop through G

        # Get discriminator outputs
        real_validity = discriminator(real_images)
        fake_validity = discriminator(fake_images)

        # Compute Gradient Penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_images, fake_images, device)

        # Calculate Discriminator (Critic) loss (including GP)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty_coefficient * gradient_penalty

        # Backward and optimize critic
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator every `critic_steps`
        if i % critic_steps == 0:
            # --------------------
            # Train Generator
            # --------------------

            # Generate fake images
            z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            fake_images = generator(z)

            # Calculate Generator loss (negative critic output for fake images)
            g_loss = -torch.mean(discriminator(fake_images))

            # Backward and optimize generator
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        # Optional: Print training progress
        if i % 50 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # Save/Visualize generated images (optional, e.g., every epoch)
    with torch.no_grad():
        fake_images = generator(fixed_noise).cpu()
    grid = vutils.make_grid(fake_images, padding=2, normalize=True)
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title(f"Epoch {epoch+1}")
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.show()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (6400x1 and 100x2048)