"""
Fréchet Frontier — A Didactic DCGAN + FID Evaluation Suite (PyTorch)
====================================================================
An end-to-end project that *teaches* how to compute and use **FID (Fréchet Inception
Distance)** to evaluate a generative model. Includes a clean DCGAN, a numerically-
stable FID implementation using pretrained InceptionV3 features, cached real-data
statistics, and rich visualizations.

• Dataset: CIFAR-10 (32×32 RGB)
• Model: DCGAN (configurable) — generator & discriminator
• Evaluation: FID over training (per-epoch or every k epochs)
• Visuals: Loss curves, FID curve, sample grids per epoch, t-SNE of features
• Epochs: configurable (5/10/20) for pedagogy

Run
----
python fid_gan_project.py \
  --data_root ./data \
  --outdir ./outputs \
  --epochs 10 \
  --batch_size 128 \
  --fid_every 1 \
  --num_fid_samples 5000

Optional flags: --epochs {5,10,20}, --lr 2e-4, --nz 128, --ngf 64, --ndf 64, --device cuda

Notes
-----
• FID is computed between *Inception features* of real vs generated images.
• We cache real CIFAR-10 stats to speed up repeated runs.
• For teaching: code is thoroughly structured into classes & utilities.
"""

In [1]:
import os
import math
import argparse
import random
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, utils as vutils, models
from torchvision.models import Inception_V3_Weights
import matplotlib.pyplot as plt

try:
    from sklearn.manifold import TSNE
    _HAVE_SK = True
except Exception:
    _HAVE_SK = False

In [2]:
# ------------------------------
# Reproducibility
# ------------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [3]:
# ------------------------------
# DCGAN Models
# ------------------------------

class DCGANGenerator(nn.Module):
    def __init__(self, nz=128, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            # input Z: (N, nz, 1, 1)
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 3, 1, 1, bias=False),
            nn.Tanh(),  # output in [-1, 1]
        )

    def forward(self, z):
        return self.main(z)


