In [None]:
# Cell 1 — Setup & Folders
import os, time, random, math, json, glob
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

import matplotlib.pyplot as plt
from PIL import Image

# Reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Output layout
OUT_DIR = Path("./gan_outputs")
(OUT_DIR / "images").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "checkpoints").mkdir(parents=True, exist_ok=True)





In [None]:
# Cell 2 — CIFAR-10 Data
DATA_ROOT   = "./data"
img_size    = 32     # CIFAR-10 is 32x32
img_channels= 3      # RGB
batch_size  = 128

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),                               # [0,1]
    transforms.Normalize([0.5, 0.5, 0.5],                # -> [-1,1]
                         [0.5, 0.5, 0.5])
])

trainset = datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transform)
loader   = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

print(f"Train images: {len(trainset)} | Batches/epoch: {len(loader)}")


In [None]:
# Cell 3 — Hyper-parameters
z_dim   = 128
g_width = 64
d_width = 64
lr      = 2e-4
beta1   = 0.5
beta2   = 0.999
epochs  = 30

# Fixed noise to visualize progress
fixed_z = torch.randn(64, z_dim, 1, 1, device=device)


In [None]:
# Cell 4 — Generator
class Generator(nn.Module):
    """
    Input:  (N, z_dim, 1, 1)
    Output: (N, 3, 32, 32) in [-1, 1]
    1x1 -> 4x4 -> 8x8 -> 16x16 -> 32x32
    """
    def __init__(self, z_dim=128, img_channels=3, g_width=64):
        super().__init__()
        self.net = nn.Sequential(
            # 1x1 -> 4x4
            nn.ConvTranspose2d(z_dim, g_width*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(g_width*4),
            nn.ReLU(True),

            # 4x4 -> 8x8
            nn.ConvTranspose2d(g_width*4, g_width*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_width*2),
            nn.ReLU(True),

            # 8x8 -> 16x16
            nn.ConvTranspose2d(g_width*2, g_width, 4, 2, 1, bias=False),
            nn.BatchNorm2d(g_width),
            nn.ReLU(True),

            # 16x16 -> 32x32
            nn.ConvTranspose2d(g_width, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

G = Generator(z_dim=z_dim, img_channels=img_channels, g_width=g_width).to(device)
G


In [None]:
# Cell 5 — Discriminator
class Discriminator(nn.Module):
    """
    Input:  (N, 3, 32, 32)
    Output: (N, 1) logits (use BCEWithLogitsLoss)
    32 -> 16 -> 8 -> 4 -> 1
    """
    def __init__(self, img_channels=3, d_width=64):
        super().__init__()
        # First block: no BatchNorm (DCGAN convention)
        self.conv_in = nn.Sequential(
            nn.Conv2d(img_channels, d_width, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv_mid = nn.Sequential(
            nn.Conv2d(d_width, d_width*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_width*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(d_width*2, d_width*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(d_width*4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_out = nn.Conv2d(d_width*4, 1, 4, 1, 0, bias=False)

    def forward(self, x):
        x = self.conv_in(x)
        x = self.conv_mid(x)
        x = self.conv_out(x)     # (N,1,1,1)
        return x.view(-1, 1)     # (N,1)

D = Discriminator(img_channels=img_channels, d_width=d_width).to(device)
D


In [None]:
# Cell 6 — Losses & Optimizers
criterion = nn.BCEWithLogitsLoss()
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))


In [None]:
# Cell 7 — Training Loop
def train_gan(G, D, loader, epochs, fixed_z, log_interval=100):
    G.train(); D.train()
    for epoch in range(1, epochs+1):
        g_sum, d_sum = 0.0, 0.0
        t0 = time.time()

        for i, (real, _) in enumerate(loader):
            real = real.to(device)
            N = real.size(0)
            real_labels = torch.ones(N, 1, device=device)
            fake_labels = torch.zeros(N, 1, device=device)

            # -------- Train D --------
            # Real
            d_real = D(real)
            d_loss_real = criterion(d_real, real_labels)

            # Fake
            z = torch.randn(N, z_dim, 1, 1, device=device)
            fake = G(z)
            d_fake = D(fake.detach()) 
            d_loss_fake = criterion(d_fake, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            opt_D.zero_grad()
            d_loss.backward()
            opt_D.step()

            # -------- Train G --------
            d_fake_for_g = D(fake)  # no detach
            g_loss = criterion(d_fake_for_g, real_labels)

            opt_G.zero_grad()
            g_loss.backward()
            opt_G.step()

            g_sum += g_loss.item()
            d_sum += d_loss.item()

            if (i + 1) % log_interval == 0:
                print(f"Epoch {epoch:02d} [{i+1:04d}/{len(loader)}] "
                      f"D_loss={d_loss.item():.3f} G_loss={g_loss.item():.3f}")

        # Save sample grid after each epoch
        G.eval()
        with torch.no_grad():
            fake_fixed = G(fixed_z).cpu()
        G.train()

        grid_path = OUT_DIR / "images" / f"epoch_{epoch:03d}.png"
        utils.save_image(fake_fixed, grid_path, nrow=8, normalize=True, value_range=(-1, 1))

        print(f"Epoch {epoch:02d} | time={time.time()-t0:.1f}s | "
              f"mean D={d_sum/len(loader):.3f} | mean G={g_sum/len(loader):.3f} | "
              f"saved {grid_path}")

train_gan(G, D, loader, epochs=epochs, fixed_z=fixed_z)


In [None]:
# Cell 8 — View the latest saved grids
def show_latest(n=4):
    paths = sorted(glob.glob(str(OUT_DIR / "images" / "epoch_*.png")))
    for p in paths[-n:]:
        img = Image.open(p)
        plt.figure(figsize=(4,4))
        plt.imshow(img)
        plt.axis("off")
        plt.title(os.path.basename(p))

show_latest(4)


In [None]:
# Cell 9 — Save/Load checkpoints (optional)
def save_ckpt(epoch):
    torch.save({
        "G": G.state_dict(),
        "D": D.state_dict(),
        "epoch": epoch,
        "hparams": dict(z_dim=z_dim, g_width=g_width, d_width=d_width,
                        lr=lr, betas=(beta1,beta2), img_size=img_size, img_channels=img_channels)
    }, OUT_DIR / "checkpoints" / f"gan_epoch_{epoch:03d}.pt")

def load_ckpt(path):
    ckpt = torch.load(path, map_location=device)
    G.load_state_dict(ckpt["G"])
    D.load_state_dict(ckpt["D"])
    print("Loaded epoch:", ckpt["epoch"])


In [None]:
# Cell 10 — Simple experiment logger
LOG_PATH = OUT_DIR / "experiment_log.jsonl"

def log_experiment(name, params, notes):
    rec = dict(
        name=name,
        params=params,
        notes=notes,
        timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
        images_dir=str(OUT_DIR / "images"),
        ckpt_dir=str(OUT_DIR / "checkpoints"),
    )
    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec) + "\n")
    print("Logged:", name)

# Example: log the baseline run (edit notes after you view images)
log_experiment(
    "Run-A (Baseline CIFAR10, DCGAN)",
    dict(dataset="CIFAR10", img_size=img_size, img_channels=img_channels,
         z_dim=z_dim, g_width=g_width, d_width=d_width, lr=lr, betas=(beta1,beta2),
         batch_size=batch_size, epochs=epochs, loss="BCEWithLogits"),
    notes="Baseline. Check grids at epoch_005/010/020/030 to discuss convergence."
)
