In [None]:
# ==========================================
# Part 1 — Setup & Library Installation
# ==========================================
!pip install torch torchvision pytorch-fid --quiet

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from pytorch_fid import fid_score
import numpy as np
import os

# Set device (GPU preferred)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
# ==========================================
# Part 2 — Load Dataset (CIFAR-10)
# ==========================================
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize between -1 and 1
])

dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

print("Dataset loaded successfully!")


100%|██████████| 170M/170M [04:16<00:00, 665kB/s]


Dataset loaded successfully!


In [None]:
# ==========================================
# Part 3 — Define Generator & Discriminator
# ==========================================
latent_dim = 100  # Random noise vector size

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1).squeeze(1)


# Initialize networks
G = Generator().to(device)
D = Discriminator().to(device)

# Loss and Optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
# ==========================================
# Part 4 — Define Correlation-Aware Loss
# ==========================================
def correlation_loss(fake_imgs):
    # Flatten each image and normalize
    x = fake_imgs.view(fake_imgs.size(0), -1)
    x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)

    # Compute pairwise correlation
    corr = torch.mm(x, x.t()) / x.size(1)
    I = torch.eye(corr.size(0), device=corr.device)

    # Encourage correlation matrix to be close to identity (diverse outputs)
    return ((corr - I) ** 2).mean()


In [None]:
# ==========================================
# Part 5 — Training Loop
# ==========================================
epochs = 50
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
os.makedirs("images", exist_ok=True)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        valid = torch.ones(imgs.size(0), device=device)
        fake = torch.zeros(imgs.size(0), device=device)

        # ---------------------
        # Train Generator
        # ---------------------
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim, 1, 1, device=device)
        gen_imgs = G(z)

        g_loss_basic = criterion(D(gen_imgs), valid)
        g_corr_loss = correlation_loss(gen_imgs)
        g_loss = g_loss_basic + 0.05 * g_corr_loss  # Weighted term

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        # Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        real_loss = criterion(D(real_imgs), valid)
        fake_loss = criterion(D(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

    print(f"[Epoch {epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}")

    if (epoch + 1) % 10 == 0:
        save_image(gen_imgs[:25], f"images/epoch_{epoch+1}.png", nrow=5, normalize=True)


[Epoch 1/50] D_loss: 0.6455 | G_loss: 5.0553
[Epoch 2/50] D_loss: 0.2388 | G_loss: 3.5971
[Epoch 3/50] D_loss: 0.3389 | G_loss: 1.1971
[Epoch 4/50] D_loss: 0.4280 | G_loss: 6.6910
[Epoch 5/50] D_loss: 0.1395 | G_loss: 3.2892
[Epoch 6/50] D_loss: 0.2975 | G_loss: 1.9846
[Epoch 7/50] D_loss: 0.1725 | G_loss: 2.5716
[Epoch 8/50] D_loss: 0.4814 | G_loss: 0.9235
[Epoch 9/50] D_loss: 0.4283 | G_loss: 1.0941
[Epoch 10/50] D_loss: 0.1582 | G_loss: 2.7194
[Epoch 11/50] D_loss: 0.3171 | G_loss: 2.9810
[Epoch 12/50] D_loss: 0.2944 | G_loss: 1.3553
[Epoch 13/50] D_loss: 0.2620 | G_loss: 1.7591
[Epoch 14/50] D_loss: 0.4005 | G_loss: 2.2762
[Epoch 15/50] D_loss: 0.0816 | G_loss: 4.1221
[Epoch 16/50] D_loss: 0.3336 | G_loss: 3.2927
[Epoch 17/50] D_loss: 0.1588 | G_loss: 2.6374
[Epoch 18/50] D_loss: 0.1707 | G_loss: 2.2686
[Epoch 19/50] D_loss: 0.1714 | G_loss: 2.6263
[Epoch 20/50] D_loss: 0.1128 | G_loss: 2.7987
[Epoch 21/50] D_loss: 0.1154 | G_loss: 3.6339
[Epoch 22/50] D_loss: 0.0993 | G_loss: 3.29

In [None]:
# ==========================================
# Part 6 — Evaluate FID (Fixed Version)
# ==========================================
import glob
from tqdm import tqdm

# Create directory for generated images
os.makedirs("generated", exist_ok=True)
G.eval()

# Generate around 10,000 fake images for stable FID
num_images = 10000
batch_size_gen = 128

print("Generating fake images for FID evaluation...")

with torch.no_grad():
    total = 0
    for i in tqdm(range(num_images // batch_size_gen)):
        z = torch.randn(batch_size_gen, latent_dim, 1, 1, device=device)
        gen_imgs = G(z)
        for j, img in enumerate(gen_imgs):
            idx = i * batch_size_gen + j
            save_image(img, f"generated/{idx}.png", normalize=True)
            total += 1

print(f"✅ Generated {total} fake images in 'generated/'")

# -----------------------------------------------------
# Check folder contents and real data path
# -----------------------------------------------------
real_path = "./data/cifar10_real"
fake_path = "./generated"

# Save a few real CIFAR-10 samples to a folder for FID comparison
os.makedirs(real_path, exist_ok=True)
if len(glob.glob(f"{real_path}/*.png")) < 1000:
    print("Preparing real CIFAR-10 images...")
    for idx, (imgs, _) in enumerate(dataloader):
        for j, img in enumerate(imgs):
            save_image(img, f"{real_path}/{idx*len(imgs)+j}.png", normalize=True)
        if idx > 80:  # around 10k samples
            break

# -----------------------------------------------------
# Compute FID between real and fake
# -----------------------------------------------------
fid_value = fid_score.calculate_fid_given_paths(
    [real_path, fake_path],
    batch_size=64,
    device=device,
    dims=2048
)

print(f"\n✅ Final FID Score: {fid_value:.2f}")


Generating fake images for FID evaluation...


100%|██████████| 78/78 [00:10<00:00,  7.47it/s]


✅ Generated 9984 fake images in 'generated/'
Preparing real CIFAR-10 images...


100%|██████████| 164/164 [00:41<00:00,  3.95it/s]
100%|██████████| 1000/1000 [04:01<00:00,  4.14it/s]



✅ Final FID Score: 44.61
