# Normalized Attention Guidance (NAG) — Reproducing Paper Experiments

This notebook is a **reproducible experiment driver** for the paper **"Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models"** using the authors' official implementation:

- Repo: https://github.com/ChenDarYen/Normalized-Attention-Guidance

It covers:
- Environment setup + installing the repo
- Loading NAG pipelines (Flux / SD3.5 / SDXL)
- Running **baseline vs NAG** generation on a prompt set (e.g., COCO-5K)
- Computing the paper's metrics: **CLIP Score, FID, Patch-FID (PFID), ImageReward**
- Saving outputs + aggregating results into a table

> Notes  
> - Some model IDs are **gated** on Hugging Face (Flux, etc.). You may need an HF token with access.  
> - **FID/PFID** require a reference set of real images (COCO val images). If you don't download COCO images, you can still compute CLIP Score & ImageReward.

---


## 0) (Optional) GPU sanity check

In [30]:
import torch, platform
print("Python:", platform.python_version())
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("Capability:", torch.cuda.get_device_capability(0))
    print("VRAM (GB):", round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2))


Python: 3.11.0
Torch: 2.9.1+cu128
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
Capability: (8, 0)
VRAM (GB): 39.49


In [31]:
# DEVICE CONFIGURATION - Set your device here
DEVICE = "cuda:2"

## 2) Imports + utilities

In [34]:
import os
import json
import math
import time
import random
import shutil
import hashlib
from dataclasses import dataclass, replace
from collections import defaultdict
from pathlib import Path
from typing import List, Dict, Optional, Tuple

from dotenv import load_dotenv
load_dotenv()

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
from PIL import Image

# NAG pipelines (from the repo we installed)
from nag import (
    NAGFluxPipeline,
    NAGFluxTransformer2DModel,
    NAGStableDiffusion3Pipeline,
    NAGStableDiffusionXLPipeline,
)

# SDXL helper deps (optional for some experiments)
from diffusers import UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download

def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(p: str | Path) -> Path:
    p = Path(p)
    p.mkdir(parents=True, exist_ok=True)
    return p

def save_image(img: Image.Image, path: str | Path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    img.save(path)

def sha1_of_file(path: str | Path, chunk_size: int = 1 << 20) -> str:
    h = hashlib.sha1()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()


In [None]:
HF_TOKEN = os.environ.get("HF_TOKEN", None)
print("HF token found in env:", bool(HF_TOKEN))

HF token found in env: True


## 3) Prompt set: COCO-5K loader (captions)

In [None]:
import requests

COCO_ANN_URL = "http://images.cocodataset.org/annotations/annotations_trainval2014.zip"
COCO_IMG_URL = "http://images.cocodataset.org/zips/val2014.zip"

DATA_DIR = ensure_dir("data")
COCO_DIR = ensure_dir(DATA_DIR / "coco2014")
PROMPTS_PATH = COCO_DIR / "coco5k_prompts.txt"

def _download(url: str, out_path: Path):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if out_path.exists():
        print(f"Already exists: {out_path}")
        return
    print(f"Downloading: {url}")
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(out_path, "wb") as f, tqdm(total=total, unit="B", unit_scale=True) as pbar:
            for chunk in r.iter_content(chunk_size=1<<20):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))

def _unzip(zip_path: Path, dest: Path):
    import zipfile
    print(f"Extracting {zip_path} -> {dest}")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(dest)

def prepare_coco_annotations():
    """Downloads COCO 2014 annotations and extracts captions_val2014.json."""
    zip_path = COCO_DIR / "annotations_trainval2014.zip"
    _download(COCO_ANN_URL, zip_path)
    _unzip(zip_path, COCO_DIR)
    ann_path = COCO_DIR / "annotations" / "captions_val2014.json"
    if not ann_path.exists():
        raise FileNotFoundError(f"Missing {ann_path}. Extraction may have failed.")
    return ann_path

def make_coco5k_prompts(seed: int = 0, n: int = 5000) -> List[str]:
    """Creates a fixed prompt list from COCO val2014 captions."""
    ann_path = prepare_coco_annotations()
    with open(ann_path, "r", encoding="utf-8") as f:
        d = json.load(f)
    captions = [a["caption"].strip() for a in d["annotations"]]
    # Paper uses COCO-5K; we sample deterministically.
    rng = random.Random(seed)
    rng.shuffle(captions)
    prompts = captions[:n]
    PROMPTS_PATH.write_text("\n".join(prompts), encoding="utf-8")
    print(f"Wrote {len(prompts)} prompts -> {PROMPTS_PATH}")
    return prompts

