## Implementation of GAN for MNIST digit generation

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import numpy as np

batch_size = 64
lr = 0.0002
real_val = 1
fake_val = 0
device = torch.device("mps")

mnist = MNIST(root="data", train=True, download=True, 
              transform=Compose([ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]))

dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
examples = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.imshow(np.transpose(torchvision.utils.make_grid(examples[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))

In [None]:
class Generator(nn.Module):
    def __init__(self, hidden_dim, latent_dim, img_channel):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.img_channel = img_channel

        self.model = nn.Sequential(
            # Input
            nn.Linear(latent_dim, hidden_dim * 4 * 7 * 7),
            nn.BatchNorm1d(hidden_dim * 4 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (hidden_dim * 4, 7, 7)),
            # First upsample layer (to 14x14)
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.ReLU(True),
            # Second upsample layer (to 28x28)
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(True),
            # Collapse to img_channel
            nn.ConvTranspose2d(hidden_dim, img_channel, 5, 1, 2, bias=False),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, hidden_dim, img_channel):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.img_channel = img_channel

        self.model = nn.Sequential(
            # First downsample layer
            nn.Conv2d(img_channel, hidden_dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            # Second downsample layer
            nn.Conv2d(hidden_dim, hidden_dim * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # Output
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim * 2 * 7 * 7, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)#.view(-1, 1).squeeze(1)

In [None]:
generator = Generator(64, 100, 1).to(device)
print(generator)
discriminator = Discriminator(64, 1).to(device)
print(discriminator)

In [None]:
loss = nn.BCELoss()

disOptimiser = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
genOptimiser = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

test_noise = torch.randn(batch_size, 100, device=device)

In [None]:
img_list = []
G_losses = []
D_losses = []
iters = 0

num_epochs = 5

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):

        # (1) Update the discriminator with real data
        discriminator.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        label = torch.full((batch_size,), real_val, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = discriminator(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = loss(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        # (2) Update the discriminator with fake data
        # Generate batch of latent vectors
        noise = torch.randn(batch_size, 100, device=device)
        # Generate fake image batch with G
        fake = generator(noise)
        label.fill_(fake_val)
        # Classify all fake batch with D
        output = discriminator(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = loss(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        disOptimiser.step()

        # (3) Update the generator with fake data
        generator.zero_grad()
        label.fill_(real_val)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discriminator(fake).view(-1)
        # Calculate G's loss based on this output
        errG = loss(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        genOptimiser.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = generator(test_noise).detach().cpu()
            img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:



# Determine the grid size
grid_size = int(np.ceil(np.sqrt(len(img_list))))

# Create a new figure
fig, axs = plt.subplots(grid_size, grid_size, figsize=(12, 12))

# Iterate over all images and plot them
for i, img in enumerate(img_list):
    ax = axs[i // grid_size, i % grid_size]
    ax.imshow(np.transpose(img, (1, 2, 0)))
    ax.set_title(f'Step {(i+1)*500}')  # Add title with image number
    ax.axis('off')

# If there are less images than grid cells, remove the empty plots
if len(img_list) < grid_size**2:
    for i in range(len(img_list), grid_size**2):
        fig.delaxes(axs.flatten()[i])

# Show the plot
plt.show()

In [None]:
import pandas as pd

# Convert lists to pandas Series
G_losses_pd = pd.Series(G_losses)
D_losses_pd = pd.Series(D_losses)

# Calculate moving averages
window_size = 50  # Define the size of the moving average window
G_losses_smooth = G_losses_pd.rolling(window_size).mean()
D_losses_smooth = D_losses_pd.rolling(window_size).mean()

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")

# Plot original losses in faded color
plt.plot(G_losses, label="G", color='skyblue')
plt.plot(D_losses, label="D", color='lightcoral')

# Plot moving averages in solid color
plt.plot(G_losses_smooth, label="G Smooth", color='blue')
plt.plot(D_losses_smooth, label="D Smooth", color='red')

plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Conditional GAN (cGAN)

We will encode our classes as embedding layer then scale them using a linear layer to act as a separate image channel.

In [None]:
class ConditionalGenerator(nn.Module):
    def __init__(self, hidden_dim, latent_dim, img_channel, num_classes):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.img_channel = img_channel
        self.num_classes = num_classes

        self.label_embedding = nn.Sequential(
            nn.Embedding(num_classes, 49),
            nn.Unflatten(1, (1, 7, 7))
        )

        self.reshape_noise = nn.Sequential(
            # Input [batch, latent_dim, channels]
            nn.Linear(latent_dim, hidden_dim * 4 * 7 * 7),
            nn.BatchNorm1d(hidden_dim * 4 * 7 * 7),
            nn.ReLU(True),
            nn.Unflatten(1, (hidden_dim * 4, 7, 7))
        )

        self.upsample = nn.Sequential(
            # First upsample layer (to 14x14)
            nn.ConvTranspose2d(hidden_dim * 4 + 1, hidden_dim * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.ReLU(True),
            # Second upsample layer (to 28x28)
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(True),
            # Collapse to img_channel
            nn.ConvTranspose2d(hidden_dim, img_channel, 5, 1, 2, bias=False),
            nn.Tanh()
        )

    def forward(self, z, y):
        y = self.label_embedding(y)
        z = self.reshape_noise(z)
        z = torch.cat((z, y), dim=1)
        return self.upsample(z)

In [None]:
class ConditionalDiscriminator(nn.Module):
    def __init__(self, hidden_dim, img_channel, num_classes):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.img_channel = img_channel
        self.num_classes = num_classes

        self.label_embedding = nn.Sequential(
            nn.Embedding(num_classes, 49),
            nn.Linear(49, 784),
            nn.Unflatten(1, (1, 28, 28))
        )

        self.downsample = nn.Sequential(
            # First downsample layer
            nn.Conv2d(img_channel + 1, hidden_dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            # Second downsample layer
            nn.Conv2d(hidden_dim, hidden_dim * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # Output
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim * 2 * 7 * 7, 1),
            nn.Sigmoid()
        )

    def forward(self, x, y):
        y = self.label_embedding(y)
        x = torch.cat((x, y), dim=1)
        return self.downsample(x)

In [None]:
cgenerator = ConditionalGenerator(64, 100, 1, 10).to(device)
print(cgenerator)
cdiscriminator = ConditionalDiscriminator(64, 1, 10).to(device)
print(cdiscriminator)

cdis_optimiser = torch.optim.Adam(cdiscriminator.parameters(), lr=lr, betas=(0.5, 0.999))
cgen_optimiser = torch.optim.Adam(cgenerator.parameters(), lr=lr, betas=(0.5, 0.999))


In [None]:
# Define a function for training the conditional generator
def cgen_train_step(batch_size, cgen, cdis, cgenoptimiser, loss):
    cgenoptimiser.zero_grad()  # Clear the gradients of the conditional generator optimizer
    noise = torch.randn(batch_size, 100, device=device)  # Generate random noise
    labels = torch.randint(0, 10, (batch_size,), device=device)  # Generate random labels
    fake = cgen(noise, labels)  # Generate fake images using the conditional generator
    validity = cdis(fake, labels)  # Get the validity scores of the fake images from the conditional discriminator
    cgloss = loss(validity, torch.full((batch_size,), real_val, dtype=torch.float, device=device).unsqueeze(-1))  # Calculate the loss for the conditional generator
    cgloss.backward()  # Backpropagate the gradients
    cgenoptimiser.step()  # Update the conditional generator parameters using the optimizer
    return cgloss.item()  # Return the loss value as a scalar

def cdis_train_step(batch_size, cgen, cdis, cdisoptimiser, loss, real_img, labels):
    cdisoptimiser.zero_grad()  # Clear the gradients of the conditional discriminator optimizer

    validity = cdis(real_img, labels)  # Get the validity scores of the real images from the conditional discriminator
    real_loss = loss(validity, torch.full((batch_size,), real_val, dtype=torch.float, device=device).unsqueeze(-1))  # Calculate the loss for the real images
    
    noise = torch.randn(batch_size, 100, device=device)  # Generate random noise
    fake_labels = torch.randint(0, 10, (batch_size,), device=device)  # Generate random labels
    fake = cgen(noise, fake_labels)  # Generate fake images using the conditional generator
    validity = cdis(fake.detach(), fake_labels)  # Get the validity scores of the fake images from the conditional discriminator
    fake_loss = loss(validity, torch.full((batch_size,), fake_val, dtype=torch.float, device=device).unsqueeze(-1))  # Calculate the loss for the fake images

    cdloss = real_loss + fake_loss  # Calculate the total loss for the conditional discriminator
    cdloss.backward()  # Backpropagate the gradients
    cdisoptimiser.step()  # Update the conditional discriminator parameters using the optimizer
    return real_loss.item() + fake_loss.item()  # Return the loss value as a scalar

In [None]:
num_epochs = 10
test_noise = torch.randn(10, 100, device=device)
test_labels = torch.LongTensor(np.arange(10)).to(device)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    for i, (images, label) in enumerate(dataloader):
        real_img = images.to(device)
        labels = label.to(device)
        cgenerator.train()
        batch_size = real_img.shape[0]
        cdis_loss = cdis_train_step(batch_size, cgenerator, cdiscriminator, cdis_optimiser, loss, real_img, labels)
        cgen_loss = cgen_train_step(batch_size, cgenerator, cdiscriminator, cgen_optimiser, loss)
        if i % 50 == 0:
            print('Step: {}/{}, g_loss: {}, d_loss: {}'.format(i, len(dataloader), cgen_loss, cdis_loss), end='\x1b[1K\r')

    cgenerator.eval()
    print('Final: g_loss: {}, d_loss: {}'.format(cgen_loss, cdis_loss)) # Print the loss values
    sample_images = cgenerator(test_noise, test_labels).cpu()
    grid = torchvision.utils.make_grid(sample_images, nrow=10, padding=2, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.axis("off")
    plt.show()