In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import random
import os

# ------------------------------
# 1. Hyperparameters and Device
# ------------------------------
batch_size = 128
image_size = 32
latent_dim = 100
num_epochs = 50
lr_G = 0.0002
lr_D = 0.00005
beta1 = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs('generated_images', exist_ok=True)
os.makedirs('fid_fake_images', exist_ok=True)

# ------------------------------
# 2. Augmentation (gentle to start)
# ------------------------------
ada_aug_p = 0.0
ada_target = 0.6
ada_interval = 4
ada_speed = 0.01
ada_history = []

ada_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
])

def ada_augment(x, p):
    """Augments a batch x with probability p using the augmentation pipeline."""
    if p == 0:
        return x
    x_aug = []
    for img in x:
        if random.random() < p:
            img_pil = transforms.ToPILImage()(img.cpu())
            img_pil = ada_aug(img_pil)
            img = transforms.ToTensor()(img_pil)
        x_aug.append(img)
    return torch.stack(x_aug).to(x.device)

# ------------------------------
# 3. DCGAN Architectures
# ------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),   # [batch, 512, 4, 4]
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),          # [batch, 256, 8, 8]
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),          # [batch, 128, 16, 16]
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),           # [batch, 64, 32, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 3, 1, 1, bias=False),             # [batch, 3, 32, 32]
            nn.Tanh()
        )
    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),        # [batch, 32, 16, 16]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),       # [batch, 64, 8, 8]
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),      # [batch, 128, 4, 4]
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0, bias=False),       # [batch, 1, 1, 1]
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.main(x)                   # [batch, 1, 1, 1]
        return out.view(x.size(0))           # [batch]

# ------------------------------
# 4. Data
# ------------------------------
transform_train = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
trainset = datasets.CIFAR10(root='./data', download=True, transform=transform_train)
dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# ------------------------------
# 5. Initialize Models and Optimizers
# ------------------------------
netG = Generator().to(device)
netD = Discriminator().to(device)
optimizerD = optim.Adam(netD.parameters(), lr=lr_D, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr_G, betas=(beta1, 0.999))
criterion = nn.BCELoss()

# ------------------------------
# 6. Training Loop with ADA
# ------------------------------
real_label = 0.9
fake_label = 0.0
g_losses, d_losses = [], []

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        b_size = real_images.size(0)
        ############################
        # (1) Update D network
        ###########################
        netD.zero_grad()
        real_images = real_images.to(device)
        real_images_aug = ada_augment(real_images, ada_aug_p)
        output_real = netD(real_images_aug)
        label_real = torch.full_like(output_real, real_label, device=device)
        errD_real = criterion(output_real, label_real)
        D_x = output_real.mean().item()
        errD_real.backward()

        noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
        fake_images = netG(noise)
        fake_images_aug = ada_augment(fake_images.detach(), ada_aug_p)
        output_fake = netD(fake_images_aug)
        label_fake = torch.full_like(output_fake, fake_label, device=device)
        errD_fake = criterion(output_fake, label_fake)
        D_G_z1 = output_fake.mean().item()
        errD_fake.backward()
        optimizerD.step()

        ############################
        # (2) Update G network MORE OFTEN
        ###########################
        for _ in range(2):  # Two G updates for each D
            netG.zero_grad()
            noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
            fake_images = netG(noise)
            gen_imgs_aug = ada_augment(fake_images, ada_aug_p)
            output_gen = netD(gen_imgs_aug)
            label_gen = torch.full_like(output_gen, real_label, device=device)
            errG = criterion(output_gen, label_gen)
            errG.backward()
            optimizerG.step()

        errD = errD_real + errD_fake
        g_losses.append(errG.item())
        d_losses.append(errD.item())

        if i % ada_interval == 0:
            with torch.no_grad():
                pred = netD(real_images_aug)
                real_acc = ((pred > 0.5).float().mean().item())
                ada_history.append(real_acc)
            if real_acc > ada_target:
                ada_aug_p = min(1.0, ada_aug_p + ada_speed)
            else:
                ada_aug_p = max(0.0, ada_aug_p - ada_speed)

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch {i}/{len(dataloader)} "
                  f"Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} "
                  f"D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} ADA_p: {ada_aug_p:.3f}")


    # Save images per epoch for monitoring
    with torch.no_grad():
        fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
        fake = netG(fixed_noise).detach().cpu()
        vutils.save_image(fake, f"generated_images/ada_dcgan_epoch_{epoch:03d}.png", normalize=True, nrow=8)