def load_prompts(path: Path = PROMPTS_PATH) -> List[str]:
    if path.exists():
        return [line.strip() for line in path.read_text(encoding='utf-8').splitlines() if line.strip()]
    print("Prompt file not found. Creating COCO-5K prompt list now...")
    return make_coco5k_prompts(seed=0, n=5000)

# Smoke test: load first 5 prompts
prompts = load_prompts()[:5]
prompts


['A MAN STANDING ON A BED PLAYING A GUITAR',
 'A man stands on the platform with his back turned to the train.',
 'A bus stopped near the sidewalk at a bus stop.',
 'A man carrying a surf board on a sandy beach next to the ocean.',
 'A large jetliner flying through a cloudy gray sky.']

## 4) (Optional) Reference images for FID/PFID (COCO val2014)

In [None]:
# WARNING: This is a large download (~6GB) and requires disk space.
# If you don't need FID/PFID, you can skip this.

VAL_IMG_DIR = COCO_DIR / "val2014"

def prepare_coco_val_images():
    zip_path = COCO_DIR / "val2014.zip"
    _download(COCO_IMG_URL, zip_path)
    _unzip(zip_path, COCO_DIR)
    if not VAL_IMG_DIR.exists():
        raise FileNotFoundError(f"Expected {VAL_IMG_DIR} after extraction.")
    print("COCO val images ready:", VAL_IMG_DIR)

# Uncomment to download:
# prepare_coco_val_images()

## 5) Experiment configs (baseline vs NAG)

In [None]:

# Universal negative prompt used in the paper's quantitative comparison.
UNIVERSAL_NEG_PROMPT = "Low resolution, blurry"

# Paper Table-5 hyperparameters (ϕ = nag_scale, τ = nag_tau, α = nag_alpha).
# You can adjust if you want to match the repo demos (they often use nag_scale=5 for Flux).
PAPER_DEFAULTS = {
    # Few-step DiT
    "flux.1-schnell": dict(nag_scale=4.0, nag_tau=2.5, nag_alpha=0.25, steps=4, guidance_scale=0.0),
    "flux.1-dev":     dict(nag_scale=4.0, nag_tau=2.5, nag_alpha=0.25, steps=25, guidance_scale=0.0),
    # Multi-step UNet (SDXL)
    "sdxl":           dict(nag_scale=2.0, nag_tau=2.5, nag_alpha=0.5, steps=8, guidance_scale=2.0),
    # SD3.5 turbo (few-step-ish)
    "sd3.5-large-turbo": dict(nag_scale=3.0, nag_tau=2.5, nag_alpha=0.25, steps=8, guidance_scale=0.0),
}

@dataclass
class RunSpec:
    name: str
    model_type: str               # "flux", "sdxl", "sd3"
    model_id: str
    steps: int
    width: int = 1024
    height: int = 1024
    guidance_scale: float = 0.0   # CFG; many distilled models use 0 or 1
    max_sequence_length: int = 256
    dtype: torch.dtype = torch.bfloat16

# Suggested experiment set (edit as needed)
RUNS: Dict[str, RunSpec] = {
    "flux_schnell_4step": RunSpec(
        name="flux_schnell_4step",
        model_type="flux",
        model_id="black-forest-labs/FLUX.1-schnell",
        steps=PAPER_DEFAULTS["flux.1-schnell"]["steps"],
        width=1024, height=1024,
        guidance_scale=PAPER_DEFAULTS["flux.1-schnell"]["guidance_scale"],
        dtype=torch.bfloat16,
    ),
    "sd3_5_turbo_8step": RunSpec(
        name="sd3_5_turbo_8step",
        model_type="sd3",
        model_id="stabilityai/stable-diffusion-3.5-large-turbo",
        steps=PAPER_DEFAULTS["sd3.5-large-turbo"]["steps"],
        width=1024, height=1024,
        guidance_scale=PAPER_DEFAULTS["sd3.5-large-turbo"]["guidance_scale"],
        dtype=torch.bfloat16,
    ),
    # SDXL example (can be heavy). In the paper they test SDXL on multi-step and also DMD2 4-step.
    # Below is a MULTI-STEP SDXL run on the base model; it may not exactly match paper's distilled checkpoints unless you load them.
    "sdxl_base_8step": RunSpec(
        name="sdxl_base_8step",
        model_type="sdxl",
        model_id="stabilityai/stable-diffusion-xl-base-1.0",
        steps=PAPER_DEFAULTS["sdxl"]["steps"],
        width=1024, height=1024,
        guidance_scale=PAPER_DEFAULTS["sdxl"]["guidance_scale"],
        dtype=torch.bfloat16,
    ),
}

