<a href="https://colab.research.google.com/github/AnanyaTyagi/VAE-GAN-Diffusion-Benchmark/blob/main/GANImage_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# --- GPU + Drive ---
!nvidia-smi -L || true

from google.colab import drive
import os, random, json, math, time
import numpy as np
drive_ok = True
try:
    drive.mount('/content/drive', force_remount=True)
    OUT_DIR = "/content/drive/MyDrive/gan_cifar10_runs"   # change if you like
    print("✅ Drive mounted. Saving to:", OUT_DIR)
except Exception as e:
    print("⚠️ Drive mount failed, saving locally. Error:", e)
    OUT_DIR = "/content/gan_cifar10_runs"
    drive_ok = False
os.makedirs(OUT_DIR, exist_ok=True)

# --- Repro/Device ---
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

# --- Small helpers ---
from torchvision.utils import save_image, make_grid
def denorm(x):                      # [-1,1] -> [0,1]
    return (x.clamp(-1,1) + 1)/2

print("Device:", device)
print("OUT_DIR:", OUT_DIR)


GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-75e0568e-7b4d-d12b-eb13-7094963b77e5)
Mounted at /content/drive
✅ Drive mounted. Saving to: /content/drive/MyDrive/gan_cifar10_runs
Device: cuda
OUT_DIR: /content/drive/MyDrive/gan_cifar10_runs


In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ---------- Generator ----------
class Generator(nn.Module):
    def __init__(self, latent_dim=128, fm=128):
        super().__init__()
        self.net = nn.Sequential(
            # (N, latent_dim, 1, 1) -> (N, fm*8, 4, 4)
            nn.ConvTranspose2d(latent_dim, fm*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(fm*8), nn.ReLU(True),

            # 4->8
            nn.ConvTranspose2d(fm*8, fm*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(fm*4), nn.ReLU(True),

            # 8->16
            nn.ConvTranspose2d(fm*4, fm*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(fm*2), nn.ReLU(True),

            # 16->32
            nn.ConvTranspose2d(fm*2, fm, 4, 2, 1, bias=False),
            nn.BatchNorm2d(fm), nn.ReLU(True),

            # (N, fm, 32, 32) -> (N, 3, 32, 32)
            nn.ConvTranspose2d(fm, 3, 3, 1, 1, bias=False),
            nn.Tanh(),  # [-1,1]
        )

    def forward(self, z):
        return self.net(z)

# ---------- Discriminator ----------
class Discriminator(nn.Module):
    def __init__(self, fm=64):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, fm, 3, 1, 1, bias=False),   # 32x32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(fm, fm*2, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(fm*2), nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(fm*2, fm*4, 4, 2, 1, bias=False), # 8x8
            nn.BatchNorm2d(fm*4), nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(fm*4, fm*8, 4, 2, 1, bias=False), # 4x4
            nn.BatchNorm2d(fm*8), nn.LeakyReLU(0.2, inplace=True),
        )
        self.head = nn.Conv2d(fm*8, 1, 4, 1, 0, bias=False)  # NO sigmoid

    def forward(self, x):
        h = self.features(x)
        out = self.head(h)        # shape (N,1,1,1)
        return out.view(-1)       # shape (N,)


# ---------- Weights init (DCGAN) ----------
def weights_init(m):
    name = m.__class__.__name__
    if name.find('Conv') != -1 or name.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif name.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# ---------- Data ----------
BATCH_SIZE = 128
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  # [-1,1]
])
trainset = datasets.CIFAR10("./data", train=True, download=True, transform=tfm)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

print("Train batches:", len(trainloader))


100%|██████████| 170M/170M [00:15<00:00, 10.7MB/s]


Train batches: 391


In [4]:
import tqdm

# --- Hyperparams ---
LATENT_DIM = 128
GEN_FM = 128
DISC_FM = 64
EPOCHS = 100
LR_G = 2e-4
LR_D = 2e-4
BETA1 = 0.5; BETA2 = 0.999

# --- Models/Opt/AMP ---
G = Generator(LATENT_DIM, GEN_FM).to(device)
D = Discriminator(DISC_FM).to(device)
G.apply(weights_init); D.apply(weights_init)

