In [1]:
"""
Inception Insight — A Didactic DCGAN + Inception Score (IS) Evaluation Suite (PyTorch)
=====================================================================================
An end-to-end teaching project that trains a DCGAN on CIFAR-10 and evaluates the
**Inception Score (IS)** across epochs. Includes:

• Dataset: CIFAR-10 (32×32 RGB)
• Model: DCGAN (configurable) — generator & discriminator
• Evaluation: Inception Score (splits, mean ± std) with InceptionV3 logits
• Visuals: Loss curves, IS curve, per-epoch sample grids, optional t-SNE of features
• Artifacts: CSV logs, animated GIF timeline of samples
• Epochs: configurable (5/10/20) for pedagogy

Run
----
python inception_score_gan.py \
  --data_root ./data \
  --outdir ./outputs_is \
  --epochs 10 \
  --batch_size 128 \
  --is_every 1 \
  --num_is_samples 5000 \
  --is_splits 10

Optional flags: --epochs {5,10,20}, --lr 2e-4, --nz 128, --ngf 64, --ndf 64, --device cuda

Notes
-----
• Inception Score uses ImageNet-pretrained InceptionV3 logits (softmax) to
  measure both sample quality and diversity. Higher IS is better.
• We compute IS over generated samples in splits, reporting mean ± std.
"""




In [4]:
import os
os.makedirs("sample")

In [9]:
%%writefile sample/inception_score_gan.py
import os
import math
import argparse
import random
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils, models
from torchvision.models import Inception_V3_Weights
import matplotlib.pyplot as plt
import imageio

try:
    from sklearn.manifold import TSNE
    _HAVE_SK = True
except Exception:
    _HAVE_SK = False

# ------------------------------
# Reproducibility
# ------------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ------------------------------
# DCGAN Models
# ------------------------------

class DCGANGenerator(nn.Module):
    def __init__(self, nz=128, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 3, 1, 1, bias=False),
            nn.Tanh(),  # [-1, 1]
        )
    def forward(self, z):
        return self.main(z)

class DCGANDiscriminator(nn.Module):
    def __init__(self, ndf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
        )
    def forward(self, x):
        return self.main(x).squeeze(1)


# ------------------------------
# InceptionV3 logits extractor for IS
# ------------------------------

class InceptionV3Logits(nn.Module):
    """Pretrained InceptionV3 that outputs 1000-d logits.
    Uses torchvision's official weights + preprocessing transforms.
    """
    def __init__(self, device: torch.device):
        super().__init__()
        weights = Inception_V3_Weights.IMAGENET1K_V1
        self.transforms = weights.transforms()  # resize 299, center-crop, normalize
        self.model = models.inception_v3(weights=weights, aux_logits=False).to(device).eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.device = device

    @torch.no_grad()
    def logits(self, imgs01: torch.Tensor, batch_size: int = 64) -> torch.Tensor:
        """imgs01 in [0,1], shape (N,3,H,W). Returns (N,1000) logits."""
        outs = []
        n = imgs01.size(0)
        for i in range(0, n, batch_size):
            b = imgs01[i:i+batch_size].cpu()
            proc = torch.stack([self.transforms(x) for x in b]).to(self.device)
            y = self.model(proc)  # (B,1000)
            outs.append(y.detach().cpu())
        return torch.cat(outs, dim=0)