RUNS


{'flux_schnell_4step': RunSpec(name='flux_schnell_4step', model_type='flux', model_id='black-forest-labs/FLUX.1-schnell', steps=4, width=1024, height=1024, guidance_scale=0.0, max_sequence_length=256, dtype=torch.bfloat16),
 'sd3_5_turbo_8step': RunSpec(name='sd3_5_turbo_8step', model_type='sd3', model_id='stabilityai/stable-diffusion-3.5-large-turbo', steps=8, width=1024, height=1024, guidance_scale=0.0, max_sequence_length=256, dtype=torch.bfloat16),
 'sdxl_base_8step': RunSpec(name='sdxl_base_8step', model_type='sdxl', model_id='stabilityai/stable-diffusion-xl-base-1.0', steps=8, width=1024, height=1024, guidance_scale=2.0, max_sequence_length=256, dtype=torch.bfloat16)}

## 6) Load pipelines

In [None]:
def load_pipeline(run: RunSpec, hf_token: Optional[str] = None, device: str = DEVICE):
    """Loads a pipeline for the given run spec."""
    if run.model_type == "flux":
        transformer = NAGFluxTransformer2DModel.from_pretrained(
            run.model_id,
            subfolder="transformer",
            torch_dtype=run.dtype,
            token=hf_token,
        )
        pipe = NAGFluxPipeline.from_pretrained(
            run.model_id,
            transformer=transformer,
            torch_dtype=run.dtype,
            token=hf_token,
        )
        pipe.to(device)
        return pipe

    if run.model_type == "sd3":
        pipe = NAGStableDiffusion3Pipeline.from_pretrained(
            run.model_id,
            torch_dtype=run.dtype,
            token=hf_token,
        )
        pipe.to(device)
        return pipe

    if run.model_type == "sdxl":
        pipe = NAGStableDiffusionXLPipeline.from_pretrained(
            run.model_id,
            torch_dtype=run.dtype,
            variant="fp16" if run.dtype in (torch.float16, torch.bfloat16) else None,
        ).to(device)
        return pipe

    raise ValueError(f"Unknown model_type: {run.model_type}")

## 7) Image generation (baseline vs NAG)

In [None]:
@torch.inference_mode()
def generate_batch(
    pipe,
    run: RunSpec,
    prompts: List[str],
    *,
    out_dir: Path,
    seed: int,
    mode: str,  # "baseline" or "nag"
    nag_negative_prompt: str = UNIVERSAL_NEG_PROMPT,
    nag_scale: float = 4.0,
    nag_tau: float = 2.5,
    nag_alpha: float = 0.25,
    batch_size: int = 1,
    device: str = DEVICE,
):
    """Generate images and save them. Returns list of saved file paths."""
    assert mode in {"baseline", "nag"}
    out_dir = ensure_dir(out_dir)

    saved = []
    set_seed(seed)

    gen = torch.Generator(device=device).manual_seed(seed)

    for i in tqdm(range(0, len(prompts), batch_size), desc=f"gen[{mode}]"):
        batch_prompts = prompts[i:i+batch_size]

        kwargs = dict(
            prompt=batch_prompts,
            num_inference_steps=run.steps,
            guidance_scale=run.guidance_scale,
            generator=gen,
        )

        # Some models (Flux) expose max_sequence_length
        if "max_sequence_length" in pipe.__call__.__code__.co_varnames:
            kwargs["max_sequence_length"] = run.max_sequence_length

        # Image size (some pipelines accept width/height; others might infer)
        if "width" in pipe.__call__.__code__.co_varnames:
            kwargs["width"] = run.width
        if "height" in pipe.__call__.__code__.co_varnames:
            kwargs["height"] = run.height

        if mode == "nag":
            kwargs.update(
                nag_negative_prompt=nag_negative_prompt,
                nag_scale=nag_scale,
                nag_tau=nag_tau,
                nag_alpha=nag_alpha,
            )

        out = pipe(**kwargs)
        images = out.images if hasattr(out, "images") else out.frames[0]

        for j, img in enumerate(images):
            idx = i + j
            fn = out_dir / f"{idx:05d}.png"
            save_image(img, fn)
            saved.append(fn)

    return saved

