In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# hyperparameters
batch_size = 128
latent_dim = 100
num_epochs = 25
lr_generator = 2e-4
lr_critic = 2e-4
n_critic = 5
lambda_gp = 0
log_interval = 100

# Create directory for sample outputs
os.makedirs("samples", exist_ok=True)

In [None]:
# transform: scale images to [-1, 1] (to match generator's Tanh output)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# download dataset and create dataloader
train_dataset = torchvision.datasets.FashionMNIST(
    root="./data", train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [3]:
# Generator: maps a latent vector to a 28x28 single-channel image
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()

        # Project latent vector into a 7x7 feature map (with 128 channels)
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 7 * 7 * 128),
            nn.BatchNorm1d(7 * 7 * 128),
            nn.ReLU(True)
        )

        # Upsample with transposed convolutions to reach 28x28
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 7x7 -> 14x14
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),    # 14x14 -> 28x28
            nn.Tanh()  # output in [-1, 1] (assuming inputs are normalized accordingly)
        )

    def forward(self, z):
        # Dense projection from latent space
        x = self.fc(z)

        # Reshape into (N, 128, 7, 7) before deconvolutions
        x = x.view(-1, 128, 7, 7)

        # Produce image through upsampling blocks
        img = self.deconv(x)
        return img


# Critic: maps a 28x28 image to a feature vector (often used for scoring or further losses)
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()

        # Downsample the image into a compact 7x7 representation
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 28x28 -> 14x14
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),# 14x14 -> 7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Flatten and map to a 1024-dim feature vector
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1024),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, img):
        # Convolutional feature extraction
        x = self.conv(img)

        # Final feature representation (returned as "features")
        features = self.fc(x)
        return features

In [4]:
# useful fonctions, kernels and loss

def compute_pairwise_distance(x, y):
    n = x.size(0)
    m = y.size(0)
    x_norm = (x**2).sum(dim=1).view(n, 1)
    y_norm = (y**2).sum(dim=1).view(1, m)
    return x_norm + y_norm - 2 * torch.mm(x, y.t())

def gaussian_kernel_matrix(x, y, sigmas=None):
    if sigmas is None:
        sigmas = [2, 5, 10, 20, 40, 80]
    sigmas = torch.tensor(sigmas, device=x.device, dtype=x.dtype)
    beta = 1. / (2. * (sigmas ** 2))  # Corrected: use sigma^2 in the denominator.
    pairwise_dists = compute_pairwise_distance(x, y).unsqueeze(0)  # shape: (1, n, m)
    beta = beta.view(-1, 1, 1)  # reshape for broadcasting
    kernel_vals = torch.exp(- beta * pairwise_dists)
    return kernel_vals.sum(dim=0)

def linear_kernel_matrix(x, y):
    return torch.mm(x, y.t())

def rational_quadratic_kernel_matrix(x, y, alphas=None):
    if alphas is None:
        alphas = [0.2, 0.5, 1, 2, 5]
    pairwise_dists = compute_pairwise_distance(x, y)  # squared distances
    kernel_sum = 0
    for alpha in alphas:
        kernel_sum += (1 + pairwise_dists / (2 * alpha)) ** (-alpha)
    return kernel_sum

def mixed_rq_dot_kernel_matrix(x, y, alphas=None):
    if alphas is None:
        alphas = [0.2, 0.5, 1, 2, 5]
    pairwise_dists = compute_pairwise_distance(x, y)
    rq_kernel = 0
    for alpha in alphas:
        rq_kernel += (1 + pairwise_dists / (2 * alpha)) ** (-alpha)

    linear_kernel = torch.mm(x, y.t())

    return rq_kernel + linear_kernel


In [7]:
def mmd_loss(real_features, fake_features, kernel=rational_quadratic_kernel_matrix, sigmas=None):
    """
    Computes the (biased) MMD loss between real and fake feature batches using a kernel matrix.
    """

    # Kernel matrix for real-real similarities
    K_XX = kernel(real_features, real_features, sigmas)

    # Kernel matrix for fake-fake similarities
    K_YY = kernel(fake_features, fake_features, sigmas)

    # Kernel matrix for real-fake similarities
    K_XY = kernel(real_features, fake_features, sigmas)

    # Batch sizes
    m = real_features.size(0)
    n = fake_features.size(0)

    # Biased estimator: includes diagonal terms of K_XX and K_YY
    loss = K_XX.sum()/(m*m) + K_YY.sum()/(n*n) - 2*K_XY.sum()/(m*n)
    return loss


