# StyleGAN2 + CLIP (StyleGAN-NADA) — Inference Only

**Purpose:** Load a trained generator and generate samples. No training.

**Flow:** Choose **VERSION** and **EPOCH**. Each checkpoint is in its own Drive folder; we download only that folder. Local layout: `checkpoints/v1/50/checkpoint_50.pt`, `checkpoints/v1/200/checkpoint_200.pt`, etc. StyleGAN2 code from GitHub.

## User settings

Choose **VERSION** and **EPOCH** of the checkpoint to load. Epochs: 50, 100, 150, 200, 250, 300.

## Step 1: Checkpoint folder (one per version) + download

Paste a **direct download link** for each checkpoint (version, epoch). Use Google Drive “Get link” → “Anyone with the link” and paste the link or the `https://drive.google.com/uc?id=FILE_ID` URL. Then run the cell below to download the chosen checkpoint if needed.

In [51]:
# One folder link per (version, epoch). Each folder contains one file: checkpoint_EPOCH.pt
# Structure: checkpoints/VERSION/EPOCH/checkpoint_EPOCH.pt
CHECKPOINT_FOLDER_LINKS = {
    "1": {
        50: "https://drive.google.com/drive/folders/1MvSVn62B_Hz21V3-5cTFktSB-CY-toma?usp=drive_link",
        100: "https://drive.google.com/drive/folders/1omeDE_8Anl9IjV-JzRN4es_h_S5ECh-R?usp=drive_link",
        150: "https://drive.google.com/drive/folders/PASTE_FOLDER_ID",
        200: "https://drive.google.com/drive/folders/1omeDE_8Anl9IjV-JzRN4es_h_S5ECh-R?usp=drive_link",
        250: "https://drive.google.com/drive/folders/PASTE_FOLDER_ID",
        300: "https://drive.google.com/drive/folders/PASTE_FOLDER_ID",
    },
    "2": {50: "...", 100: "...", 150: "...", 200: "...", 250: "...", 300: "..."},
}

In [52]:
import os

WORK_DIR = "/content" if os.path.exists("/content") else "."
CHECKPOINTS_ROOT = os.path.join(WORK_DIR, "checkpoints")
REPOS_DIR = os.path.join(WORK_DIR, "repos")
CHECKPOINT_NAME = f"checkpoint_{EPOCH}.pt"
OUTPUT_DIR = os.path.join(WORK_DIR, "output")
os.makedirs(OUTPUT_DIR, exist_ok=True)

version_dir = os.path.join(CHECKPOINTS_ROOT, VERSION)
epoch_dir = os.path.join(version_dir, str(EPOCH))
CKPT_PATH = os.path.join(epoch_dir, CHECKPOINT_NAME)
if VERSION not in CHECKPOINT_FOLDER_LINKS or EPOCH not in CHECKPOINT_FOLDER_LINKS.get(VERSION, {}):
    raise ValueError(f"No folder link for (VERSION={VERSION}, EPOCH={EPOCH}). Edit CHECKPOINT_FOLDER_LINKS in the cell above.")
folder_url = CHECKPOINT_FOLDER_LINKS[VERSION][EPOCH]
if not isinstance(folder_url, str) or "PASTE_FOLDER" in folder_url or "/folders/" not in folder_url:
    raise ValueError(
        f"Invalid folder link for (VERSION={VERSION}, EPOCH={EPOCH}). "
        "Replace the placeholder in CHECKPOINT_FOLDER_LINKS (cell above) with a Google Drive folder link "
        "(share: Anyone with the link). Or use EPOCH 50/100 if you have not set links for this epoch yet."
    )
folder_id = folder_url.strip().split("/folders/")[-1].split("/")[0].split("?")[0]

need_download = not os.path.isfile(CKPT_PATH)
if need_download:
    import gdown
    import shutil
    os.makedirs(version_dir, exist_ok=True)
    print("Downloading checkpoint folder for", VERSION, "/", EPOCH, "...")
    gdown.download_folder(id=folder_id, output=version_dir, quiet=False)
    src = None
    for root, _, files in os.walk(version_dir):
        if CHECKPOINT_NAME in files:
            src = os.path.join(root, CHECKPOINT_NAME)
            break
    if not src or not os.path.isfile(src):
        raise FileNotFoundError(f"{CHECKPOINT_NAME} not found in downloaded folder. Ensure the Drive folder contains that file.")
    os.makedirs(epoch_dir, exist_ok=True)
    shutil.copy2(src, CKPT_PATH)
    print("Saved to", CKPT_PATH)
else:
    print("Checkpoint already on disk:", CKPT_PATH)