def get_nag_params_for_run(run: RunSpec):
    if run.model_type == "flux":
        d = PAPER_DEFAULTS["flux.1-schnell"] if "schnell" in run.model_id.lower() else PAPER_DEFAULTS["flux.1-dev"]
    elif run.model_type == "sd3":
        d = PAPER_DEFAULTS["sd3.5-large-turbo"]
    elif run.model_type == "sdxl":
        d = PAPER_DEFAULTS["sdxl"]
    else:
        raise ValueError(run.model_type)
    return d["nag_scale"], d["nag_tau"], d["nag_alpha"]


## 8) Quick smoke test (generate a few images)

In [None]:
# Choose a run:
run_key = "sdxl_base_8step"
run = RUNS[run_key]

# Small subset for smoke test
test_prompts = load_prompts()[:8]
seed = 0

pipe = load_pipeline(run, hf_token=HF_TOKEN)

nag_scale, nag_tau, nag_alpha = get_nag_params_for_run(run)

OUT_ROOT = ensure_dir("outputs") / run.name
baseline_dir = OUT_ROOT / "baseline"
nag_dir = OUT_ROOT / "nag"

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


In [None]:
baseline_paths = generate_batch(
    pipe, run, test_prompts,
    out_dir=baseline_dir,
    seed=seed,
    mode="baseline",
    batch_size=1,
)

nag_paths = generate_batch(
    pipe, run, test_prompts,
    out_dir=nag_dir,
    seed=seed,
    mode="nag",
    nag_negative_prompt=UNIVERSAL_NEG_PROMPT,
    nag_scale=nag_scale,
    nag_tau=nag_tau,
    nag_alpha=nag_alpha,
    batch_size=1,
)

baseline_paths[:2], nag_paths[:2]


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

gen[baseline]:   0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

gen[nag]:   0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

([PosixPath('outputs/sdxl_base_8step/baseline/00000.png'),
  PosixPath('outputs/sdxl_base_8step/baseline/00001.png')],
 [PosixPath('outputs/sdxl_base_8step/nag/00000.png'),
  PosixPath('outputs/sdxl_base_8step/nag/00001.png')])

## 9) Metrics: CLIP Score, FID, Patch-FID (PFID), ImageReward

In [None]:
from torchvision import transforms
from torchmetrics.multimodal import CLIPScore
from torchmetrics.image.fid import FrechetInceptionDistance

# --- CLIP Score ---
def compute_clip_score(image_paths: List[Path], prompts: List[str], device=DEVICE, clip_model="openai/clip-vit-base-patch32"):
    metric = CLIPScore(model_name_or_path=clip_model).to(device)
    # torchmetrics expects uint8 images [0, 255] as a tensor Bx3xHxW
    to_tensor = transforms.Compose([
        transforms.ToTensor(),  # 0..1 float
        transforms.Lambda(lambda x: (x * 255).to(torch.uint8)),
    ])

    scores = []
    for p, txt in tqdm(list(zip(image_paths, prompts)), desc="CLIPScore"):
        img = Image.open(p).convert("RGB")
        img_t = to_tensor(img).unsqueeze(0).to(device)
        s = metric(img_t, [txt]).item()
        scores.append(s)
    return float(np.mean(scores)), scores

# --- FID ---
def _iter_images_as_uint8_tensor(image_paths: List[Path], size=(299, 299)):
    tfm = transforms.Compose([
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x * 255).to(torch.uint8)),
    ])
    for p in image_paths:
        img = Image.open(p).convert("RGB")
        yield tfm(img)

def compute_fid(gen_paths: List[Path], real_paths: List[Path], device=DEVICE):
    fid = FrechetInceptionDistance(feature=2048, reset_real_features=True).to(device)

    for t in tqdm(_iter_images_as_uint8_tensor(real_paths), total=len(real_paths), desc="FID(real)"):
        fid.update(t.unsqueeze(0).to(device), real=True)

    for t in tqdm(_iter_images_as_uint8_tensor(gen_paths), total=len(gen_paths), desc="FID(gen)"):
        fid.update(t.unsqueeze(0).to(device), real=False)

    return float(fid.compute().item())

