In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.linalg import sqrtm

# =============================
# Configuration (Jupyter-safe)
# =============================
class Opt:
    n_epochs = 200
    batch_size = 64
    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    latent_dim = 100
    n_classes = 10
    img_size = 28
    channels = 1
    sample_interval = 500

opt = Opt()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_shape = (opt.channels, opt.img_size, opt.img_size)

os.makedirs("results/images", exist_ok=True)
os.makedirs("results/plots", exist_ok=True)

# =============================
# Generator (Conditional)
# =============================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_f, out_f, norm=True):
            layers = [nn.Linear(in_f, out_f)]
            if norm:
                layers.append(nn.BatchNorm1d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z, labels):
        x = torch.cat((z, self.label_emb(labels)), dim=1)
        img = self.model(x)
        return img.view(img.size(0), *img_shape)

# =============================
# Discriminator (Conditional)
# =============================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), dim=1)
        return self.model(d_in)

# =============================
# Models & Loss
# =============================
generator = Generator().to(device)
discriminator = Discriminator().to(device)
criterion = nn.BCELoss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# =============================
# Data
# =============================
dataloader = DataLoader(
    datasets.MNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize(opt.img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    ),
    batch_size=opt.batch_size,
    shuffle=True
)

# =============================
# Feature Extractor (Inception)
# =============================
inception = models.inception_v3(pretrained=True, transform_input=False).to(device)
inception.eval()
inception.fc = nn.Identity()

def get_features(images):
    images = images.repeat(1, 3, 1, 1)
    images = nn.functional.interpolate(images, size=299)
    with torch.no_grad():
        return inception(images).cpu().numpy()