def inception_score_from_logits(logits: torch.Tensor, splits: int = 10) -> Tuple[float, float]:
    """Compute Inception Score from logits (ImageNet classes).
    Returns (mean, std) across splits.
    """
    probs = torch.softmax(logits, dim=1)  # (N,1000)
    N = probs.size(0)
    assert splits >= 1 and N >= splits, "Invalid splits/N for IS"
    sizes = [N // splits] * splits
    sizes[-1] += N - sum(sizes)  # remainder to last split

    scores = []
    start = 0
    for s in sizes:
        p_yx = probs[start:start+s]                       # (s,1000)
        p_y = p_yx.mean(dim=0, keepdim=True)              # (1,1000)
        kl = p_yx * (torch.log(p_yx + 1e-10) - torch.log(p_y + 1e-10))
        kl_mean = kl.sum(dim=1).mean()                    # scalar
        scores.append(torch.exp(kl_mean).item())
        start += s
    scores = np.array(scores)
    return float(scores.mean()), float(scores.std())


# ------------------------------
# Dataset & Utilities
# ------------------------------

@dataclass
class TrainConfig:
    data_root: str = "./data"
    outdir: str = "./outputs_is"
    batch_size: int = 128
    workers: int = 2
    epochs: int = 10
    lr: float = 2e-4
    beta1: float = 0.5
    nz: int = 128
    ngf: int = 64
    ndf: int = 64
    is_every: int = 1
    num_is_samples: int = 5000
    is_splits: int = 10
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


def get_cifar10_loader(cfg: TrainConfig):
    tf_train = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_set = datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=tf_train)
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.workers, drop_last=True)
    return train_loader


def denorm_to_01(x: torch.Tensor) -> torch.Tensor:
    return x.add(1).div(2).clamp(0, 1)


def save_image_grid(tensor: torch.Tensor, path: str, nrow: int = 8):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    grid = vutils.make_grid(tensor, nrow=nrow, padding=2)
    vutils.save_image(grid, path)