# --- Patch-FID (PFID) ---
def _extract_patches(img: Image.Image, grid: int = 4) -> List[Image.Image]:
    """Split image into grid x grid patches."""
    w, h = img.size
    pw, ph = w // grid, h // grid
    patches = []
    for gy in range(grid):
        for gx in range(grid):
            left, top = gx * pw, gy * ph
            patches.append(img.crop((left, top, left + pw, top + ph)))
    return patches

def compute_patch_fid(gen_paths: List[Path], real_paths: List[Path], device=DEVICE, grid: int = 4):
    fid = FrechetInceptionDistance(feature=2048, reset_real_features=True).to(device)

    # update real patches
    for p in tqdm(real_paths, desc="PFID(real)", total=len(real_paths)):
        img = Image.open(p).convert("RGB")
        for patch in _extract_patches(img, grid=grid):
            t = transforms.ToTensor()(patch)
            t = (t * 255).to(torch.uint8)
            fid.update(t.unsqueeze(0).to(device), real=True)

    # update gen patches
    for p in tqdm(gen_paths, desc="PFID(gen)", total=len(gen_paths)):
        img = Image.open(p).convert("RGB")
        for patch in _extract_patches(img, grid=grid):
            t = transforms.ToTensor()(patch)
            t = (t * 255).to(torch.uint8)
            fid.update(t.unsqueeze(0).to(device), real=False)

    return float(fid.compute().item())

# --- ImageReward ---
# def compute_imagereward(image_paths: List[Path], prompts: List[str], device=DEVICE):
#     import ImageReward as IR
#     model = IR.load("ImageReward-v1.0", device=device)
#     scores = []
#     for p, txt in tqdm(list(zip(image_paths, prompts)), desc="ImageReward"):
#         img = Image.open(p).convert("RGB")
#         s = model.score(txt, img)
#         scores.append(float(s))
#     return float(np.mean(scores)), scores


## 10) Run evaluation on your generated folders

In [None]:
def list_images_sorted(folder: Path) -> List[Path]:
    return sorted([p for p in folder.glob("*.png")])

# Use the smoke test output folders by default:
gen_baseline = list_images_sorted(baseline_dir)
gen_nag = list_images_sorted(nag_dir)

eval_prompts = test_prompts  # must match image ordering

print("n(baseline):", len(gen_baseline), "n(nag):", len(gen_nag), "n(prompts):", len(eval_prompts))

# ---- CLIP score ----
clip_mean_base, _ = compute_clip_score(gen_baseline, eval_prompts)
clip_mean_nag, _ = compute_clip_score(gen_nag, eval_prompts)

# ---- ImageReward ----
# ir_mean_base, _ = compute_imagereward(gen_baseline, eval_prompts)
# ir_mean_nag, _ = compute_imagereward(gen_nag, eval_prompts)

results = pd.DataFrame([
    dict(run=run.name, mode="baseline", clip_score=clip_mean_base,
    # image_reward=ir_mean_base
),
    dict(run=run.name, mode="nag",      clip_score=clip_mean_nag,
    # image_reward=ir_mean_nag
),
])
results


n(baseline): 8 n(nag): 8 n(prompts): 8


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIPScore:   0%|          | 0/8 [00:00<?, ?it/s]

CLIPScore:   0%|          | 0/8 [00:00<?, ?it/s]

Unnamed: 0,run,mode,clip_score
0,sdxl_base_8step,baseline,29.505333
1,sdxl_base_8step,nag,30.482667


## 11) (Optional) FID / PFID evaluation (requires COCO real images)

In [None]:

# This section requires COCO validation images downloaded in section 4.
# We align on the *same number of images* as generated (e.g., 5k).
# NOTE: COCO val2014 has 40k images; paper uses COCO-5K subset. You should define which 5k images you use.

from glob import glob

def list_coco_real_images(val_dir: Path, n: int) -> List[Path]:
    all_imgs = sorted(val_dir.glob("*.jpg"))
    if len(all_imgs) < n:
        raise ValueError(f"Not enough COCO images: {len(all_imgs)} < {n}")
    return all_imgs[:n]  # deterministic prefix; change if you want a different fixed subset


