In [1]:
# ============================================================
# üì¶ SETUP
# ============================================================

!git clone https://github.com/NVlabs/edm2.git /kaggle/working/edm2
%cd /kaggle/working/edm2
!pip install click tqdm psutil scipy pillow matplotlib pandas --quiet

import sys
sys.path.append("/kaggle/working/edm2")


Cloning into '/kaggle/working/edm2'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 60 (delta 13), reused 10 (delta 10), pack-reused 33 (from 1)[K
Receiving objects: 100% (60/60), 1.27 MiB | 9.87 MiB/s, done.
Resolving deltas: 100% (24/24), done.
/kaggle/working/edm2


In [2]:
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path

import sys
sys.path.append("/kaggle/working/edm2")
from training import networks_edm2
from training import training_loop


In [3]:
# ============================================================
# üì¶ LOAD CONDITIONAL EDM2 MODEL (40 ATTR, KARRAS)
# ============================================================

import torch, pickle

MODEL_PKL = "/kaggle/input/network-snapshot-0001572-0-100-kr-40attr/pytorch/default/1/network-snapshot-0001572-0.100_kr_40attr.pkl"

device = torch.device("cuda")

with open(MODEL_PKL, "rb") as f:
    net = pickle.load(f)["ema"].to(device).eval()

print("‚úÖ Model loaded")
print("Image resolution:", net.img_resolution)
print("Label dim:", net.label_dim)


‚úÖ Model loaded
Image resolution: 64
Label dim: 40


In [4]:
# ============================================================
# üì¶ LOAD CONDITIONAL EDM2 MODEL
# ============================================================

import pickle
import torch

MODEL_PKL = "/kaggle/input/network-snapshot-0001572-0-100-kr-40attr/pytorch/default/1/network-snapshot-0001572-0.100_kr_40attr.pkl"

device = torch.device("cuda")

with open(MODEL_PKL, "rb") as f:
    net = pickle.load(f)["ema"].to(device).eval()

print("‚úÖ Conditional model loaded")
print("Condition dim:", net.label_dim)


‚úÖ Conditional model loaded
Condition dim: 40


In [5]:
# ============================================================
# üé≤ PURE NOISE + CONDITION
# ============================================================

C = net.img_channels
H = W = net.img_resolution

# Pure noise
noise = torch.randn(1, C, H, W, device=device)

# Condition (1 √ó 40)
labels = torch.tensor(attr_vec, device=device).unsqueeze(0)

print("Noise shape:", noise.shape)
print("Labels shape:", labels.shape)


NameError: name 'attr_vec' is not defined

In [6]:
from pathlib import Path
import textwrap

SAMPLER_PATH = Path("/kaggle/working/edm2/edm_step_sampler.py")

SAMPLER_PATH.write_text(textwrap.dedent("""
import os
import torch
import numpy as np
import PIL.Image

def edm_sampler_with_steps(
    net,
    noise,
    labels=None,
    gnet=None,
    randn_like=None,
    num_steps=32,
    sigma_min=0.002,
    sigma_max=80,
    rho=7,
    schedule_name="karras",
    guidance=1,
    save_dir=None,
    dtype=torch.float32,
    **_ignored,
):
    assert save_dir is not None
    os.makedirs(save_dir, exist_ok=True)
    device = noise.device

    # ------------------ Karras rho ------------------
    step_indices = torch.arange(num_steps, device=device)
    sigmas = (
        sigma_max ** (1 / rho)
        + step_indices / (num_steps - 1)
        * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
    ) ** rho

    sigmas = torch.cat([sigmas, torch.zeros_like(sigmas[:1])])

    def denoise(x, t):
        Dx = net(x, t, labels).to(dtype)
        if guidance == 1 or gnet is None:
            return Dx
        ref = gnet(x, t, labels).to(dtype)
        return ref.lerp(Dx, guidance)

    x = noise * sigmas[0]

    for i in range(num_steps):
        t_cur = sigmas[i]
        t_next = sigmas[i + 1]

        d = (x - denoise(x, t_cur)) / t_cur
        x = x + (t_next - t_cur) * d

        img = x[0].detach().cpu().clamp(-1, 1)
        img = ((img + 1) / 2 * 255).permute(1, 2, 0).numpy().astype(np.uint8)

        PIL.Image.fromarray(img).save(
            os.path.join(save_dir, f"step_{i:03d}.png")
        )

    return x
"""))

print("‚úÖ Conditional-ready edm_step_sampler.py written")


‚úÖ Conditional-ready edm_step_sampler.py written


In [7]:
# ============================================================
# üß™ CONDITIONAL EDM SAMPLER WITH STEP SAVING
# ============================================================

import os
import PIL.Image

def edm_sampler_with_steps(
    net,
    noise,
    labels,
    save_dir,
    num_steps=32,
    sigma_min=0.002,
    sigma_max=80,
    rho=7,
):
    os.makedirs(save_dir, exist_ok=True)
    device = noise.device

    # Karras rho schedule
    i = torch.arange(num_steps, device=device)
    sigmas = (
        sigma_max ** (1 / rho)
        + i / (num_steps - 1)
        * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
    ) ** rho
    sigmas = torch.cat([sigmas, torch.zeros_like(sigmas[:1])])

    def denoise(x, t):
        return net(x, t, labels)

    x = noise * sigmas[0]

    for i in range(num_steps):
        t_cur, t_next = sigmas[i], sigmas[i+1]
        d = (x - denoise(x, t_cur)) / t_cur
        x = x + (t_next - t_cur) * d

        img = x[0].detach().cpu().clamp(-1, 1)
        img = ((img + 1) / 2 * 255).permute(1, 2, 0).numpy().astype(np.uint8)
        PIL.Image.fromarray(img).save(f"{save_dir}/step_{i:03d}.png")

    return sigmas[:-1]


In [9]:
# ============================================================
# üß© DEFINE 3 CONDITIONAL EXPERIMENTS (TEST SET)
# ============================================================

import pandas as pd
import numpy as np

ATTR_CSV = "/kaggle/input/celeva-64x64-dataset/celeba64/list_attr_celeba.csv"

df = pd.read_csv(ATTR_CSV)
attr_cols = df.columns[1:41]

# CelebA test split (0-indexed)
TEST_START = 182637
TEST_END   = 202599

test_df = df.iloc[TEST_START:TEST_END].reset_index(drop=True)

# Convert {-1, +1} ‚Üí {0, 1}
attr_matrix = (test_df[attr_cols].values > 0).astype(np.float32)
attr_counts = attr_matrix.sum(axis=1)

# ------------------------------------------------------------
# 1Ô∏è‚É£ Max +1 attributes
# ------------------------------------------------------------
idx_max = np.argmax(attr_counts)
attrs_max = attr_matrix[idx_max]

# ------------------------------------------------------------
# 2Ô∏è‚É£ Min +1 attributes
# ------------------------------------------------------------
idx_min = np.argmin(attr_counts)
attrs_min = attr_matrix[idx_min]

# ------------------------------------------------------------
# 3Ô∏è‚É£ Bald + Mustache
# ------------------------------------------------------------
mask = (test_df["Bald"] == 1) & (test_df["Mustache"] == 1)
idx_bm = np.where(mask)[0][0]
attrs_bm = (test_df.loc[idx_bm, attr_cols].values > 0).astype(np.float32)

# ------------------------------------------------------------
# FINAL EXPERIMENT DICT ‚úÖ
# ------------------------------------------------------------
EXPERIMENTS = {
    "max_attrs": attrs_max,
    "min_attrs": attrs_min,
    "bald_mustache": attrs_bm,
}

print("‚úÖ Experiments ready:")
for k, v in EXPERIMENTS.items():
    print(f"{k:15s} ‚Üí active attributes = {int(v.sum())}")


‚úÖ Experiments ready:
max_attrs       ‚Üí active attributes = 20
min_attrs       ‚Üí active attributes = 1
bald_mustache   ‚Üí active attributes = 15


In [10]:
# ============================================================
# üöÄ RUN 3 CONDITIONAL EXPERIMENTS
# ============================================================

C = net.img_channels
H = W = net.img_resolution

for name, attr_vec in EXPERIMENTS.items():
    print(f"\nüöÄ Running experiment: {name}")

    # Pure noise
    noise = torch.randn(1, C, H, W, device=device)

    # Conditional attributes (1 √ó 40)
    labels = torch.tensor(attr_vec, device=device).unsqueeze(0)

    outdir = f"/kaggle/working/cond_{name}_karras_steps"

    sigmas = edm_sampler_with_steps(
        net=net,
        noise=noise,
        labels=labels,
        save_dir=outdir,
        num_steps=32,
    )

    print(f"‚úÖ Saved steps to: {outdir}")

print("\nüéâ All conditional trajectories generated successfully")



üöÄ Running experiment: max_attrs
‚úÖ Saved steps to: /kaggle/working/cond_max_attrs_karras_steps

üöÄ Running experiment: min_attrs
‚úÖ Saved steps to: /kaggle/working/cond_min_attrs_karras_steps

üöÄ Running experiment: bald_mustache
‚úÖ Saved steps to: /kaggle/working/cond_bald_mustache_karras_steps

üéâ All conditional trajectories generated successfully


In [11]:
# ============================================================
# üß© STITCH CONDITIONAL TRAJECTORIES (4√ó8, œÉ-ANNOTATED)
# ============================================================

from PIL import Image, ImageDraw, ImageFont
import torch, os, math

# ------------------------------------------------------------
# Experiment folders
# ------------------------------------------------------------
EXPERIMENT_DIRS = {
    "bald_mustache": "/kaggle/working/cond_bald_mustache_karras_steps",
    "max_attrs": "/kaggle/working/cond_max_attrs_karras_steps",
    "min_attrs": "/kaggle/working/cond_min_attrs_karras_steps",
}

OUT_ROOT = "/kaggle/working"

# ------------------------------------------------------------
# Layout
# ------------------------------------------------------------
ROWS = 4
COLS = 8
TOP_PAD = 32
ROW_GAP = 18

# ------------------------------------------------------------
# Karras œÉ parameters (MUST match sampler)
# ------------------------------------------------------------
sigma_min = 0.002
sigma_max = 80
rho = 7
NUM_STEPS = ROWS * COLS

# ------------------------------------------------------------
# Compute Karras sigmas
# ------------------------------------------------------------
i = torch.arange(NUM_STEPS)
sigmas = (
    sigma_max ** (1 / rho)
    + i / (NUM_STEPS - 1)
    * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho

# ------------------------------------------------------------
# Safe font (Kaggle-compatible)
# ------------------------------------------------------------
try:
    font = ImageFont.truetype("DejaVuSans.ttf", 13)
except OSError:
    font = ImageFont.load_default()

# ------------------------------------------------------------
# Process each experiment
# ------------------------------------------------------------
for name, step_dir in EXPERIMENT_DIRS.items():
    print(f"üîß Stitching: {name}")

    # Load step images
    imgs = [
        Image.open(os.path.join(step_dir, f"step_{i:03d}.png"))
        for i in range(NUM_STEPS)
    ]

    w, h = imgs[0].size

    canvas_w = COLS * w
    canvas_h = ROWS * (h + TOP_PAD + ROW_GAP)

    canvas = Image.new("RGB", (canvas_w, canvas_h), "white")
    draw = ImageDraw.Draw(canvas)

    # --------------------------------------------------------
    # Paste images + œÉ labels
    # --------------------------------------------------------
    for idx, img in enumerate(imgs):
        r = idx // COLS
        c = idx % COLS

        x = c * w
        y = r * (h + TOP_PAD + ROW_GAP)

        # œÉ label
        draw.text(
            (x + w // 2, y + 10),
            f"œÉ = {sigmas[idx]:.2f}",
            anchor="mm",
            fill="black",
            font=font,
        )

        # image
        canvas.paste(img, (x, y + TOP_PAD))

    # --------------------------------------------------------
    # Save
    # --------------------------------------------------------
    out_path = f"{OUT_ROOT}/{name}_karras_4x8_sigma.png"
    canvas.save(out_path)
    print(f"‚úÖ Saved: {out_path}")

print("\nüéâ All stitched images generated successfully")


üîß Stitching: bald_mustache
‚úÖ Saved: /kaggle/working/bald_mustache_karras_4x8_sigma.png
üîß Stitching: max_attrs
‚úÖ Saved: /kaggle/working/max_attrs_karras_4x8_sigma.png
üîß Stitching: min_attrs
‚úÖ Saved: /kaggle/working/min_attrs_karras_4x8_sigma.png

üéâ All stitched images generated successfully


In [12]:
# ============================================================
# üßæ PRINT ATTRIBUTES FOR 3 CONDITIONAL EXPERIMENTS (CelebA)
# ============================================================

import pandas as pd
import numpy as np

# ------------------------------------------------------------
# Load CelebA attributes
# ------------------------------------------------------------
ATTR_CSV = "/kaggle/input/celeva-64x64-dataset/celeba64/list_attr_celeba.csv"
df = pd.read_csv(ATTR_CSV)

img_col = df.columns[0]
attr_cols = list(df.columns[1:41])  # 40 attributes

# ------------------------------------------------------------
# Test split (CelebA official)
# Rows 182638‚Äì202599 (1-indexed)
# ------------------------------------------------------------
TEST_START = 182637  # zero-indexed
TEST_END   = 202599

test_df = df.iloc[TEST_START:TEST_END].copy()

# Convert {-1, +1} ‚Üí {0, 1}
attr_matrix = (test_df[attr_cols].values > 0).astype(np.int32)

# ------------------------------------------------------------
# 1Ô∏è‚É£ Max +1 attributes
# ------------------------------------------------------------
max_idx = np.argmax(attr_matrix.sum(axis=1))
max_row = test_df.iloc[max_idx]
max_attrs = [a for a, v in zip(attr_cols, attr_matrix[max_idx]) if v == 1]

# ------------------------------------------------------------
# 2Ô∏è‚É£ Min +1 attributes (max ‚àí1)
# ------------------------------------------------------------
min_idx = np.argmin(attr_matrix.sum(axis=1))
min_row = test_df.iloc[min_idx]
min_attrs = [a for a, v in zip(attr_cols, attr_matrix[min_idx]) if v == 1]

# ------------------------------------------------------------
# 3Ô∏è‚É£ Bald + Mustache = 1
# ------------------------------------------------------------
bald_i = attr_cols.index("Bald")
mustache_i = attr_cols.index("Mustache")

mask = (attr_matrix[:, bald_i] == 1) & (attr_matrix[:, mustache_i] == 1)
bm_idx = np.where(mask)[0][0]   # deterministic first match
bm_row = test_df.iloc[bm_idx]
bm_attrs = [a for a, v in zip(attr_cols, attr_matrix[bm_idx]) if v == 1]

# ------------------------------------------------------------
# Pretty print
# ------------------------------------------------------------
def print_block(title, row, attrs):
    print("\n" + "="*60)
    print(title)
    print("="*60)
    print("Image ID :", row[img_col])
    print("Num +1 attrs:", len(attrs))
    print("Attributes:")
    for a in attrs:
        print("  ‚Ä¢", a)

print_block("üî∫ MAX ATTRIBUTES EXPERIMENT", max_row, max_attrs)
print_block("üîª MIN ATTRIBUTES EXPERIMENT", min_row, min_attrs)
print_block("üßî BALD + MUSTACHE EXPERIMENT", bm_row, bm_attrs)



üî∫ MAX ATTRIBUTES EXPERIMENT
Image ID : 184620.jpg
Num +1 attrs: 20
Attributes:
  ‚Ä¢ Arched_Eyebrows
  ‚Ä¢ Bags_Under_Eyes
  ‚Ä¢ Big_Lips
  ‚Ä¢ Big_Nose
  ‚Ä¢ Black_Hair
  ‚Ä¢ Chubby
  ‚Ä¢ Double_Chin
  ‚Ä¢ Heavy_Makeup
  ‚Ä¢ High_Cheekbones
  ‚Ä¢ Mouth_Slightly_Open
  ‚Ä¢ Narrow_Eyes
  ‚Ä¢ No_Beard
  ‚Ä¢ Receding_Hairline
  ‚Ä¢ Rosy_Cheeks
  ‚Ä¢ Smiling
  ‚Ä¢ Wavy_Hair
  ‚Ä¢ Wearing_Earrings
  ‚Ä¢ Wearing_Lipstick
  ‚Ä¢ Wearing_Necklace
  ‚Ä¢ Young

üîª MIN ATTRIBUTES EXPERIMENT
Image ID : 183075.jpg
Num +1 attrs: 1
Attributes:
  ‚Ä¢ No_Beard

üßî BALD + MUSTACHE EXPERIMENT
Image ID : 182797.jpg
Num +1 attrs: 15
Attributes:
  ‚Ä¢ 5_o_Clock_Shadow
  ‚Ä¢ Bald
  ‚Ä¢ Big_Lips
  ‚Ä¢ Big_Nose
  ‚Ä¢ Double_Chin
  ‚Ä¢ Eyeglasses
  ‚Ä¢ Goatee
  ‚Ä¢ High_Cheekbones
  ‚Ä¢ Male
  ‚Ä¢ Mouth_Slightly_Open
  ‚Ä¢ Mustache
  ‚Ä¢ Receding_Hairline
  ‚Ä¢ Smiling
  ‚Ä¢ Wearing_Necktie
  ‚Ä¢ Young
