In [11]:
import warnings
warnings.filterwarnings('ignore')

TRAINING THE GAN FOR 3 SEEDS

In [None]:
!python train.py --flagfile ./config/GN-GAN_CIFAR10_BIGGAN.txt --seed=0 --logdir=./logs/GN-BigGAN_CIFAR10_seed0


In [None]:
!python train.py --flagfile ./config/GN-GAN_CIFAR10_BIGGAN.txt --seed=1 --logdir=./logs/GN-BigGAN_CIFAR10_seed1


In [None]:
!python train.py --flagfile ./config/GN-GAN_CIFAR10_BIGGAN.txt  --seed=2 --logdir=./logs/GN-BigGAN_CIFAR10_seed2


TESTING THE MODEL

In [None]:
import os
import glob
import sys
import types
import torch
import numpy as np
try:
    from torchvision.models.utils import load_state_dict_from_url  
except ImportError:
    from torch.hub import load_state_dict_from_url
    utils_mod = types.ModuleType("torchvision.models.utils")
    utils_mod.load_state_dict_from_url = load_state_dict_from_url
    sys.modules["torchvision.models.utils"] = utils_mod
from pytorch_gan_metrics import get_inception_score_and_fid
from models import biggan  
stats_path = "./stats/cifar10.test.npz"
ckpt_pattern = "logs/GN-BigGAN_CIFAR10_seed*/best_model.pt"
z_dim = 128
n_classes = 10
num_images = 50000     
batch_size = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
paths = sorted(glob.glob(ckpt_pattern))
if not paths:
    raise SystemExit(f"No checkpoints found with pattern: {ckpt_pattern}")
print("Found checkpoints:")
for p in paths:
    print("  ", p)
results = []
for path in paths:
    seed_name = os.path.basename(os.path.dirname(path))
    print(f"\nEvaluating {seed_name} ...")
    ckpt = torch.load(path, map_location=device)
    net_G = biggan.Generator32(z_dim, n_classes).to(device)
    if "ema_G" in ckpt:
        net_G.load_state_dict(ckpt["ema_G"])
    else:
        net_G.load_state_dict(ckpt["net_G"])
    net_G.eval()
    images = []
    remaining = num_images
    with torch.no_grad():
        while remaining > 0:
            bs = min(batch_size, remaining)
            z = torch.randn(bs, z_dim, device=device)
            y = torch.randint(0, n_classes, (bs,), device=device)
            fake = (net_G(z, y) + 1) / 2          
            fake = fake.clamp(0, 1)                
            images.append(fake.cpu())             
            remaining -= bs
    images = torch.cat(images, dim=0)[:num_images]   
    print(f"  Generated {images.shape[0]} images.")
    (IS, IS_std), FID = get_inception_score_and_fid(
        images, stats_path, verbose=True
    )
    print(f"  {seed_name} -> IS={IS:.3f} (±{IS_std:.3f}), FID(test)={FID:.3f}")
    results.append((seed_name, IS, IS_std, FID))
IS_vals = np.array([r[1] for r in results])
IS_std_vals = np.array([r[2] for r in results])
FID_vals = np.array([r[3] for r in results])
print("\nPer-seed results (EMA, FID(test)):")
for seed_name, IS, IS_std, FID in results:
    print(f"  {seed_name}: IS={IS:.3f} (±{IS_std:.3f}), FID(test)={FID:.3f}")
print("\n=== Aggregated over seeds (GN-BigGAN, FID(test)) ===")
print(f"Inception Score: {IS_vals.mean():.3f} ± {IS_vals.std(ddof=1):.3f}")
print(f"FID(test):       {FID_vals.mean():.3f} ± {FID_vals.std(ddof=1):.3f}")

Found checkpoints:
   logs/GN-BigGAN_CIFAR10_seed0/best_model.pt
   logs/GN-BigGAN_CIFAR10_seed1/best_model.pt
   logs/GN-BigGAN_CIFAR10_seed2/best_model.pt

Evaluating GN-BigGAN_CIFAR10_seed0 ...
  Generated 50000 images.


                                                                

  GN-BigGAN_CIFAR10_seed0 -> IS=9.249 (±0.095), FID(test)=9.110

Evaluating GN-BigGAN_CIFAR10_seed1 ...
  Generated 50000 images.


                                                                              

  GN-BigGAN_CIFAR10_seed1 -> IS=8.979 (±0.097), FID(test)=9.833

Evaluating GN-BigGAN_CIFAR10_seed2 ...
  Generated 50000 images.


                                                                              

  GN-BigGAN_CIFAR10_seed2 -> IS=9.244 (±0.080), FID(test)=8.382

Per-seed results (EMA, FID(test)):
  GN-BigGAN_CIFAR10_seed0: IS=9.249 (±0.095), FID(test)=9.110
  GN-BigGAN_CIFAR10_seed1: IS=8.979 (±0.097), FID(test)=9.833
  GN-BigGAN_CIFAR10_seed2: IS=9.244 (±0.080), FID(test)=8.382

=== Aggregated over seeds (GN-BigGAN, FID(test)) ===
Inception Score: 9.157 ± 0.154
FID(test):       9.109 ± 0.726


MAKING THE GAN TO GENERATE AEROPLANE IMAGE

In [None]:
import torch
from torchvision.utils import make_grid, save_image
from models import biggan
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
z_dim = 128      
n_classes = 10   
ckpt_path = "logs/GN-BigGAN_CIFAR10_seed0/best_model.pt"
ckpt = torch.load(ckpt_path, map_location=device)
net_G = biggan.Generator32(z_dim, n_classes).to(device)
if "ema_G" in ckpt:
    net_G.load_state_dict(ckpt["ema_G"])
else:
    net_G.load_state_dict(ckpt["net_G"])
net_G.eval()
class_id = 0     
num_images = 64
z = torch.randn(num_images, z_dim, device=device)
y = torch.full((num_images,), class_id, dtype=torch.long, device=device)
with torch.no_grad():
    imgs = (net_G(z, y) + 1) / 2   
grid = make_grid(imgs, nrow=8)
save_image(grid, "samples_airplane.png")
print("Saved samples_airplane.png")

Saved samples_airplane.png