real_paths = list_coco_real_images(VAL_IMG_DIR, n=len(gen_baseline))

fid_base = compute_fid(gen_baseline, real_paths)
fid_nag = compute_fid(gen_nag, real_paths)

pfid_base = compute_patch_fid(gen_baseline, real_paths, grid=4)
pfid_nag = compute_patch_fid(gen_nag, real_paths, grid=4)

results_fid = pd.DataFrame([
    dict(run=run.name, mode="baseline", fid=fid_base, pfid=pfid_base),
    dict(run=run.name, mode="nag",      fid=fid_nag,  pfid=pfid_nag),
])
results_fid

FID(real):   0%|          | 0/8 [00:00<?, ?it/s]

FID(gen):   0%|          | 0/8 [00:00<?, ?it/s]

FID(real):   0%|          | 0/8 [00:00<?, ?it/s]

FID(gen):   0%|          | 0/8 [00:00<?, ?it/s]

PFID(real):   0%|          | 0/8 [00:00<?, ?it/s]

PFID(gen):   0%|          | 0/8 [00:00<?, ?it/s]

PFID(real):   0%|          | 0/8 [00:00<?, ?it/s]

PFID(gen):   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
results_fid

Unnamed: 0,run,mode,fid,pfid
0,sdxl_base_8step,baseline,353.235657,321.943604
1,sdxl_base_8step,nag,350.996185,324.41391


## 12) Full experiment runner (COCO-5K)

In [None]:

def run_experiment(
    run: RunSpec,
    *,
    prompts: List[str],
    out_root: Path,
    seed: int = 0,
    hf_token: Optional[str] = None,
    batch_size: int = 1,
    nag_negative_prompt: str = UNIVERSAL_NEG_PROMPT,
):
    out_root = ensure_dir(out_root)
    pipe = load_pipeline(run, hf_token=hf_token)

    nag_scale, nag_tau, nag_alpha = get_nag_params_for_run(run)

    baseline_dir = out_root / "baseline"
    nag_dir = out_root / "nag"

    base_paths = generate_batch(
        pipe, run, prompts,
        out_dir=baseline_dir,
        seed=seed,
        mode="baseline",
        batch_size=batch_size,
    )

    nag_paths = generate_batch(
        pipe, run, prompts,
        out_dir=nag_dir,
        seed=seed,
        mode="nag",
        nag_negative_prompt=nag_negative_prompt,
        nag_scale=nag_scale,
        nag_tau=nag_tau,
        nag_alpha=nag_alpha,
        batch_size=batch_size,
    )

    # Always compute CLIP + ImageReward (no ref needed)
    clip_base, _ = compute_clip_score(base_paths, prompts)
    clip_nag, _ = compute_clip_score(nag_paths, prompts)

    ir_base, _ = compute_imagereward(base_paths, prompts)
    ir_nag, _ = compute_imagereward(nag_paths, prompts)

    metrics = pd.DataFrame([
        dict(run=run.name, mode="baseline", clip_score=clip_base, image_reward=ir_base),
        dict(run=run.name, mode="nag",      clip_score=clip_nag,  image_reward=ir_nag),
    ])

    # Optionally compute FID/PFID if real images exist
    if VAL_IMG_DIR.exists():
        real_paths = list_coco_real_images(VAL_IMG_DIR, n=len(prompts))
        fid_base = compute_fid(base_paths, real_paths)
        fid_nag = compute_fid(nag_paths, real_paths)
        pfid_base = compute_patch_fid(base_paths, real_paths, grid=4)
        pfid_nag = compute_patch_fid(nag_paths, real_paths, grid=4)
        metrics.loc[metrics["mode"]=="baseline", "fid"] = fid_base
        metrics.loc[metrics["mode"]=="nag", "fid"] = fid_nag
        metrics.loc[metrics["mode"]=="baseline", "pfid"] = pfid_base
        metrics.loc[metrics["mode"]=="nag", "pfid"] = pfid_nag

    metrics.to_csv(out_root / "metrics.csv", index=False)
    return metrics

# Example: run on 128 prompts first, then scale up
# prompts_5k = load_prompts()[:5000]
# metrics = run_experiment(RUNS["flux_schnell_4step"], prompts=prompts_5k, out_root=Path("outputs")/RUNS["flux_schnell_4step"].name, hf_token=HF_TOKEN, batch_size=1)
# metrics


## 14) Tips for matching the paper more closely


- Use the **same prompt list** and **same real-image subset** when reporting FID/PFID.  
  In this notebook we create a deterministic prompt list (seeded shuffle) and select a deterministic prefix of COCO val images — but you should align these with the paper's exact protocol if you have it.

- Use the **paper hyperparameters** (ϕ, τ, α) from Table-5:
  - Flux: ϕ=4, τ=2.5, α=0.25  
  - SDXL: ϕ=2, τ=2.5, α=0.25  
  - SD3.5: ϕ=3, τ=2.5, α=0.25  

- Many few-step models are designed for **guidance_scale=0 or 1**; CFG can degrade them.  
  NAG is intended to restore negative prompting without requiring CFG.

- For large-scale runs (COCO-5K), consider:
  - `batch_size=1` for safety (VRAM)
  - saving intermediate progress
  - running on multiple GPUs / nodes with prompt sharding

---


In [36]:
@dataclass(frozen=True)
class PipelineConfig:
    name: str
    model_type: str
    model_id: str
    steps: int # Number of inference steps
    guidance_scale: float = 1 # Guidance scale
    nag_scale: float = 1 # NAG scale
    nag_alpha: float = 0.125 # NAG alpha
    nag_tau: float = 2.5 # NAG tau
    max_sequence_length: int = 256 # Maximum sequence length
    width: int = 1024 # Image width
    height: int = 1024 # Image height
    seed: int = 42 # Random seed
    dtype: torch.dtype = torch.bfloat16 # Data type

In [None]:
sd3_config = PipelineConfig(
    name="sd3_5_turbo_8step",
    model_type="sd3",
    model_id="stabilityai/stable-diffusion-3.5-large-turbo",
    steps=8,
    guidance_scale=3,
    nag_scale=4,
    nag_alpha=0.125,
    nag_tau=2.5,
)

sdxl_config = PipelineConfig(
    name="sdxl_base_8step",
    model_type="sdxl",
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    steps=8,
    guidance_scale=4,
    nag_scale=2,
    nag_alpha=0.5,
    nag_tau=2.5,
)

## Exp: changing nag scale $\Phi$

In [None]:
nag_scale_arr = [0, 2.5, 5, 7.5, 10, 15, 20]

n_prompts = 100
nag_neg_prompt = UNIVERSAL_NEG_PROMPT

run = replace(sdxl_config, name="sdxl_base_8step", guidance_scale=0.0)

test_prompts = load_prompts()[:n_prompts]

OUT_ROOT = ensure_dir("outputs") / run.name
exp_dir = OUT_ROOT / "exp_nag_scale"

real_paths = list_coco_real_images(VAL_IMG_DIR, n=n_prompts)

try:
    del pipe
    torch.cuda.empty_cache()
except NameError:
    pass

pipe = load_pipeline(run, hf_token=HF_TOKEN)

results_1 = defaultdict(list)
for i, nag_scale in enumerate(nag_scale_arr):
    print(f"Running with NAG scale: {nag_scale}")
    
    nag_paths = generate_batch(
        pipe, run, test_prompts,
        out_dir=exp_dir,
        seed=seed,
        mode="nag",
        nag_negative_prompt=nag_neg_prompt,
        nag_scale=nag_scale,
        nag_tau=nag_tau,
        nag_alpha=nag_alpha,
        batch_size=1,
    )

    gen_images = list_images_sorted(exp_dir)

    clip_mean, _ = compute_clip_score(gen_images, test_prompts)

    fid_base = compute_fid(gen_images, real_paths)
    pfid_base = compute_patch_fid(gen_images, real_paths, grid=4)

    results_1["nag_scale"].append(nag_scale)
    results_1["clip_mean"].append(clip_mean)
    results_1["fid_base"].append(fid_base)
    results_1["pfid_base"].append(pfid_base)

    print(f"NAG scale: {nag_scale}, CLIP score: {clip_mean}, FID: {fid_base}, PFID: {pfid_base}")



In [35]:
results