# =============================
# Metrics
# =============================
def inception_score(features, splits=10):
    scores = []
    N = features.shape[0]
    for i in range(splits):
        part = features[i * (N // splits):(i + 1) * (N // splits)]
        p_yx = np.exp(part) / np.sum(np.exp(part), axis=1, keepdims=True)
        p_y = np.mean(p_yx, axis=0)
        kl = p_yx * (np.log(p_yx + 1e-10) - np.log(p_y + 1e-10))
        scores.append(np.exp(np.mean(np.sum(kl, axis=1))))
    return np.mean(scores)

def fid_score(real_f, fake_f):
    mu1, sigma1 = real_f.mean(axis=0), np.cov(real_f, rowvar=False)
    mu2, sigma2 = fake_f.mean(axis=0), np.cov(fake_f, rowvar=False)
    covmean = sqrtm(sigma1 @ sigma2)
    return np.real(np.sum((mu1 - mu2) ** 2) + np.trace(sigma1 + sigma2 - 2 * covmean))

# =============================
# Training
# =============================
g_losses, d_losses, fid_scores, is_scores = [], [], [], []

for epoch in range(opt.n_epochs):

    g_epoch, d_epoch = 0, 0
    real_imgs_all, fake_imgs_all = [], []

    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)
        batch = imgs.size(0)

        valid = torch.ones(batch, 1, device=device)
        fake = torch.zeros(batch, 1, device=device)

        # ---- Generator ----
        optimizer_G.zero_grad()
        z = torch.randn(batch, opt.latent_dim, device=device)
        gen_labels = torch.randint(0, opt.n_classes, (batch,), device=device)
        gen_imgs = generator(z, gen_labels)
        g_loss = criterion(discriminator(gen_imgs, gen_labels), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---- Discriminator ----
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(imgs, labels), valid)
        fake_loss = criterion(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        g_epoch += g_loss.item()
        d_epoch += d_loss.item()

        real_imgs_all.append(imgs)
        fake_imgs_all.append(gen_imgs)

    # ---- Epoch metrics ----
    real_imgs_all = torch.cat(real_imgs_all)[:1000]
    fake_imgs_all = torch.cat(fake_imgs_all)[:1000]

    real_feat = get_features(real_imgs_all)
    fake_feat = get_features(fake_imgs_all)

    fid = fid_score(real_feat, fake_feat)
    inc = inception_score(fake_feat)

    g_losses.append(g_epoch / len(dataloader))
    d_losses.append(d_epoch / len(dataloader))
    fid_scores.append(fid)
    is_scores.append(inc)

    # ---- Save image comparison ----
    comparison = torch.cat([real_imgs_all[:25], fake_imgs_all[:25]])
    save_image(comparison, f"results/images/epoch_{epoch+1}.png", nrow=5, normalize=True)

    print(f"Epoch {epoch+1}/{opt.n_epochs} | G: {g_losses[-1]:.4f} | D: {d_losses[-1]:.4f} | IS: {inc:.2f} | FID: {fid:.2f}")

# =============================
# Visualizations
# =============================
plt.figure()
plt.plot(g_losses, label="G Loss")
plt.plot(d_losses, label="D Loss")
plt.legend()
plt.title("Loss Curves")
plt.savefig("results/plots/loss_curve.png")

plt.figure()
plt.plot(is_scores)
plt.title("Inception Score vs Epoch")
plt.savefig("results/plots/is_curve.png")

plt.figure()
plt.plot(fid_scores)
plt.title("FID vs Epoch")
plt.savefig("results/plots/fid_curve.png")

# PCA visualization
pca = PCA(n_components=2)
real_pca = pca.fit_transform(real_feat)
fake_pca = pca.transform(fake_feat)

plt.figure()
plt.scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.5, label="Real")
plt.scatter(fake_pca[:, 0], fake_pca[:, 1], alpha=0.5, label="Fake")
plt.legend()
plt.title("PCA Feature Space")
plt.savefig("results/plots/pca_real_fake.png")

print("Training and evaluation completed.")


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to C:\Users\SIVAKUMAR S/.cache\torch\hub\checkpoints\inception_v3_google-0cc3c7bd.pth
100%|███████████████████████████████████████████████████████████████████████████████| 104M/104M [01:29<00:00, 1.21MB/s]


Epoch 1/200 | G: 0.9232 | D: 0.5451 | IS: 1.03 | FID: 399.04
Epoch 2/200 | G: 1.0297 | D: 0.5442 | IS: 1.06 | FID: 272.56
Epoch 3/200 | G: 1.0691 | D: 0.5369 | IS: 1.04 | FID: 224.45
Epoch 4/200 | G: 1.2064 | D: 0.5045 | IS: 1.04 | FID: 214.42
Epoch 5/200 | G: 1.2714 | D: 0.4838 | IS: 1.04 | FID: 202.01
Epoch 6/200 | G: 1.2769 | D: 0.4830 | IS: 1.04 | FID: 214.04
Epoch 7/200 | G: 1.2815 | D: 0.4824 | IS: 1.04 | FID: 206.08
Epoch 8/200 | G: 1.3081 | D: 0.4767 | IS: 1.03 | FID: 200.24
Epoch 9/200 | G: 1.3840 | D: 0.4600 | IS: 1.04 | FID: 207.66
Epoch 10/200 | G: 1.4197 | D: 0.4544 | IS: 1.06 | FID: 210.45
Epoch 11/200 | G: 1.5013 | D: 0.4260 | IS: 1.04 | FID: 194.64
Epoch 12/200 | G: 1.5301 | D: 0.4181 | IS: 1.05 | FID: 182.38
Epoch 13/200 | G: 1.5650 | D: 0.4078 | IS: 1.05 | FID: 182.87
Epoch 14/200 | G: 1.6054 | D: 0.4108 | IS: 1.05 | FID: 193.61
Epoch 15/200 | G: 1.6792 | D: 0.3859 | IS: 1.05 | FID: 177.77
Epoch 16/200 | G: 1.6341 | D: 0.3986 | IS: 1.04 | FID: 178.01
Epoch 17/200 | G: