In [None]:
import warnings
warnings.filterwarnings('ignore')

TESTING NEW HYPOTHESIS

SGD WITH EXPONENTIAL MOVING AVERAGE

In [None]:
!python train_hypothesis.py --resume --flagfile=./config/GN-GAN_CIFAR10_BIGGAN.txt --sgd_ema --lr_D=0.001 --lr_G=0.0005 --lr_decay_start=60000 --ema_decay=0.999 --seed=0 --logdir=./logs/GN-BigGAN_CIFAR10_sgdema


SGD WITH DYNAMIC LEARNING RATE

In [None]:
!python train_hypothesis.py --resume --flagfile=./config/GN-GAN_CIFAR10_BIGGAN.txt --sgd_ema_dlr \
--lr_D=0.001 \
--lr_G=0.0005 \
--lr_warmup_steps=10000 \
--lr_min_mult=0.1 \
--ema_decay=0.999 \
--seed=0 \
--logdir=./logs/GN-BigGAN_CIFAR10_sgdema_dlr

PRINTING THE RESULTS

In [5]:
import os
import sys
import types
import json
import torch
import numpy as np
try:
    from torchvision.models.utils import load_state_dict_from_url  # type: ignore
except ImportError:
    from torch.hub import load_state_dict_from_url
    utils_mod = types.ModuleType("torchvision.models.utils")
    utils_mod.load_state_dict_from_url = load_state_dict_from_url
    sys.modules["torchvision.models.utils"] = utils_mod
from pytorch_gan_metrics import get_inception_score_and_fid
from models import biggan  
stats_path = "./stats/cifar10.test.npz" 
z_dim = 128
n_classes = 10
num_images = 50000
batch_size = 128
runs = {
    "adam_seed0": "logs/GN-BigGAN_CIFAR10_seed0/best_model.pt",
    "sgd_ema": "logs/GN-BigGAN_CIFAR10_sgdema/best_model.pt",
    "sgd_ema_dlr": "logs/GN-BigGAN_CIFAR10_sgdema_dlr/best_model.pt",
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def generate_images(net_G, num_images, batch_size, z_dim, n_classes, device):
    """Generate num_images fake samples in [0,1] as float32."""
    images = []
    remaining = num_images
    net_G.eval()
    with torch.no_grad():
        while remaining > 0:
            bs = min(batch_size, remaining)
            z = torch.randn(bs, z_dim, device=device)
            y = torch.randint(0, n_classes, (bs,), device=device)
            fake = (net_G(z, y) + 1) / 2.0  # [-1,1] → [0,1]
            fake = fake.clamp(0.0, 1.0)
            images.append(fake.cpu())
            remaining -= bs
    images = torch.cat(images, dim=0)[:num_images]
    return images

print("Evaluating runs:")
for name, path in runs.items():
    print(f"  {name}: {path}")

results = {}     
results_list = [] 

for name, ckpt_path in runs.items():
    if not os.path.isfile(ckpt_path):
        print(f"\n[WARN] Skipping {name}, checkpoint not found: {ckpt_path}")
        continue
    print(f"\n=== Evaluating {name} ===")
    ckpt = torch.load(ckpt_path, map_location=device)
    net_G = biggan.Generator32(z_dim, n_classes).to(device)
    if "ema_G" in ckpt:
        net_G.load_state_dict(ckpt["ema_G"])
        print("  Loaded EMA generator weights.")
    else:
        net_G.load_state_dict(ckpt["net_G"])
        print("  Loaded net_G weights (no EMA found).")
    images = generate_images(
        net_G=net_G,
        num_images=num_images,
        batch_size=batch_size,
        z_dim=z_dim,
        n_classes=n_classes,
        device=device,
    )
    print(f"  Generated {images.shape[0]} images for metrics.")

    (IS, IS_std), FID = get_inception_score_and_fid(
        images, stats_path, verbose=True
    )

    print(f"  {name}: IS={IS:.3f} (±{IS_std:.3f}), FID={FID:.3f}")

    info = {
        "name": name,
        "ckpt_path": ckpt_path,
        "IS": float(IS),
        "IS_std": float(IS_std),
        "FID": float(FID),
    }
    results[name] = info
    results_list.append(info)
if results_list:
    print("\nPer-run results:")
    for r in results_list:
        print(
            f"  {r['name']}: IS={r['IS']:.3f} (±{r['IS_std']:.3f}), "
            f"FID={r['FID']:.3f}"
        )

    out_json = "gn_biggan_optimizer_eval_results.json"
    payload = {
        "stats_path": stats_path,
        "num_images": num_images,
        "batch_size": batch_size,
        "z_dim": z_dim,
        "n_classes": n_classes,
        "results": results_list,
    }
    with open(out_json, "w") as f:
        json.dump(payload, f, indent=2)

    print(f"\nSaved detailed results to: {out_json}")
else:
    print("\nNo checkpoints were successfully evaluated.")

Evaluating runs:
  adam_seed0: logs/GN-BigGAN_CIFAR10_seed0/best_model.pt
  sgd_ema: logs/GN-BigGAN_CIFAR10_sgdema/best_model.pt
  sgd_ema_dlr: logs/GN-BigGAN_CIFAR10_sgdema_dlr/best_model.pt

=== Evaluating adam_seed0 ===
  Loaded EMA generator weights.
  Generated 50000 images for metrics.




  adam_seed0: IS=9.222 (±0.140), FID=9.056

=== Evaluating sgd_ema ===
  Loaded EMA generator weights.
  Generated 50000 images for metrics.




  sgd_ema: IS=7.637 (±0.091), FID=21.462

=== Evaluating sgd_ema_dlr ===
  Loaded EMA generator weights.
  Generated 50000 images for metrics.




  sgd_ema_dlr: IS=7.335 (±0.074), FID=24.836

Per-run results:
  adam_seed0: IS=9.222 (±0.140), FID=9.056
  sgd_ema: IS=7.637 (±0.091), FID=21.462
  sgd_ema_dlr: IS=7.335 (±0.074), FID=24.836

Saved detailed results to: gn_biggan_optimizer_eval_results.json
