In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

%load_ext tensorboard

# Model Architecture

## Discriminator Class
- A convolutional neural network designed to classify images as real or fake.
- Takes in an image and outputs a probability value indicating whether the image is real or fake.
- Uses convolutional layers with **LeakyReLU** activations and **BatchNorm** layers.
- The final output is a single probability (between 0 and 1) for each image.

## Generator Class
- A convolutional neural network that generates synthetic images from random noise vectors.
- Takes in a latent vector (usually a random noise vector) and outputs an image.
- Uses transposed convolution layers to upsample the latent vector into an image of the desired size.
- Outputs images normalized between [-1, 1] using the **Tanh** activation function.

## Initialize Weights
- A function to initialize the weights of the model layers using **normal distribution** (mean = 0, std = 0.02).
- The **Convolutional layers** (`Conv2d` and `ConvTranspose2d`) and **Batch Normalization layers** (`BatchNorm2d`) are initialized to improve training stability.
- For **BatchNorm layers**, the scaling factor (`gamma`) is initialized using **normal distribution** with mean = 1 and std = 0.02 to allow slight flexibility in activation scaling.
- The **bias** of BatchNorm layers is initialized to 0 to prevent any initial shift in activations.

## Gradient Penalty
- Gradient Penalty is used to enforce Lipshitz constraints

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        """
        Implements the Discriminator for DCGAN with conditional inputs.
        The model follows a convolutional architecture that progressively reduces spatial dimensions
        while increasing the feature depth. The final output is a single scalar indicating the probability
        that the input image is real or fake.

        Parameters:
        -----------
        channels_img : int
            Number of channels in the input image.
            (For RGB images, this is typically 3. For grayscale images like MNIST, it's 1.)
        features_d : int
            Number of feature maps in the first convolutional layer.
            This number increases in deeper layers to capture more complex features.
        num_classes : int
            Number of classes in the dataset (used for conditional embedding).
        img_size : int
            Height and width of the input images (assumed square images: img_size x img_size).
        """
        super().__init__()

        # Learnable embedding table to represent each class as an img_size x img_size feature map
        # Embedding is class-instance (per model) look up table for each classes
        self.embed = nn.Embedding(num_classes, img_size * img_size)
        self.img_size = img_size

        self.model = nn.Sequential(
            # Layer 1: Initial Convolution
            # Input: (N, channels_img + 1, img_size, img_size)  [Extra channel from embedding]
            # Output: (N, features_d, img_size/2, img_size/2)
            nn.Conv2d(
                channels_img + 1,  # Input channels: original image + class embedding
                features_d,         # Output feature maps
                kernel_size=4,      # 4x4 convolution kernel
                stride=2,           # Reduces spatial dimensions by half
                padding=1           # Maintains proper output size
            ),
            nn.LeakyReLU(0.2),  # LeakyReLU activation with negative slope 0.2

            # Layer 2: Downsampling (img_size/2 -> img_size/4)
            self._convolutionBlock(features_d, features_d * 2, 4, 2, 1),

            # Layer 3: Downsampling (img_size/4 -> img_size/8)
            self._convolutionBlock(features_d * 2, features_d * 4, 4, 2, 1),

            # Layer 4: Downsampling (img_size/8 -> img_size/16)
            self._convolutionBlock(features_d * 4, features_d * 8, 4, 2, 1),

            # Final Layer: Fully Connected Convolution
            # Input: (N, features_d*8, img_size/16, img_size/16)
            # Output: (N, 1, 1, 1) - Single score per image
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _convolutionBlock(self, in_channels, out_channels, kernel_size, stride, padding):
        """
        Defines a convolutional block with:
        - A 2D Convolution
        - Instance Normalization (stabilizes training by normalizing feature maps)
        - LeakyReLU Activation (avoids dead neurons)

        Parameters:
        -----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output feature maps
        kernel_size : int
            Size of the convolutional kernel
        stride : int
            Stride of the convolution (typically 2 for downsampling)
        padding : int
            Padding applied to the convolution (typically 1 to maintain proper dimensions)
        
        Returns:
        --------
        nn.Sequential : A sequential block of operations
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,  # Bias is removed as InstanceNorm handles normalization
            ),
            nn.InstanceNorm2d(out_channels, affine=True),  # Normalizes feature maps to stabilize training
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        """
        Defines the forward pass of the Discriminator.
        Labels are first converted to an embedding, reshaped, and concatenated with the image.

        Parameters:
        -----------
        x : torch.Tensor
            Input image tensor of shape (N, channels_img, img_size, img_size)
        labels : torch.Tensor
            Class labels of shape (N,) mapped to embeddings

        Returns:
        --------
        torch.Tensor
            Output tensor of shape (N, 1, 1, 1), representing probability scores of being real/fake.
        """
        # Convert labels into a spatial feature map (N, 1, img_size, img_size)
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        
        # Concatenate image with class embedding as an additional channel
        x = torch.cat([x, embedding], dim=1)
        
        return self.model(x)


In [4]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
        """
        Implements the Generator for DCGAN.
        The Generator takes a random noise vector (latent space) and transforms it 
        into a realistic-looking image through a series of transposed convolutions.
        
        Parameters:
        -----------
        z_dim : int
            Dimension of the latent noise vector (typically 100 in DCGAN implementations).
        channels_img : int
            Number of channels in the generated image.
            (For RGB images, this is typically 3. For grayscale images, it's 1.)
        features_g : int
            Number of feature maps in the first transposed convolutional layer.
            This number scales down in deeper layers to generate finer details.
        num_classes : int
            Number of distinct classes (e.g., for conditional GANs). This is used to create the 
            embedding layer to condition the generation on class labels.
        img_size : int
            The size of the generated image (typically 64 for DCGAN). This determines the output size 
            after all transposed convolutions and defines the final image dimensions (img_size x img_size).
        embed_size : int
            Size of the embedding for class labels. This represents how much information from the class 
            label is encoded and passed into the generator.
        """
        super().__init__()

        self.img_size = img_size
        self.embed = nn.Embedding(num_classes, embed_size)

        self.model = nn.Sequential(
            # ---------------------------------------
            # Layer 1: Transform Noise Vector (z_dim) into Feature Maps
            # ---------------------------------------
            # Input: (N, z_dim, 1, 1)  [Latent space input]
            # Output: (N, features_g*16, 4, 4)
            # Explanation:
            # - Converts the 1x1 noise vector into a 4x4 feature map.
            # - `stride=1` and `padding=0` ensure the output starts as exactly 4x4.
            self._convolutionBlock(z_dim+embed_size, features_g*16, kernel_size=4, stride=1, padding=0),

            # ---------------------------------------
            # Layer 2: Upsample to 8x8
            # ---------------------------------------
            # Input: (N, features_g*16, 4, 4)
            # Output: (N, features_g*8, 8, 8)
            self._convolutionBlock(features_g*16, features_g*8, kernel_size=4, stride=2, padding=1),

            # ---------------------------------------
            # Layer 3: Upsample to 16x16
            # ---------------------------------------
            # Input: (N, features_g*8, 8, 8)
            # Output: (N, features_g*4, 16, 16)
            self._convolutionBlock(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1),

            # ---------------------------------------
            # Layer 4: Upsample to 32x32
            # ---------------------------------------
            # Input: (N, features_g*4, 16, 16)
            # Output: (N, features_g*2, 32, 32)
            self._convolutionBlock(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1),

            # ---------------------------------------
            # Final Layer: Upsample to 64x64 (Target Image Size)
            # ---------------------------------------
            # Input: (N, features_g*2, 32, 32)
            # Output: (N, channels_img, 64, 64) [Final generated image]
            # Explanation:
            # - The final layer **does not use BatchNorm** (per the DCGAN paper).
            # - Uses `Tanh` activation to output values in [-1, 1] to match normalized image range.
            nn.ConvTranspose2d(
                in_channels=features_g*2,
                out_channels=channels_img,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.Tanh()
        )
        
    def _convolutionBlock(self, in_channels, out_channels, kernel_size, stride, padding):
        """
        Defines a transposed convolutional block used in the Generator.
        Each block consists of:
        - A Transposed Convolution (upsampling operation)
        - Batch Normalization (to stabilize training)
        - ReLU Activation (for non-linearity)
        
        Parameters:
        -----------
        in_channels : int
            Number of input feature maps.
        out_channels : int
            Number of output feature maps.
        kernel_size : int
            Size of the convolutional kernel.
        stride : int
            Stride of the transposed convolution (typically 2 for upsampling).
        padding : int
            Padding applied to the transposed convolution (typically 1 for proper output size).
        
        Returns:
        --------
        nn.Sequential : A sequential block of operations.
        """
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False  # No bias since BatchNorm is used.
            ),
            nn.BatchNorm2d(out_channels),  # Stabilizes training by normalizing activations.
            nn.ReLU()  # Activation function to introduce non-linearity.
        )

    def forward(self, x, labels):
        """
        Defines the forward pass of the Generator.
        
        Parameters:
        -----------
        x : torch.Tensor
            Input noise vector of shape (N, z_dim, 1, 1).
        labels : torch.Tensor
            Tensor of shape (N,) containing class labels for conditional generation.
        
        Returns:
        --------
        torch.Tensor
            Output image tensor of shape (N, channels_img, 64, 64).
        """
        # Embed the class labels and concatenate with the input noise vector
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        
        # Pass the concatenated vector through the model
        return self.model(x)


In [5]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            # Per DCGAN paper, we initialize weights from a normal distribution with mean=0, std=0.02
            # This helps stabilize training and prevents mode collapse in GANs.
            nn.init.normal_(m.weight.data, 0.0, 0.02) 

            # For batch normalization layers, these are default values but explicitly stating it again
            if isinstance(m, nn.BatchNorm2d):
                # Gamma (scaling factor) is initialized to follow N(1, 0.02) per DCGAN recommendations
                nn.init.normal_(m.weight.data, 1.0, 0.02)  

                # Beta (bias) is set to zero, ensuring the initial batch normalization does not shift activations
                nn.init.constant_(m.bias.data, 0)  


In [6]:
def gradient_penalty(critic, labels, real, fake, device):
    """
    Computes the gradient penalty for enforcing the Lipschitz constraint in Wasserstein GANs with Gradient Penalty (WGAN-GP).
    
    Args:
        critic (nn.Module): The critic (discriminator) network that maps images to real-valued scores.
        real (torch.Tensor): A batch of real images with shape (batch_size, channels, height, width).
        fake (torch.Tensor): A batch of generated (fake) images with the same shape as `real`.
        device (torch.device): The device (CPU or GPU) where computations should be performed.

    Returns:
        torch.Tensor: The gradient penalty term, a scalar tensor encouraging the gradient norm to be close to 1.
    """

    # Extract shape parameters
    batch_size, channel_size, height, width = real.shape  # Expecting a 4D tensor: (N, C, H, W)
    
    # Generate a random interpolation factor `epsilon` for each sample in the batch.
    # Shape: (batch_size, 1, 1, 1) -> Broadcasting ensures uniform weighting across all channels & spatial dimensions.
    epsilon = torch.rand((batch_size, 1, 1, 1), device=device).repeat(1, channel_size, height, width)  # (N, C, H, W)

    # Compute interpolated images as a convex combination of real and fake samples
    # Each interpolated image is: ε * real + (1 - ε) * fake
    interpolated_images = (real * epsilon) + (fake * (1 - epsilon))  # (N, C, H, W)

    # Pass interpolated images through the critic to obtain output scores
    mixed_scores = critic(interpolated_images, labels)  # (N, 1) assuming critic outputs a single value per image

    # Compute gradients of the critic scores w.r.t. the interpolated images
    # grad now holds ∂critic/∂interpolated_images with shape (N, C, H, W)
    grad = torch.autograd.grad(
        inputs=interpolated_images,    # Input tensor (N, C, H, W)
        outputs=mixed_scores,          # Scalar output tensor per sample (N, 1)
        grad_outputs=torch.ones_like(mixed_scores),  # Backpropagation signal, same shape as `mixed_scores` (N, 1)
        create_graph=True,             # Enables higher-order gradients (for penalty term calculation)
        retain_graph=True,             # Retains computation graph for further operations
    )[0] # Grad returns tuple size equal to number of inputs, for each gradient w.r.t. each inputs

    # Reshape gradient tensor to collapse all non-batch dimensions into a single vector per sample
    grad = grad.view(grad.shape[0], -1)  # (N, C*H*W)

    # Compute L2 norm of gradients per sample along feature dimension
    grad_norm = grad.norm(2, dim=1)  # (N,)

    # Compute the gradient penalty term, enforcing ||∇D(interpolated_images)|| ≈ 1
    grad_penalty = torch.mean((grad_norm - 1) ** 2)  # Scalar tensor

    return grad_penalty


# Conditional Wasserstein GAN with Gradient Penalty on CelebA dataset

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters for training
lr = 1e-4
batch_size = 64 
img_size = 64
img_channels = 1
num_classes = 10
generator_embedding = 100
z_dim = 100
num_epochs = 5
features_disc = 64
features_gen = 64
critic_iterations = 5
lambda_penalty = 10

# Data transformations for preprocessing the CELEB dataset
transformer = transforms.Compose(
    [
        # Resize images to the specified size (img_size x img_size)
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        # Normalize the images so pixel values are between [-1, 1]
        transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)]),
    ]
)

# Load the MNIST dataset with the specified transformations
dataset = datasets.MNIST(root="dataset/", train=True, transform=transformer, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Initialize the generator and discriminator models and move them to the appropriate device (GPU or CPU)
gen_model = Generator(z_dim, img_channels, features_gen, num_classes, img_size, generator_embedding).to(device)
disc_model = Discriminator(img_channels, features_disc, num_classes, img_size).to(device)

# Initialize the weights of both the generator and discriminator using the custom function
initialize_weights(gen_model)
initialize_weights(disc_model)

# Set up Adam optimizers for both models with the same learning rate and betas for stability in training
optimizer_gen = optim.Adam(gen_model.parameters(), lr=lr, betas=(0.0, 0.9))
optimizer_disc = optim.Adam(disc_model.parameters(), lr=lr, betas=(0.0, 0.9))

# Set up TensorBoard writers to log images for both real and fake images
writer_fake = SummaryWriter(f"runs/DCGAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/DCGAN_MNIST/real")

In [11]:
%tensorboard --logdir=runs/DCGAN_MNIST --bind_all --port=6006
print("Tensorboard is running on port 6006")

# Set the models to training mode
gen_model.train()
disc_model.train()
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, labels) in enumerate(loader):
        # Move the real images to the device (GPU/CPU)
        real = real.to(device)
        labels = labels.to(device)

        # --- Train the critic ---
        for _ in range(critic_iterations):
            # Generate random noise to feed into the generator
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        
            # Generate fake images using the generator
            fake = gen_model(noise, labels)
        
            # Compute the discriminator's output on real images
            disc_real = disc_model(real, labels).reshape(-1)
            
            # Compute the discriminator's output on fake images
            disc_fake = disc_model(fake.detach(), labels).reshape(-1)

            # We want to maximize this, so negative the minimization along with gradient penalty
            penalty = gradient_penalty(disc_model, labels, real, fake, device)
            loss_disc = (-(torch.mean(disc_real) - torch.mean(disc_fake))) + (lambda_penalty * penalty)
        
            disc_model.zero_grad()
            loss_disc.backward(retain_graph=True)
            optimizer_disc.step()

        # --- Train the Generator ---
        # Compute the discriminator's output on the fake images
        output = disc_model(fake, labels).reshape(-1)
        
        # min -E[critic(gen_fake)]
        loss_gen = -torch.mean(output)
        
        gen_model.zero_grad()
        loss_gen.backward()
        optimizer_gen.step()

        # Print losses and log images to TensorBoard every 100 batches
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            # Log real and fake images to TensorBoard
            with torch.no_grad():  # No gradients are needed for this step
                fake = gen_model(noise, labels)
                # Create image grids for real and fake images (up to 32 images)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                # Add images to TensorBoard (real and fake images for visualization)
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Reusing TensorBoard on port 6006 (pid 5797), started 0:07:52 ago. (Use '!kill 5797' to kill it.)

Tensorboard is running on port 6006
Epoch [0/5] Batch 0/937                   Loss D: -12.6377, loss G: 9.1760
Epoch [0/5] Batch 100/937                   Loss D: -19.6038, loss G: 63.0470
Epoch [0/5] Batch 200/937                   Loss D: -15.8541, loss G: 78.5505
Epoch [0/5] Batch 300/937                   Loss D: -13.1848, loss G: 73.0322
Epoch [0/5] Batch 400/937                   Loss D: -12.3208, loss G: 83.5074
Epoch [0/5] Batch 500/937                   Loss D: -14.2987, loss G: 76.1848
Epoch [0/5] Batch 600/937                   Loss D: -12.0301, loss G: 71.2471
Epoch [0/5] Batch 700/937                   Loss D: -8.8279, loss G: 69.6128
Epoch [0/5] Batch 800/937                   Loss D: -8.0103, loss G: 65.7421
Epoch [0/5] Batch 900/937                   Loss D: -9.6818, loss G: 59.8363
Epoch [1/5] Batch 0/937                   Loss D: -8.6036, loss G: 58.8834
Epoch [1/5] Batch 100/937                   Loss D: -8.5183, loss G: 61.4852
Epoch [1/5] Batch 200/937             