In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [19]:
class Discriminator(nn.Module):  
    def __init__(self, img_dim):
        """
        Initializes the Discriminator model.
        
        Args:
            img_dim (int): The total number of pixels in the image (28*28 = 784 for MNIST).
        """
        super().__init__()
        
        # Define a simple fully connected neural network (Multi-Layer Perceptron)
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),  # Fully connected layer: input size -> 128 neurons
            nn.LeakyReLU(0.1),  # LeakyReLU activation with slope 0.1 for negative inputs
            nn.Linear(128, 1),  # Fully connected layer: 128 neurons -> 1 output neuron
            nn.Sigmoid(),  # Sigmoid activation to output a probability (0 to 1)
        )

    def forward(self, x):
        """
        Forward pass of the Discriminator.
        
        Args:
            x (Tensor): Input tensor representing an image (flattened).
            
        Returns:
            Tensor: Probability of the input being real (closer to 1) or fake (closer to 0).
        """
        return self.disc(x)


In [20]:
class Generator(nn.Module):  
    def __init__(self, z_dim, img_dim):
        """
        Initializes the Generator model.
        
        Args:
            z_dim (int): The dimension of the latent noise vector (random input).
            img_dim (int): The total number of pixels in the output image (28*28 = 784 for MNIST).
        """
        super().__init__()
        
        # Define a simple fully connected neural network (Multi-Layer Perceptron)
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),  # Fully connected layer: input (random noise) -> 256 neurons
            nn.LeakyReLU(0.1),  # LeakyReLU activation with slope 0.1 for negative inputs
            nn.Linear(256, img_dim),  # Fully connected layer: 256 neurons -> output size of image
            nn.Tanh(),  # Tanh activation: outputs values between -1 and 1 to match normalized images
        )

    def forward(self, x):
        """
        Forward pass of the Generator.
        
        Args:
            x (Tensor): Input tensor representing random noise.
            
        Returns:
            Tensor: Generated image (flattened), with pixel values between -1 and 1.
        """
        return self.gen(x)


In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters
lr = 3e-4  # Learning rate for both the Generator and Discriminator
z_dim = 64  # Latent space dimension (random noise input to the Generator)
image_dim = 28 * 28 * 1  # Flattened MNIST image size (28x28 grayscale)
batch_size = 32  # Batch size for training
num_epochs = 100  # Number of epochs for training

# Initialize models (Discriminator and Generator)
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

# Generate a batch of random noise (latent vectors) to track generator progress
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

# Data preprocessing and transformation for MNIST dataset
transforms = transforms.Compose(
    [transforms.ToTensor(),  # Convert image to PyTorch tensor
     transforms.Normalize((0.5,), (0.5,))]  # Normalize to [-1, 1] for GAN training
)

# Download MNIST dataset and apply transformations
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

# Load data in batches and shuffle the dataset for training
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizers for both Discriminator and Generator using Adam
# Parameters refer to the entire layers (weights, bias)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

# Loss function (Binary Cross-Entropy) for GAN
criterion = nn.BCELoss()

# TensorBoard writers to visualize generated images and real images
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

Layer: disc.0.weight | Shape: torch.Size([128, 784]) | Requires Grad: True
Layer: disc.0.bias | Shape: torch.Size([128]) | Requires Grad: True
Layer: disc.2.weight | Shape: torch.Size([1, 128]) | Requires Grad: True
Layer: disc.2.bias | Shape: torch.Size([1]) | Requires Grad: True


In [5]:
%tensorboard --logdir=runs --bind_all --port=6006
print("Tensorboard is running on port 6006")

# Initialize step count for TensorBoard logging
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):  # Loop through the data batches
        # Flatten the images and move them to the selected device
        real = real.view(-1, 784).to(device)  # Flatten (28x28) images to (784,)
        batch_size = real.shape[0]  # Get the batch size

        # Training the Discriminator:
        # Maximize log(D(real)) + log(1 - D(G(z)))
        
        # Generate fake images from random noise (latent vectors)
        noise = torch.randn(batch_size, z_dim).to(device)  # Latent vectors (random noise)
        fake = gen(noise)  # Generate fake images using the Generator

        # Discriminator's output for real images
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        # Total Discriminator loss (average of real and fake losses)
        lossD = (lossD_real + lossD_fake) / 2

        # Backpropagate and update Discriminator's weights
        disc.zero_grad()  # Zero out gradients from the previous step
        lossD.backward()  # Backpropagate the loss, these gradients are stored in param.grad attributes of each layer
        opt_disc.step()  # Update Discriminator's parameters using the optimizer based on the gradient saved from previous step

        # Training the Generator:
        # Minimize log(1 - D(G(z))) <--> Maximize log(D(G(z))) (via D's output on fake images)
        
        output = disc(fake).view(-1)  # Discriminator's output for fake images (used to train Generator)
        lossG = criterion(output, torch.ones_like(output))

        # Backpropagate and update Generator's weights
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Training for one batch is done here        
        # Log progress at the beginning of each epoch (and batch_idx == 0 for better logging)
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            # Generate and save images for visualization (using fixed noise)
            # Fixed noise will use currently trained model in each epoch to generate image from same noise
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)  # Generate fake images with fixed noise
                data = real.reshape(-1, 1, 28, 28)  # Reshape real data to image format

                # Create image grids for visualization (real and fake)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)  # Normalize pixel values for display
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                # Log images to TensorBoard
                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step  # Log fake images
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step  # Log real images
                )
                step += 1  # Increment the step count for TensorBoard


Tensorboard is running on port 6006
Epoch [0/100] Batch 0/1875                 Loss D: 0.7443, Loss G: 0.6950
Epoch [1/100] Batch 0/1875                 Loss D: 0.3259, Loss G: 1.4357
Epoch [2/100] Batch 0/1875                 Loss D: 0.5549, Loss G: 1.2190
Epoch [3/100] Batch 0/1875                 Loss D: 0.4574, Loss G: 1.1389
Epoch [4/100] Batch 0/1875                 Loss D: 0.6160, Loss G: 0.8249
Epoch [5/100] Batch 0/1875                 Loss D: 0.5893, Loss G: 1.0247
Epoch [6/100] Batch 0/1875                 Loss D: 0.5417, Loss G: 1.1014
Epoch [7/100] Batch 0/1875                 Loss D: 0.5296, Loss G: 1.3228
Epoch [8/100] Batch 0/1875                 Loss D: 0.9952, Loss G: 0.8284
Epoch [9/100] Batch 0/1875                 Loss D: 0.5422, Loss G: 1.3175
Epoch [10/100] Batch 0/1875                 Loss D: 0.6286, Loss G: 1.1357
Epoch [11/100] Batch 0/1875                 Loss D: 0.7654, Loss G: 0.9302
Epoch [12/100] Batch 0/1875                 Loss D: 0.7705, Loss G: 1.0299