CHECKPOINTS_DIR = epoch_dir
if os.path.isdir(CHECKPOINTS_DIR):
    print("Available:", sorted([f for f in os.listdir(CHECKPOINTS_DIR) if f.endswith(".pt")]) or "(none)")
print("Using:", CKPT_PATH)

Downloading checkpoint folder for 1 / 200 ...


Retrieving folder contents


Processing file 1LA89XKMwRCTZqLeacGuU4JSCtxOPMQIH Copy of checkpoint_200.pt


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From (original): https://drive.google.com/uc?id=1LA89XKMwRCTZqLeacGuU4JSCtxOPMQIH
From (redirected): https://drive.google.com/uc?id=1LA89XKMwRCTZqLeacGuU4JSCtxOPMQIH&confirm=t&uuid=242b0fee-778f-43a0-835a-7114a02e69ee
To: /content/checkpoints/1/Copy of checkpoint_200.pt
100%|██████████| 133M/133M [00:01<00:00, 129MB/s]  
Download completed


Saved to /content/checkpoints/1/200/checkpoint_200.pt
Available: ['checkpoint_200.pt']
Using: /content/checkpoints/1/200/checkpoint_200.pt


## Step 2: Install deps and StyleGAN2 repo (GitHub)

Clone StyleGAN2 repo from GitHub if not already present. No Drive, no CLIP, no base weights.

In [None]:
!pip install -q ftfy regex tqdm gdown Ninja

In [None]:
import torch
import sys
import random
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"CUDA: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'none'}")

In [None]:
SEED = 3456
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
print(f"Seeds fixed to {SEED}.")

In [None]:
# StyleGAN2 repo
STYLEGAN2_ROOT = os.path.join(REPOS_DIR, "stylegan2-pytorch")
if not os.path.isfile(os.path.join(STYLEGAN2_ROOT, "model.py")):
    os.makedirs(REPOS_DIR, exist_ok=True)
    print("Cloning StyleGAN2 repo...")
    !git clone https://github.com/rosinality/stylegan2-pytorch.git "{STYLEGAN2_ROOT}"
else:
    print("StyleGAN2 repo at", STYLEGAN2_ROOT)
sys.path.insert(0, STYLEGAN2_ROOT)

In [None]:
# Inference needs only checkpoint (G_train). No base weights, no CLIP.

## Step 3: Load trained generator

Create Generator, load trained weights from checkpoint (key `G_train`). Checkpoint contains full trained generator.

In [None]:
from model import Generator

latent_dim = 512
generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)

if not os.path.isfile(CKPT_PATH):
    raise FileNotFoundError(f"Checkpoint not found: {CKPT_PATH}. Set VERSION and EPOCH in User settings, then run Step 1 to download.")
with open(CKPT_PATH, "rb") as f:
    if f.read(100).lstrip().startswith(b"<"):
        raise RuntimeError(
            f"{CKPT_PATH} is an HTML file (Drive returned a page, not the .pt). "
            "Re-run Step 1 with a direct download link (https://drive.google.com/uc?id=FILE_ID), or download the .pt manually."
        )
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)
generator.load_state_dict(ckpt["G_train"])
generator.eval()

print("Trained generator loaded from", CKPT_PATH)
if "source" in ckpt and "target" in ckpt:
    print("  source:", ckpt.get("source", "—"))
    print("  target:", ckpt.get("target", "—"))
    print("  iter:", ckpt.get("iter", "—"))

## Step 4: Generate samples

Sample random z, generate images with the trained generator, display.

In [None]:
import matplotlib.pyplot as plt

def to_np(tensor):
    """Convert (B, 3, H, W) tensor in [-1,1] to numpy for display."""
    x = (tensor.clamp(-1, 1) + 1) / 2
    x = x.permute(0, 2, 3, 1).cpu().numpy()
    return x

N_SAMPLES = 4  # number of images to generate

with torch.no_grad():
    z = torch.randn(N_SAMPLES, latent_dim, device=device)
    imgs, _ = generator([z], input_is_latent=False)

imgs_np = to_np(imgs)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()
for i in range(N_SAMPLES):
    axes[i].imshow(imgs_np[i])
    axes[i].axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Optional: save generated images
from PIL import Image

ckpt_basename = os.path.splitext(os.path.basename(CKPT_PATH))[0]
for i in range(N_SAMPLES):
    arr = (imgs_np[i] * 255).astype(np.uint8)
    img = Image.fromarray(arr)
    out_path = os.path.join(OUTPUT_DIR, f"{ckpt_basename}_sample_{i}.png")
    img.save(out_path)
    print("Saved:", out_path)