In [1]:
# ============================================================
# ðŸ“¦ BLOCK 1 â€” Install Dependencies & Clone EDM2 Repository
# ============================================================

!pip install click tqdm psutil scipy pillow pandas --quiet

# Clone EDM2 repo
!git clone https://github.com/NVlabs/edm2.git
%cd edm2

# Install any remaining minimal dependencies
!pip install -r requirements.txt --quiet || true

# Return to working directory
%cd /kaggle/working

print("EDM2 setup complete!")


Cloning into 'edm2'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 60 (delta 13), reused 10 (delta 10), pack-reused 33 (from 1)[K
Receiving objects: 100% (60/60), 1.27 MiB | 10.86 MiB/s, done.
Resolving deltas: 100% (24/24), done.
/kaggle/working/edm2
[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0m/kaggle/working
EDM2 setup complete!


In [2]:
# ============================================================
# ðŸ§© BLOCK 2 â€” Prepare FULL Test Labels + Filenames
# ============================================================

import pandas as pd
import numpy as np
import os

ATTR_CSV = "/kaggle/input/celeva-64x64-dataset/celeba64/list_attr_celeba.csv"
TEST_DIR = "/kaggle/input/celeva-64x64-dataset/celeba64/test"
OUT_DIR = "/kaggle/working/cond-test-labels"
os.makedirs(OUT_DIR, exist_ok=True)

# Load attributes
df = pd.read_csv(ATTR_CSV)
img_col = df.columns[0]
attr_cols = df.columns[1:41]

# Convert -1/+1 â†’ 0/1
attrs = (df[attr_cols].values > 0).astype(np.float32)

# All test images
test_files = sorted([f for f in os.listdir(TEST_DIR) if f.endswith(".jpg")])
print("Total test images:", len(test_files))

# Only keep images present in CSV
valid_test_files = [f for f in test_files if f in df[img_col].values]
print("Valid test images (in CSV):", len(valid_test_files))

# Build full array (NO random sampling)
labels_full = np.array([
    attrs[df[img_col].values.tolist().index(f)]
    for f in valid_test_files
])
filenames_full = np.array(valid_test_files)

# Save
np.save(f"{OUT_DIR}/labels_full.npy", labels_full)
np.save(f"{OUT_DIR}/filenames_full.npy", filenames_full)

print("Saved FULL test labels:")
print(f"{OUT_DIR}/labels_full.npy")
print(f"{OUT_DIR}/filenames_full.npy")


Total test images: 19962
Valid test images (in CSV): 19962
Saved FULL test labels:
/kaggle/working/cond-test-labels/labels_full.npy
/kaggle/working/cond-test-labels/filenames_full.npy


In [3]:
# ============================================================
# ðŸŽ¨ BLOCK 3 â€” Write Conditional Generator Script (NO %%writefile)
# ============================================================

generator_code = r'''
import os
import pickle
import torch
import numpy as np
import PIL.Image
import sys

# ------------------------------------------------------------
# Make sure EDM2 repo is importable
# ------------------------------------------------------------
sys.path.append("/kaggle/working/edm2")
import dnnlib
from edm2.generate_images import edm_sampler


def generate_conditional(model, labels, names, outdir,
                         steps=32, sigma_min=0.002, sigma_max=80, rho=7):

    os.makedirs(outdir, exist_ok=True)
    device = torch.device("cuda")

    print("Loading model:", model)
    with open(model, "rb") as f:
        net = pickle.load(f)["ema"].to(device).eval()

    labels_np = np.load(labels)
    filenames = np.load(names)

    N = labels_np.shape[0]
    print(f"Generating {N} images...")

    bs_max = 64

    for start in range(0, N, bs_max):
        bs = min(bs_max, N - start)
        cond = torch.tensor(labels_np[start:start+bs], device=device)
        noise = torch.randn(bs, net.img_channels, net.img_resolution, net.img_resolution, device=device)

        # EDM2 sampler (Karras rho)
        imgs = edm_sampler(
            net=net,
            noise=noise,
            labels=cond,
            num_steps=steps,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
        )

        imgs = imgs.clamp(-1, 1)
        imgs = (imgs * 127.5 + 127.5).to(torch.uint8)
        imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()

        for j in range(bs):
            outname = filenames[start + j].replace(".jpg", ".png")
            PIL.Image.fromarray(imgs[j]).save(os.path.join(outdir, outname))

    print("DONE! Saved images to", outdir)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--labels", type=str, required=True)
    parser.add_argument("--names", type=str, required=True)
    parser.add_argument("--outdir", type=str, required=True)
    parser.add_argument("--steps", type=int, default=32)

    args = parser.parse_args()

    generate_conditional(
        model=args.model,
        labels=args.labels,
        names=args.names,
        outdir=args.outdir,
        steps=args.steps,
    )
'''

# Write generator script to a file
with open("gen_conditional.py", "w") as f:
    f.write(generator_code)

print("âœ… gen_conditional.py written successfully")


âœ… gen_conditional.py written successfully


In [4]:
# ============================================================
# ðŸš€ BLOCK 4 â€” Run Conditional Generation (FULL test set)
# ============================================================

MODEL_PKL = "/kaggle/input/network-snapshot-0001572-0-100/pytorch/default/1/network-snapshot-0001572-0.100.pkl"

LABELS_FULL = "/kaggle/working/cond-test-labels/labels_full.npy"
FILENAMES_FULL = "/kaggle/working/cond-test-labels/filenames_full.npy"

OUT_IMAGES = "/kaggle/working/cond-generated-full"

!python gen_conditional.py \
    --model "$MODEL_PKL" \
    --labels "$LABELS_FULL" \
    --names "$FILENAMES_FULL" \
    --outdir "$OUT_IMAGES" \
    --steps 32


Loading model: /kaggle/input/network-snapshot-0001572-0-100/pytorch/default/1/network-snapshot-0001572-0.100.pkl
Generating 19962 images...
DONE! Saved images to /kaggle/working/cond-generated-full


In [5]:
# ============================================================
# ðŸ“Š BLOCK 5 â€” Compute CelebA64 Reference Stats (once)
# ============================================================

# Compute reference statistics on whole test set (19,962 images)
!python /kaggle/working/edm2/calculate_metrics.py ref \
    --data="/kaggle/input/celeva-64x64-dataset/celeba64/test" \
    --dest="/kaggle/working/celeba64_ref.pkl" \
    --metrics=fid \
    --batch=64 \
    --workers=2


Loading images from /kaggle/input/celeva-64x64-dataset/celeba64/test ...
[rank0]:[W1201 17:13:37.719583878 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
Setting up InceptionV3Detector...
Calculating feature statistics...
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 312/312 [01:36<00:00,  3.24batch/s]


In [6]:
# ============================================================
# ðŸ§® BLOCK 6 â€” Compute FID for FULL test set (2 GPUs)
# ============================================================

NUM_FULL = len(os.listdir("/kaggle/working/cond-generated-full"))

!torchrun --standalone --nproc_per_node=2 \
    /kaggle/working/edm2/calculate_metrics.py calc \
    --images="/kaggle/working/cond-generated-full" \
    --ref="/kaggle/working/celeba64_ref.pkl" \
    --metrics=fid \
    --num=$NUM_FULL \
    --batch=64 \
    --workers=2


W1201 17:15:18.235000 101 torch/distributed/run.py:792] 
W1201 17:15:18.235000 101 torch/distributed/run.py:792] *****************************************
W1201 17:15:18.235000 101 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1201 17:15:18.235000 101 torch/distributed/run.py:792] *****************************************
[W1201 17:15:18.722062623 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1201 17:15:18.722739527 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1201 17:15:20.577236791 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1201 17:15:20.577977499 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W1201 17:15:20.598248535 soc