defaultdict(list,
            {'nag_scale': [0, 2.5, 5, 7.5, 10, 15],
             'clip_mean': [24.902287969589235,
              25.268795852661132,
              25.284650869369507,
              25.30036262512207,
              25.309369144439696,
              25.285631084442137],
             'fid_base': [321.5397644042969,
              319.117431640625,
              318.7151184082031,
              318.9698791503906,
              318.3301696777344,
              318.1114196777344],
             'pfid_base': [237.57139587402344,
              234.54177856445312,
              234.5036163330078,
              234.3966522216797,
              234.47134399414062,
              234.61070251464844]})

## Exp: changing nag scale $\Phi$ w/o Refine

In [None]:
nag_scale_arr = [0, 2.5, 5, 7.5, 10, 15, 20]

n_prompts = 100
nag_neg_prompt = UNIVERSAL_NEG_PROMPT

run = replace(sdxl_config, name="sdxl_base_8step", guidance_scale=0.0, nag_alpha=0.0)

test_prompts = load_prompts()[:n_prompts]

OUT_ROOT = ensure_dir("outputs") / run.name
exp_dir = OUT_ROOT / "exp_nag_scale"

real_paths = list_coco_real_images(VAL_IMG_DIR, n=n_prompts)

try:
    del pipe
    torch.cuda.empty_cache()
except NameError:
    pass

pipe = load_pipeline(run, hf_token=HF_TOKEN)

results_2 = defaultdict(list)
for i, nag_scale in enumerate(nag_scale_arr):
    print(f"Running with NAG scale: {nag_scale}")
    
    nag_paths = generate_batch(
        pipe, run, test_prompts,
        out_dir=exp_dir,
        seed=seed,
        mode="nag",
        nag_negative_prompt=nag_neg_prompt,
        nag_scale=nag_scale,
        nag_tau=nag_tau,
        nag_alpha=nag_alpha,
        batch_size=1,
    )

    gen_images = list_images_sorted(exp_dir)

    clip_mean, _ = compute_clip_score(gen_images, test_prompts)

    fid_base = compute_fid(gen_images, real_paths)
    pfid_base = compute_patch_fid(gen_images, real_paths, grid=4)

    results_2["nag_scale"].append(nag_scale)
    results_2["clip_mean"].append(clip_mean)
    results_2["fid_base"].append(fid_base)
    results_2["pfid_base"].append(pfid_base)

    print(f"NAG scale: {nag_scale}, CLIP score: {clip_mean}, FID: {fid_base}, PFID: {pfid_base}")



## Exp: changing nag scale $\Phi$ w/o Refine & Norm

In [None]:
nag_scale_arr = [0, 2.5, 5, 7.5, 10, 15, 20]

n_prompts = 100
nag_neg_prompt = UNIVERSAL_NEG_PROMPT

run = replace(sdxl_config, name="sdxl_base_8step",
    guidance_scale=0.0, nag_alpha=0.0, nag_tau=1000)

test_prompts = load_prompts()[:n_prompts]

OUT_ROOT = ensure_dir("outputs") / run.name
exp_dir = OUT_ROOT / "exp_nag_scale"

real_paths = list_coco_real_images(VAL_IMG_DIR, n=n_prompts)

try:
    del pipe
    torch.cuda.empty_cache()
except NameError:
    pass

pipe = load_pipeline(run, hf_token=HF_TOKEN)

results_3 = defaultdict(list)
for i, nag_scale in enumerate(nag_scale_arr):
    print(f"Running with NAG scale: {nag_scale}")
    
    nag_paths = generate_batch(
        pipe, run, test_prompts,
        out_dir=exp_dir,
        seed=seed,
        mode="nag",
        nag_negative_prompt=nag_neg_prompt,
        nag_scale=nag_scale,
        nag_tau=nag_tau,
        nag_alpha=nag_alpha,
        batch_size=1,
    )

    gen_images = list_images_sorted(exp_dir)

    clip_mean, _ = compute_clip_score(gen_images, test_prompts)

    fid_base = compute_fid(gen_images, real_paths)
    pfid_base = compute_patch_fid(gen_images, real_paths, grid=4)

    results_3["nag_scale"].append(nag_scale)
    results_3["clip_mean"].append(clip_mean)
    results_3["fid_base"].append(fid_base)
    results_3["pfid_base"].append(pfid_base)

    print(f"NAG scale: {nag_scale}, CLIP score: {clip_mean}, FID: {fid_base}, PFID: {pfid_base}")

