In [None]:
import torch
from torchvision.utils import save_image
from pathlib import Path
from tqdm import tqdm

from config import Config
from utils import ensure_dirs, device
from models.ie_gan import Generator

def generate(num_images=500):
    cfg = Config()
    dev = device()
    ensure_dirs(cfg.synth_dir)

    ckpt = torch.load(cfg.gan_ckpt, map_location=dev)
    G = Generator(z_dim=cfg.z_dim).to(dev)
    G.load_state_dict(ckpt["G"])
    G.eval()

    for i in tqdm(range(num_images), desc="Generating synthetic"):
        z = torch.randn(1, cfg.z_dim, device=dev)
        img = G(z)
        # convert from [-1,1] to [0,1]
        img = (img + 1) / 2
        save_image(img, Path(cfg.synth_dir) / f"synth_{i:05d}.png")

    print("Synthetic images saved to:", cfg.synth_dir)

if __name__ == "__main__":
    generate(num_images=500)