# Save model weights
torch.save(netG.state_dict(), "ada_dcgan_gen.pth")
torch.save(netD.state_dict(), "ada_dcgan_disc.pth")

# Plot loss curves
plt.figure()
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.legend()
plt.title("GAN Losses")
plt.show()

# --------------- FID EVALUATION ---------------
def generate_samples_for_fid(generator, device, latent_dim=100, num_samples=5000, outdir='fid_fake_images'):
    os.makedirs(outdir, exist_ok=True)
    generator.eval()
    idx = 0
    with torch.no_grad():
        while idx < num_samples:
            z = torch.randn(64, latent_dim, 1, 1, device=device)
            imgs = generator(z)
            imgs = (imgs + 1) / 2  # [-1,1] to [0,1]
            for i in range(imgs.size(0)):
                vutils.save_image(imgs[i], f"{outdir}/fake_{idx+i}.png")
            idx += imgs.size(0)
    print(f"{num_samples} FID-ready fake images saved to: {outdir}")

generate_samples_for_fid(netG, device, latent_dim=latent_dim, num_samples=5000, outdir="fid_fake_images")
print("To compute the FID, run: python -m pytorch_fid fid_fake_images cifar10_test_real")

# --------------- SINGLE GRID FOR VISUALIZATION ---------------
with torch.no_grad():
    fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)
    fake_images = netG(fixed_noise).detach().cpu()
    vutils.save_image(
        fake_images,
        "ada_dcgan_sample_grid.png",
        normalize=True,
        nrow=8
    )
print("Saved ADA-DCGAN visual sample grid to 'ada_dcgan_sample_grid.png'.")

# --- Show grid inline (if in notebook) ---
try:
    from PIL import Image
    img = Image.open("ada_dcgan_sample_grid.png")
    plt.figure(figsize=(7,7))
    plt.axis("off")
    plt.title("ADA-DCGAN Sample Grid")
    plt.imshow(img)
    plt.show()
except Exception as e:
    print("Image preview unavailable:", e)


100%|██████████| 170M/170M [00:06<00:00, 28.4MB/s]


Epoch [1/50] Batch 0/391 Loss_D: 1.3982 Loss_G: 0.5310 D(x): 0.5271 D(G(z)): 0.5101 ADA_p: 0.010
Epoch [1/50] Batch 100/391 Loss_D: 2.2769 Loss_G: 0.4956 D(x): 0.2862 D(G(z)): 0.6638 ADA_p: 0.000
Epoch [1/50] Batch 200/391 Loss_D: 1.8378 Loss_G: 0.6385 D(x): 0.3405 D(G(z)): 0.5557 ADA_p: 0.000
Epoch [1/50] Batch 300/391 Loss_D: 1.7484 Loss_G: 0.6984 D(x): 0.3400 D(G(z)): 0.5158 ADA_p: 0.000
Epoch [2/50] Batch 0/391 Loss_D: 1.6096 Loss_G: 0.7032 D(x): 0.4061 D(G(z)): 0.5159 ADA_p: 0.000
Epoch [2/50] Batch 100/391 Loss_D: 1.5289 Loss_G: 0.7469 D(x): 0.4104 D(G(z)): 0.4800 ADA_p: 0.000
Epoch [2/50] Batch 200/391 Loss_D: 1.5047 Loss_G: 0.7630 D(x): 0.4106 D(G(z)): 0.4720 ADA_p: 0.000
Epoch [2/50] Batch 300/391 Loss_D: 1.6013 Loss_G: 0.6935 D(x): 0.4170 D(G(z)): 0.5251 ADA_p: 0.000


In [None]:
!pip install pytorch-fid

Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Downloading pytorch_fid-0.3.0-py3-none-any.whl (15 kB)
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.3.0


In [None]:
# In a Jupyter/Colab notebook cell, use !
!python -m pytorch_fid fid_fake_images cifar10_test_real

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/pytorch_fid/__main__.py", line 3, in <module>
    pytorch_fid.fid_score.main()
  File "/usr/local/lib/python3.12/dist-packages/pytorch_fid/fid_score.py", line 313, in main
    fid_value = calculate_fid_given_paths(args.path,
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pytorch_fid/fid_score.py", line 253, in calculate_fid_given_paths
    raise RuntimeError('Invalid path: %s' % p)
RuntimeError: Invalid path: fid_fake_images