class DCGANDiscriminator(nn.Module):
    def __init__(self, ndf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            # input (N, 3, 32, 32)
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.main(x).squeeze(1)  # (N,)


In [4]:
# ------------------------------
# InceptionV3 Feature Extractor & FID
# ------------------------------

class InceptionV3Pool(nn.Module):
    """Pretrained InceptionV3 that outputs 2048-d pool features.
    We apply the official weights' preprocess transforms.
    """
    def __init__(self, device: torch.device):
        super().__init__()
        weights = Inception_V3_Weights.IMAGENET1K_V1
        self.transforms = weights.transforms()  # handles resize(299), normalize
        model = models.inception_v3(weights=weights, aux_logits=False)
        model.fc = nn.Identity()  # output 2048-d embeddings
        self.model = model.to(device).eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.device = device

    @torch.no_grad()
    def features(self, imgs: torch.Tensor, batch_size: int = 64) -> torch.Tensor:
        """imgs: float tensor in [0,1], shape (N,3,H,W). Returns (N,2048)."""
        feats = []
        self.model.eval()
        n = imgs.size(0)
        for i in range(0, n, batch_size):
            batch = imgs[i:i+batch_size].cpu()
            # apply transforms one-by-one (they expect PIL or tensors in [0,1])
            # We'll map using a compiled transform to batch by for-loop to keep it simple
            proc = torch.stack([self.transforms(x) for x in batch])  # (B,3,299,299)
            proc = proc.to(self.device)
            f = self.model(proc)  # (B,2048)
            feats.append(f.detach().cpu())
        return torch.cat(feats, dim=0)


def cov_mean(feats: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute mean and covariance (unbiased, N-1) for features (N,D)."""
    x = feats.double()
    mu = x.mean(dim=0)
    xc = x - mu
    # (D,D) covariance with unbiased denominator (N-1)
    n = x.shape[0]
    cov = (xc.t() @ xc) / (n - 1)
    return mu.float(), cov.float()


def _matrix_sqrt_psd(A: torch.Tensor, eps=1e-10) -> torch.Tensor:
    """Matrix square-root for symmetric PSD matrix using eigen-decomposition."""
    # ensure symmetry
    A = (A + A.t()) * 0.5
    evals, evecs = torch.linalg.eigh(A.double())
    evals = torch.clamp(evals, min=0.0)
    sqrt_evals = torch.sqrt(evals + eps)
    sqrtA = (evecs * sqrt_evals.unsqueeze(0)) @ evecs.t()
    return sqrtA.float()


def frechet_distance(mu1: torch.Tensor, C1: torch.Tensor,
                     mu2: torch.Tensor, C2: torch.Tensor, eps=1e-6) -> float:
    """Compute FID = ||mu1 - mu2||^2 + Tr(C1 + C2 - 2*sqrt(C1^{1/2} C2 C1^{1/2}))."""
    mu1 = mu1.float(); mu2 = mu2.float()
    C1 = C1.float(); C2 = C2.float()

    diff = mu1 - mu2
    # Compute sqrt( C1^{1/2} C2 C1^{1/2} ) via symmetric PSD trick
    C1_sqrt = _matrix_sqrt_psd(C1 + eps * torch.eye(C1.shape[0]))
    inner = C1_sqrt @ C2 @ C1_sqrt
    inner = (inner + inner.t()) * 0.5  # symmetrize
    covmean_sqrt = _matrix_sqrt_psd(inner + eps * torch.eye(inner.shape[0]))

    tr = torch.trace(C1 + C2 - 2.0 * covmean_sqrt)
    fid = diff.dot(diff) + tr
    return float(fid.item())


In [5]:
# ------------------------------
# Dataset & Utilities
# ------------------------------

@dataclass
class TrainConfig:
    data_root: str = "./data"
    outdir: str = "./outputs"
    batch_size: int = 128
    workers: int = 2
    epochs: int = 10
    lr: float = 2e-4
    beta1: float = 0.5
    nz: int = 128
    ngf: int = 64
    ndf: int = 64
    fid_every: int = 1
    num_fid_samples: int = 5000
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


def get_cifar10_loaders(cfg: TrainConfig):
    # GAN training transform: [-1,1]
    tf_train = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_set = datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=tf_train)
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)

    # FID real features transform expects [0,1] images (we'll apply Inception transforms later)
    tf_real = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
    ])
    real_eval_set = datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=tf_real)
    return train_loader, real_eval_set


def denorm_to_01(x: torch.Tensor) -> torch.Tensor:
    """Convert from [-1,1] to [0,1] range."""
    return x.add(1).div(2).clamp(0, 1)


def save_image_grid(tensor: torch.Tensor, path: str, nrow: int = 8):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    grid = vutils.make_grid(tensor, nrow=nrow, padding=2)
    vutils.save_image(grid, path)


def plot_curves(curves: dict, path: str, title: str, ylabel: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.figure(figsize=(7, 4))
    for k, v in curves.items():
        xs = list(range(1, len(v)+1))
        plt.plot(xs, v, label=k)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.savefig(path)
    plt.close()


In [6]:
# ------------------------------
# FID computer (caches real stats)
# ------------------------------

class FIDEvaluator:
    def __init__(self, device: torch.device, outdir: str, real_stats_path: str = None):
        self.device = device
        self.extractor = InceptionV3Pool(device)
        self.outdir = outdir
        os.makedirs(self.outdir, exist_ok=True)
        self.real_stats_path = real_stats_path or os.path.join(self.outdir, "cifar10_train_inception_stats.npz")

    def compute_real_stats(self, real_dataset: datasets.VisionDataset, max_items: int = None, batch: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
        if os.path.exists(self.real_stats_path):
            arr = np.load(self.real_stats_path)
            mu = torch.from_numpy(arr["mu"]).float()
            cov = torch.from_numpy(arr["cov"]).float()
            return mu, cov
        loader = DataLoader(real_dataset, batch_size=batch, shuffle=False, num_workers=2)
        feats = []
        total = 0
        for img, _ in loader:
            if max_items is not None and total >= max_items:
                break
            img = img.to(self.device)
            img01 = img  # already [0,1]
            f = self.extractor.features(img01)
            feats.append(f)
            total += img.size(0)
        feats = torch.cat(feats, dim=0)
        mu, cov = cov_mean(feats)
        np.savez(self.real_stats_path, mu=mu.numpy(), cov=cov.numpy())
        return mu, cov

    def compute_fake_stats(self, generator: nn.Module, nz: int, num_samples: int = 5000, batch: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
        feats = []
        with torch.no_grad():
            for _ in range(math.ceil(num_samples / batch)):
                bs = min(batch, num_samples - len(feats)*batch)
                z = torch.randn(bs, nz, 1, 1, device=self.device)
                fake = generator(z)  # [-1,1]
                fake01 = denorm_to_01(fake)
                f = self.extractor.features(fake01)
                feats.append(f)
        feats = torch.cat(feats, dim=0)[:num_samples]
        mu, cov = cov_mean(feats)
        return mu, cov

    def fid(self, mu_r: torch.Tensor, cov_r: torch.Tensor, mu_f: torch.Tensor, cov_f: torch.Tensor) -> float:
        return frechet_distance(mu_r, cov_r, mu_f, cov_f)

    def tsne_plot(self, real_dataset: datasets.VisionDataset, generator: nn.Module, nz: int, path: str, num_per_class: int = 300):
        if not _HAVE_SK:
            return  # skip if sklearn not present
        # sample subset from real
        loader = DataLoader(real_dataset, batch_size=256, shuffle=True, num_workers=2)
        imgs_real = []
        for img, _ in loader:
            imgs_real.append(img)
            if sum(x.size(0) for x in imgs_real) >= num_per_class:
                break
        Xr = torch.cat(imgs_real, dim=0)[:num_per_class].to(self.device)
        Xr_f = self.extractor.features(Xr)
        # sample fake
        with torch.no_grad():
            z = torch.randn(num_per_class, nz, 1, 1, device=self.device)
            Xf = generator(z)
            Xf01 = denorm_to_01(Xf)
            Xf_f = self.extractor.features(Xf01)
        X = torch.cat([Xr_f, Xf_f], dim=0).cpu().numpy()
        y = np.array([0]*num_per_class + [1]*num_per_class)
        tsne = TSNE(n_components=2, init="random", random_state=42, perplexity=30)
        emb = tsne.fit_transform(X)
        # plot
        os.makedirs(os.path.dirname(path), exist_ok=True)
        plt.figure(figsize=(6,6))
        plt.scatter(emb[y==0,0], emb[y==0,1], alpha=0.6, label="Real")
        plt.scatter(emb[y==1,0], emb[y==1,1], alpha=0.6, label="Fake")
        plt.legend(); plt.title("t-SNE of Inception Features (Real vs Fake)")
        plt.tight_layout(); plt.savefig(path); plt.close()


In [9]:
# ------------------------------
# Trainer
# ------------------------------

class GANTrainer:
    def __init__(self, cfg: TrainConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        set_seed(cfg.seed)
        os.makedirs(cfg.outdir, exist_ok=True)

        # Data
        self.train_loader, self.real_eval_set = get_cifar10_loaders(cfg)

        # Models
        self.netG = DCGANGenerator(cfg.nz, cfg.ngf).to(self.device)
        self.netD = DCGANDiscriminator(cfg.ndf).to(self.device)
        self._init_weights()

        # Loss/Opt
        self.criterion = nn.BCEWithLogitsLoss()
        self.optG = optim.Adam(self.netG.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
        self.optD = optim.Adam(self.netD.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))

        # Fixed noise for grids
        self.fixed_z = torch.randn(64, cfg.nz, 1, 1, device=self.device)

        # FID evaluator & cached real stats
        self.fid_eval = FIDEvaluator(self.device, cfg.outdir)
        self.mu_r, self.cov_r = self.fid_eval.compute_real_stats(self.real_eval_set)

        # Logs
        self.hist_G: List[float] = []
        self.hist_D: List[float] = []
        self.hist_FID: List[float] = []

    def _init_weights(self):
        def weights_init(m):
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)

    def train_epoch(self, epoch: int):
        self.netG.train(); self.netD.train()
        running_G, running_D = 0.0, 0.0
        for i, (imgs, _) in enumerate(self.train_loader):
            b = imgs.size(0)
            real = imgs.to(self.device)  # [-1,1]
            z = torch.randn(b, self.cfg.nz, 1, 1, device=self.device)
            fake = self.netG(z)

            # Train D: maximize log D(real) + log(1 - D(fake))
            self.optD.zero_grad()
            pred_real = self.netD(real)
            pred_fake = self.netD(fake.detach())
            loss_D = self.criterion(pred_real, torch.ones_like(pred_real)) + \
                     self.criterion(pred_fake, torch.zeros_like(pred_fake))
            loss_D.backward()
            self.optD.step()

            # Train G: maximize log D(fake)
            self.optG.zero_grad()
            pred_fake2 = self.netD(fake)
            loss_G = self.criterion(pred_fake2, torch.ones_like(pred_fake2))
            loss_G.backward()
            self.optG.step()

            running_G += loss_G.item()
            running_D += loss_D.item()

            if (i+1) % 100 == 0:
                print(f"Epoch {epoch:03d} [{i+1:04d}/{len(self.train_loader)}]  D: {loss_D.item():.3f}  G: {loss_G.item():.3f}")

        mean_G = running_G / len(self.train_loader)
        mean_D = running_D / len(self.train_loader)
        self.hist_G.append(mean_G)
        self.hist_D.append(mean_D)

        # Save sample grid for the epoch
        with torch.no_grad():
            sample = self.netG(self.fixed_z)
            sample01 = denorm_to_01(sample)
            save_image_grid(sample01, os.path.join(self.cfg.outdir, f"samples/epoch_{epoch:03d}.png"), nrow=8)

    def evaluate_fid(self, epoch: int):
        mu_f, cov_f = self.fid_eval.compute_fake_stats(self.netG, self.cfg.nz, self.cfg.num_fid_samples)
        fid = self.fid_eval.fid(self.mu_r, self.cov_r, mu_f, cov_f)
        self.hist_FID.append(fid)
        print(f"[FID] Epoch {epoch:03d}  FID = {fid:.2f}")

    def post_plots(self):
        # Loss curves
        plot_curves({"G": self.hist_G, "D": self.hist_D}, os.path.join(self.cfg.outdir, "curves/loss_curves.png"),
                    title="GAN Loss Curves", ylabel="Loss")
        # FID curve
        if len(self.hist_FID) > 0:
            plot_curves({"FID": self.hist_FID}, os.path.join(self.cfg.outdir, "curves/fid_curve.png"),
                        title="FID over Epochs", ylabel="FID (lower is better)")

        # t-SNE of features at the end (optional if sklearn available)
        try:
            self.fid_eval.tsne_plot(self.real_eval_set, self.netG, self.cfg.nz,
                                    os.path.join(self.cfg.outdir, "curves/tsne_features.png"))
        except Exception as e:
            print(f"t-SNE plot skipped: {e}")

    def run(self):
        for epoch in range(1, self.cfg.epochs + 1):
            self.train_epoch(epoch)
            if (epoch % self.cfg.fid_every) == 0:
                self.evaluate_fid(epoch)
        self.post_plots()
        print("Training complete. Outputs saved to:", self.cfg.outdir)


In [16]:
# ------------------------------
# Main
# ------------------------------

def parse_args() -> TrainConfig:
    p = argparse.ArgumentParser()
    p.add_argument('--data_root', type=str, default='./data')
    p.add_argument('--outdir', type=str, default='./outputs')
    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--workers', type=int, default=2)
    p.add_argument('--epochs', type=int, default=10)
    p.add_argument('--lr', type=float, default=2e-4)
    p.add_argument('--beta1', type=float, default=0.5)
    p.add_argument('--nz', type=int, default=128)
    p.add_argument('--ngf', type=int, default=64)
    p.add_argument('--ndf', type=int, default=64)
    p.add_argument('--fid_every', type=int, default=1)
    p.add_argument('--num_fid_samples', type=int, default=5000)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = p.parse_args(args=None if __name__ == '__main__' else [])
    return TrainConfig(**vars(args))


def main():
    cfg = parse_args()

    print("\nFréchet Frontier — DCGAN + FID Evaluation Suite (PyTorch)")
    print("Config:", cfg)

    trainer = GANTrainer(cfg)
    trainer.run()