def plot_curves(curves: dict, path: str, title: str, ylabel: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.figure(figsize=(7, 4))
    for k, v in curves.items():
        xs = list(range(1, len(v)+1))
        plt.plot(xs, v, label=k)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.savefig(path)
    plt.close()


def save_csv(path: str, header: List[str], rows: List[List]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        f.write(','.join(header) + '\n')
        for r in rows:
            f.write(','.join(map(str, r)) + '\n')


def make_gif_from_folder(folder: str, out_path: str, fps: int = 4):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    frames = []
    if not os.path.isdir(folder):
        return
    for name in sorted(os.listdir(folder)):
        if name.lower().endswith('.png'):
            frames.append(imageio.v2.imread(os.path.join(folder, name)))
    if frames:
        imageio.mimsave(out_path, frames, fps=fps)


# ------------------------------
# IS Evaluator
# ------------------------------

class ISEvaluator:
    def __init__(self, device: torch.device, outdir: str):
        self.device = device
        self.extractor = InceptionV3Logits(device)
        self.outdir = outdir
        os.makedirs(self.outdir, exist_ok=True)

    @torch.no_grad()
    def compute_is(self, generator: nn.Module, nz: int, num_samples: int = 5000, splits: int = 10, batch: int = 128) -> Tuple[float, float]:
        logits_all = []
        for _ in range(math.ceil(num_samples / batch)):
            bs = min(batch, num_samples - len(logits_all)*batch)
            z = torch.randn(bs, nz, 1, 1, device=self.extractor.device)
            fake = generator(z)  # [-1,1]
            fake01 = denorm_to_01(fake)
            logits = self.extractor.logits(fake01)  # (bs,1000)
            logits_all.append(logits)
        logits_all = torch.cat(logits_all, dim=0)[:num_samples]
        mean, std = inception_score_from_logits(logits_all, splits=splits)
        return mean, std

    def tsne_plot(self, generator: nn.Module, nz: int, path: str, num_fake: int = 1000):
        if not _HAVE_SK:
            return
        with torch.no_grad():
            z = torch.randn(num_fake, nz, 1, 1, device=self.extractor.device)
            fake = generator(z)
            fake01 = denorm_to_01(fake)
            # Use Inception features before logits for t-SNE: swap classifier for Identity
            feats_model = models.inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, aux_logits=False)
            feats_model.fc = nn.Identity()
            feats_model = feats_model.to(self.extractor.device).eval()
            proc = torch.stack([self.extractor.transforms(x.cpu()) for x in fake01]).to(self.extractor.device)
            feats = feats_model(proc).detach().cpu().numpy()
        tsne = TSNE(n_components=2, init="random", random_state=42, perplexity=30)
        emb = tsne.fit_transform(feats)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        plt.figure(figsize=(6,6))
        plt.scatter(emb[:,0], emb[:,1], s=6, alpha=0.6)
        plt.title("t-SNE of Inception Features (Fake Only)")
        plt.tight_layout(); plt.savefig(path); plt.close()


# ------------------------------
# Trainer
# ------------------------------

class GANTrainer:
    def __init__(self, cfg: TrainConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        set_seed(cfg.seed)
        os.makedirs(cfg.outdir, exist_ok=True)

        # Data
        self.train_loader = get_cifar10_loader(cfg)

        # Models
        self.netG = DCGANGenerator(cfg.nz, cfg.ngf).to(self.device)
        self.netD = DCGANDiscriminator(cfg.ndf).to(self.device)
        self._init_weights()

        # Loss/Opt
        self.criterion = nn.BCEWithLogitsLoss()
        self.optG = optim.Adam(self.netG.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
        self.optD = optim.Adam(self.netD.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))

        # Fixed noise for grids
        self.fixed_z = torch.randn(64, cfg.nz, 1, 1, device=self.device)

        # IS evaluator
        self.is_eval = ISEvaluator(self.device, cfg.outdir)

        # Logs
        self.hist_G: List[float] = []
        self.hist_D: List[float] = []
        self.hist_IS: List[float] = []
        self.hist_IS_std: List[float] = []

    def _init_weights(self):
        def weights_init(m):
            cname = m.__class__.__name__
            if cname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif cname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)

    def train_epoch(self, epoch: int):
        self.netG.train(); self.netD.train()
        running_G, running_D = 0.0, 0.0
        for i, (imgs, _) in enumerate(self.train_loader):
            b = imgs.size(0)
            real = imgs.to(self.device)
            z = torch.randn(b, self.cfg.nz, 1, 1, device=self.device)
            fake = self.netG(z)

            # D update
            self.optD.zero_grad()
            pred_real = self.netD(real)
            pred_fake = self.netD(fake.detach())
            loss_D = self.criterion(pred_real, torch.ones_like(pred_real)) + \
                     self.criterion(pred_fake, torch.zeros_like(pred_fake))
            loss_D.backward(); self.optD.step()

            # G update
            self.optG.zero_grad()
            pred_fake2 = self.netD(fake)
            loss_G = self.criterion(pred_fake2, torch.ones_like(pred_fake2))
            loss_G.backward(); self.optG.step()

            running_G += loss_G.item(); running_D += loss_D.item()
            if (i+1) % 100 == 0:
                print(f"Epoch {epoch:03d} [{i+1:04d}/{len(self.train_loader)}]  D: {loss_D.item():.3f}  G: {loss_G.item():.3f}")

        self.hist_G.append(running_G / len(self.train_loader))
        self.hist_D.append(running_D / len(self.train_loader))

        # Save sample grid for the epoch
        with torch.no_grad():
            sample = self.netG(self.fixed_z)
            sample01 = denorm_to_01(sample)
            save_image_grid(sample01, os.path.join(self.cfg.outdir, f"samples/epoch_{epoch:03d}.png"), nrow=8)

    def evaluate_is(self, epoch: int):
        mean, std = self.is_eval.compute_is(self.netG, self.cfg.nz, num_samples=self.cfg.num_is_samples,
                                            splits=self.cfg.is_splits)
        self.hist_IS.append(mean); self.hist_IS_std.append(std)
        print(f"[IS] Epoch {epoch:03d}  IS = {mean:.2f} ± {std:.2f}  (splits={self.cfg.is_splits})")

    def post_plots(self):
        # Loss curves
        plot_curves({"G": self.hist_G, "D": self.hist_D}, os.path.join(self.cfg.outdir, "curves/loss_curves.png"),
                    title="GAN Loss Curves", ylabel="Loss")
        # Inception Score curve
        if len(self.hist_IS) > 0:
            os.makedirs(os.path.join(self.cfg.outdir, "curves"), exist_ok=True)
            plt.figure(figsize=(7,4))
            xs = list(range(1, len(self.hist_IS)+1))
            plt.errorbar(xs, self.hist_IS, yerr=self.hist_IS_std, fmt='-o')
            plt.title("Inception Score over Epochs")
            plt.xlabel("Epoch"); plt.ylabel("IS (higher is better)")
            plt.tight_layout(); plt.savefig(os.path.join(self.cfg.outdir, "curves/is_curve.png")); plt.close()
        # Save numeric logs
        save_csv(os.path.join(self.cfg.outdir, "logs/losses.csv"), ["epoch","G","D"],
                 [[i+1, self.hist_G[i], self.hist_D[i]] for i in range(len(self.hist_G))])
        if len(self.hist_IS) > 0:
            save_csv(os.path.join(self.cfg.outdir, "logs/is.csv"), ["epoch","is_mean","is_std"],
                     [[(i+1)*self.cfg.is_every, self.hist_IS[i], self.hist_IS_std[i]] for i in range(len(self.hist_IS))])
        # Animated GIF
        make_gif_from_folder(os.path.join(self.cfg.outdir, "samples"), os.path.join(self.cfg.outdir, "samples/timeline.gif"), fps=3)
        # Optional t-SNE of generated features
        try:
            self.is_eval.tsne_plot(self.netG, self.cfg.nz, os.path.join(self.cfg.outdir, "curves/tsne_fake_features.png"))
        except Exception as e:
            print(f"t-SNE plot skipped: {e}")

    def run(self):
        for epoch in range(1, self.cfg.epochs + 1):
            self.train_epoch(epoch)
            if (epoch % self.cfg.is_every) == 0:
                self.evaluate_is(epoch)
        self.post_plots()
        print("\nArtifacts saved:")
        print(" - Sample grids:", os.path.join(self.cfg.outdir, "samples/epoch_###.png"))
        print(" - Sample animation:", os.path.join(self.cfg.outdir, "samples/timeline.gif"))
        print(" - Loss curves:", os.path.join(self.cfg.outdir, "curves/loss_curves.png"))
        print(" - IS curve:", os.path.join(self.cfg.outdir, "curves/is_curve.png"))
        print(" - Logs (CSV):", os.path.join(self.cfg.outdir, "logs/"))
        print(" - t-SNE (if available):", os.path.join(self.cfg.outdir, "curves/tsne_fake_features.png"))
        print("Training complete. Outputs saved to:", self.cfg.outdir)


# ------------------------------
# Main
# ------------------------------

def parse_args() -> TrainConfig:
    p = argparse.ArgumentParser()
    p.add_argument('--data_root', type=str, default='./data')
    p.add_argument('--outdir', type=str, default='./outputs_is')
    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--workers', type=int, default=2)
    p.add_argument('--epochs', type=int, default=10)
    p.add_argument('--lr', type=float, default=2e-4)
    p.add_argument('--beta1', type=float, default=0.5)
    p.add_argument('--nz', type=int, default=128)
    p.add_argument('--ngf', type=int, default=64)
    p.add_argument('--ndf', type=int, default=64)
    p.add_argument('--is_every', type=int, default=1)
    p.add_argument('--num_is_samples', type=int, default=5000)
    p.add_argument('--is_splits', type=int, default=10)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = p.parse_args(args=None if __name__ == '__main__' else [])
    return TrainConfig(**vars(args))


def main():
    cfg = parse_args()
    print("\nInception Insight — DCGAN + Inception Score Evaluation Suite (PyTorch)")
    print("Config:", cfg)
    trainer = GANTrainer(cfg)
    trainer.run()


    if __name__ == '__main__':
        main()


Writing sample/inception_score_gan.py


In [10]:
!python sample/inception_score_gan.py \
  --epochs 10 \
  --is_every 1 \
  --num_is_samples 5000 \
  --is_splits 10 \
  --outdir ./outputs_is