In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import typeguard
from architectures import Discriminator, Encoder, Generator
from PIL import Image
from torch import nn, optim
from torch.utils import tensorboard
from torch.utils.data import DataLoader
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure  # MSE, PSNR
from torchvision import datasets, transforms
from typeguard import CollectionCheckStrategy

typeguard.config.collection_check_strategy = CollectionCheckStrategy.ALL_ITEMS

In [None]:
# ------------------------------
# Initialize Networks, Loss, and Optimizers
# ------------------------------

D: Discriminator = Discriminator()
D.compile(fullgraph=True, mode="max-autotune")
G: Generator = Generator()
G.compile(fullgraph=True, mode="max-autotune")

# Hyperparameters
lr: float = 2e-4
beta1: float = 0.01
gamma: float = 1  # Coefficient for gradient penalty


optimizer_D: optim.Optimizer = optim.AdamW(D.parameters(), lr=lr, betas=(beta1, 0.99))
optimizer_G: optim.Optimizer = optim.AdamW(G.parameters(), lr=lr, betas=(beta1, 0.99))

# Set the maximum number of epochs (or iterations) for annealing
T_max = 100  # for instance, one cycle over 100 epochs

# Create the cosine annealing scheduler
scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=T_max, eta_min=1e-6)
scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_G, T_max=T_max, eta_min=1e-6)


# ------------------------------
# Data Loader for MNIST
# ------------------------------

transform = transforms.Compose(
    [
        transforms.ToTensor(),  # converts image to [0,1]
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # scales to [-1,1]
    ]
)

dataset: datasets.MNIST = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader: DataLoader = DataLoader(dataset, batch_size=64, shuffle=True)


# ------------------------------
# TensorBoard Logging
# ------------------------------

# Create a SummaryWriter instance. You can change the log_dir as needed.
writer = SummaryWriter(log_dir="./runs/experiment_1")
global_step = 0

# ------------------------------
# Training Loop
# ------------------------------

num_epochs: int = 100
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        batch_size: int = real_images.size(0)
        # Symmetric target labels: real=1, fake=-1
        # Use least squares loss (1-v)**2.

        ### === Discriminator Update ===
        # Zero-out the Discriminator optimizer
        optimizer_D.zero_grad()

        # When generating fake images, disable gradients to avoid updating G during D's update.
        with torch.no_grad():
            fake_images: torch.Tensor = G(batch_size)

        # Relativistic LS loss for the Discriminator with R1 + R2 zero mean gradient penalties:
        loss_D: torch.Tensor = D.calculate_relativistic_loss(real_images, fake_images, gamma)
        loss_D.backward()
        optimizer_D.step()

        ### === Generator Update ===
        optimizer_G.zero_grad()

        # Recompute discriminator outputs with updated D:
        # Use the same real images.
        D_real_updated: torch.Tensor = D(real_images)
        # Generate new fake images (fresh forward pass through G)
        fake_images = G(batch_size)
        D_fake_updated: torch.Tensor = D(fake_images)

        # Relativistic LS loss for the Generator: we want D_fake - D_real to be close to 1
        v = D_fake_updated - D_real_updated
        loss_G: torch.Tensor = torch.square(1 - v).mean()

        loss_G.backward()
        optimizer_G.step()

        # Logging for monitoring training
        writer.add_scalar("Loss/Discriminator", loss_D.item(), global_step)
        writer.add_scalar("Loss/Generator", loss_G.item(), global_step)
        global_step += 1
        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
                f"D Loss: {loss_D.item():.4f}, G Loss: {loss_G.item():.4f}"
            )
    # Step the scheduler at the end of each epoch
    scheduler_D.step()
    scheduler_G.step()
    # Update the value of gamma to the desired one
    if gamma > 0.1:
        gamma -= 0.1