optG = torch.optim.Adam(G.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optD = torch.optim.Adam(D.parameters(), lr=LR_D, betas=(BETA1, BETA2))

# ✅ use logits-safe loss
criterion = nn.BCEWithLogitsLoss()

# ✅ new AMP API (no FutureWarning)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

fixed_noise = torch.randn(64, LATENT_DIM, 1, 1, device=device)

for epoch in range(1, EPOCHS+1):
    G.train(); D.train()
    pbar = tqdm.tqdm(trainloader, desc=f"Epoch {epoch}", unit="batch")
    for real, _ in pbar:
        real = real.to(device)
        N = real.size(0)
        real_label = torch.ones(N, device=device)
        fake_label = torch.zeros(N, device=device)

        # ----- Train D -----
        optD.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device=="cuda")):

            out_real = D(real)                        # logits
            loss_real = criterion(out_real, real_label)

            z = torch.randn(N, LATENT_DIM, 1, 1, device=device)
            fake = G(z).detach()
            out_fake = D(fake)                        # logits
            loss_fake = criterion(out_fake, fake_label)

            loss_D = loss_real + loss_fake

        scaler.scale(loss_D).backward()
        scaler.step(optD)

        # ----- Train G -----
        optG.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(device=="cuda")):
            z = torch.randn(N, LATENT_DIM, 1, 1, device=device)
            fake = G(z)
            out = D(fake)                             # logits
            loss_G = criterion(out, real_label)       # wants "real"

        scaler.scale(loss_G).backward()
        scaler.step(optG)
        scaler.update()

        pbar.set_postfix(loss_D=float(loss_D), loss_G=float(loss_G))

    # ----- Save epoch artifacts -----
    G.eval(); D.eval()
    with torch.no_grad():
        fake_fixed = G(fixed_noise)
    grid_real = make_grid(denorm(real[:64].cpu()), nrow=8)
    grid_fake = make_grid(denorm(fake_fixed.cpu()), nrow=8)
    combo = torch.cat([grid_real, grid_fake], dim=1)
    save_image(combo, os.path.join(OUT_DIR, f"real_vs_fake_epoch_{epoch:03d}.png"))

    ckpt = {
        "epoch": epoch,
        "G": G.state_dict(),
        "D": D.state_dict(),
        "LATENT_DIM": LATENT_DIM,
        "GEN_FM": GEN_FM,
        "DISC_FM": DISC_FM,
    }
    torch.save(ckpt, os.path.join(OUT_DIR, f"dcgan_epoch_{epoch:03d}.pth"))