In [5]:
def gradient_penalty(critic, real_imgs, fake_imgs, device, lambda_gp=10):
    """
    WGAN-GP gradient penalty computed on random interpolations between real and fake samples.
    """

    # per-sample mixing factor, broadcast over channels/spatial dims
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1, device=device)
    # enable gradients w.r.t. inputs
    interpolates = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)

    scores = critic(interpolates)

    # to aggregate per-sample outputs when calling autograd
    grad_outputs = torch.ones_like(scores, device=device)
    gradients = torch.autograd.grad(
        outputs=scores,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,   # keep graph so this penalty is differentiable
        retain_graph=True,   # keep graph if you backprop through critic elsewhere
        only_inputs=True
    )[0]

    # flatten per sample
    gradients = gradients.view(gradients.size(0), -1)
    # per-sample L2 norm
    grad_norm = gradients.norm(2, dim=1)
    # penalty term
    gp = lambda_gp * ((grad_norm - 1) ** 2).mean()
    return gp


In [None]:
from tqdm import tqdm
from IPython.display import display, clear_output

# networks
generator = Generator(latent_dim).to(device)
critic    = Critic().to(device)

# optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr_generator, betas=(0.5, 0.999))
optimizer_C = optim.Adam(critic.parameters(), lr=lr_critic, betas=(0.5, 0.999))

# fixed noise for evaluation
fixed_noise = torch.randn(64, latent_dim, device=device)

with torch.no_grad():
    fake = generator(fixed_noise).detach().cpu()
grid = torchvision.utils.make_grid(fake, nrow=8, normalize=True)
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Initial Generated Samples")
plt.imshow(np.transpose(grid, (1,2,0)))
plt.show()

# training loop with sample display after each epoch
for epoch in tqdm(range(num_epochs)):
    for batch_idx, (real_imgs, _) in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        batch_size_curr = real_imgs.size(0)

        # critic training
        for _ in range(n_critic):
            noise = torch.randn(batch_size_curr, latent_dim, device=device)
            fake_imgs = generator(noise).detach()  # detach to avoid generator update during critic update

            real_features = critic(real_imgs)
            fake_features = critic(fake_imgs)

            mmd = mmd_loss(real_features, fake_features)
            gp = gradient_penalty(critic, real_imgs, fake_imgs, device, lambda_gp)
            critic_loss = -mmd + gp

            optimizer_C.zero_grad()
            critic_loss.backward()
            optimizer_C.step()

        # generator
        noise = torch.randn(batch_size_curr, latent_dim, device=device)
        fake_imgs = generator(noise)
        fake_features = critic(fake_imgs)
        real_features = critic(real_imgs)
        generator_loss = mmd_loss(real_features, fake_features)

        optimizer_G.zero_grad()
        generator_loss.backward()
        optimizer_G.step()

        # training stats
        if batch_idx % log_interval == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} | "
                  f"Critic Loss: {critic_loss.item():.4f} | Generator Loss: {generator_loss.item():.4f}")

    with torch.no_grad():
        fake = generator(fixed_noise).detach().cpu()
    grid = torchvision.utils.make_grid(fake, nrow=8, normalize=True)

    fig, ax = plt.subplots(figsize=(8,8))
    ax.imshow(np.transpose(grid, (1,2,0)))
    ax.axis("off")
    ax.set_title(f"Epoch {epoch+1} Generated Samples")
    fig.savefig(f"samples/epoch_{epoch+1:03d}.png")
    plt.show()
    plt.close(fig)
    clear_output(wait=True)

In [None]:
from PIL import Image

sample_img = Image.open(f"samples/epoch_{num_epochs:03d}.png")
plt.figure(figsize=(6,6))
plt.imshow(sample_img)
plt.axis("off")
plt.title("Generated Samples - Final Epoch")
plt.show()

In [None]:
import shutil
from IPython.display import FileLink

shutil.make_archive('samples', 'zip', 'samples')
FileLink("samples.zip")

In [None]:
from google.colab import files
files.download('samples.zip')