print("Training complete. Files in:", OUT_DIR)


  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
  with torch.cuda.amp.autocast(enabled=(device=="cuda")):
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  pbar.set_postfix(loss_D=float(loss_D), loss_G=float(loss_G))
Epoch 1: 100%|██████████| 391/391 [00:13<00:00, 29.75batch/s, loss_D=0.512, loss_G=7.64]
Epoch 2: 100%|██████████| 391/391 [00:06<00:00, 61.12batch/s, loss_D=0.882, loss_G=2.95]
Epoch 3: 100%|██████████| 391/391 [00:06<00:00, 60.63batch/s, loss_D=0.336, loss_G=3.91]
Epoch 4: 100%|██████████| 391/391 [00:06<00:00, 60.63batch/s, loss_D=0.458, loss_G=3.93]
Epoch 5: 100%|██████████| 391/391 [00:06<00:00, 60.67batch/s, loss_D=0.513, loss_G=1.74]
Epoch 6: 100%|██████████| 391/391 [00:06<00:00, 58.53batch/s, loss_D=0.562, loss_G=2.67]
Epoch 7: 100%|██████████| 391/391 [00:06<00:00, 60.40batch/s, loss_D=0.958, loss_G=3.53]
Epoch 8: 100%|██████████| 391/391 [00:06<00:00, 60.59batc

Training complete. Files in: /content/drive/MyDrive/gan_cifar10_runs


In [None]:
from pathlib import Path

EXPORT_DIR = Path(OUT_DIR) / "samples_10k"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

G.eval()
total = 10_000
bs = 256
saved = 0
with torch.no_grad():
    while saved < total:
        cur = min(bs, total - saved)
        z = torch.randn(cur, LATENT_DIM, 1, 1, device=device)
        imgs = G(z).cpu()
        imgs = denorm(imgs)  # [0,1] for PNG
        for i in range(cur):
            save_image(imgs[i], EXPORT_DIR / f"{saved+i:05d}.png")
        saved += cur
print(f"Saved {saved} images to {EXPORT_DIR}")


Saved 10000 images to /content/drive/MyDrive/gan_cifar10_runs/samples_10k


In [None]:
!pip -q install "torch-fidelity>=0.3.0"

import os, torch
from torch_fidelity import calculate_metrics

EXPORT_DIR = str(EXPORT_DIR)  # folder with your 10k GAN samples
print("Using generated samples from:", EXPORT_DIR)

# --- Temporarily patch torch.load so torch-fidelity can read its cached stats safely ---
orig_torch_load = torch.load

def _compat_load(*args, **kwargs):
    kwargs.setdefault("weights_only", False)  # allow full unpickling for torch-fidelity cache
    return orig_torch_load(*args, **kwargs)

torch.load = _compat_load

try:
    # === 1) FID + KID vs CIFAR-10 train ===
    metrics_fid = calculate_metrics(
        input1=EXPORT_DIR,
        input2="cifar10-train",
        fid=True,
        kid=True,
        isc=False,             # ✅ use 'isc', not 'inception_score'
        cuda=(device == "cuda"),
        verbose=False,
    )

    # === 2) Inception Score ONLY on generated images ===
    metrics_is = calculate_metrics(
        input1=EXPORT_DIR,
        fid=False,
        kid=False,
        isc=True,              # ✅ enable IS here
        cuda=(device == "cuda"),
        verbose=False,
    )
finally:
    # restore original torch.load no matter what
    torch.load = orig_torch_load

# --- Extract metrics safely ---
fid  = metrics_fid.get("frechet_inception_distance", None)
kidm = metrics_fid.get("kernel_inception_distance_mean", metrics_fid.get("kid_mean", None))
kids = metrics_fid.get("kernel_inception_distance_std",  metrics_fid.get("kid_std",  None))

# IS field names can vary slightly by version
is_m = (metrics_is.get("inception_score_mean")
        or metrics_is.get("isc_mean")
        or metrics_is.get("inception_score"))
is_s = metrics_is.get("inception_score_std", metrics_is.get("isc_std", None))

print("\n=== GAN Metrics ===")
print(f"FID : {fid:.2f}" if fid is not None else "FID : None")

if is_m is None:
    print("IS  : Not returned")
else:
    if is_s is None:
        print(f"IS  : {is_m:.2f}")
    else:
        print(f"IS  : {is_m:.2f} ± {is_s:.2f}")

if kidm is None:
    print("KID : Not returned")
else:
    if kids is None:
        print(f"KID : {kidm:.6f}")
    else:
        print(f"KID : {kidm:.6f} ± {kids:.6f}")

# Save to Drive
with open(os.path.join(OUT_DIR, "metrics.txt"), "w") as f:
    f.write(f"FID: {fid}\n")
    f.write(f"IS_mean: {is_m}\n")
    f.write(f"IS_std: {is_s}\n")
    f.write(f"KID_mean: {kidm}\n")
    f.write(f"KID_std: {kids}\n")

print("\nSaved metrics to:", os.path.join(OUT_DIR, "metrics.txt"))


Using generated samples from: /content/drive/MyDrive/gan_cifar10_runs/samples_10k

=== GAN Metrics ===
FID : 29.29
IS  : 7.04 ± 0.17
KID : 0.022578 ± 0.001404

Saved metrics to: /content/drive/MyDrive/gan_cifar10_runs/metrics.txt


In [None]:
# Load a saved checkpoint from Drive and generate a grid of images
from torchvision.utils import save_image

CKPT_PATH = os.path.join(OUT_DIR, "dcgan_epoch_020.pth")  # pick any epoch you saved

# Rebuild the same model classes, then load weights
ckpt = torch.load(CKPT_PATH, map_location=device)
LATENT_DIM = ckpt.get("LATENT_DIM", 128)
GEN_FM     = ckpt.get("GEN_FM", 128)

G_loaded = Generator(LATENT_DIM, GEN_FM).to(device)
G_loaded.load_state_dict(ckpt["G"])
G_loaded.eval()

with torch.no_grad():
    z = torch.randn(64, LATENT_DIM, 1, 1, device=device)
    imgs = denorm(G_loaded(z)).cpu()
save_image(make_grid(imgs, nrow=8), os.path.join(OUT_DIR, "generated_grid_from_ckpt.png"))
print("Saved:", os.path.join(OUT_DIR, "generated_grid_from_ckpt.png"))
