# Evaluation Notebook: SD‑1.5 + ControlNet (Canny) + IP‑Adapter (Colour Head)

This notebook reproduces the `evaluate.py` functionality in Jupyter form.

**What it does**
1. Loads your manifest, CLIP embeddings `[N,1024]`, and colour histograms `[N,K]` (RGB512 / Lab514 / HCL514).
2. Builds SD‑1.5 + ControlNet(Canny) and your forked IP‑Adapter that accepts a **custom colour embedding**.
3. Runs the **2×2 ablation** (Text / +Edges / +Colour / +Both) and the **scale grid** (IP scale × CN scale).
4. Computes metrics: **Colour‑EMD (Sinkhorn)**, **neutral‑bin errors** (Lab/HCL), **Edge‑F1**, **Edge‑SSIM**, **CLIP‑Score**, **latency**.
5. Saves images, montages, and a per‑image **metrics CSV**.

> **Note:** This expects your **forked `ip_adapter`** module to be importable and to support `embedding_type='custom'` so you can pass the Colour Head embedding directly.


## 0) Prerequisites (install as needed)
Uncomment and run the lines below if your environment is missing dependencies.

```bash
# pip install --upgrade pip
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121  # adapt to your CUDA
# pip install diffusers accelerate transformers safetensors
# pip install open-clip-torch
# pip install geomloss scikit-image opencv-python pandas matplotlib tqdm
```

Make sure your project folder (with the forked `ip_adapter` and your `models` package) is on `sys.path`.


In [ ]:
import os, sys, time, math, json
from types import SimpleNamespace
from pathlib import Path
import itertools
import numpy as np
import pandas as pd
from PIL import Image, ImageOps, ImageDraw

import torch
import torch.nn.functional as F

import cv2
from skimage.metrics import structural_similarity as ssim

# If your project repo is not in sys.path, add it here, e.g.:
# sys.path.append('/data/thesis/repo')

try:
    from ip_adapter import IPAdapter  # your fork
except Exception as e:
    print("⚠️ Could not import ip_adapter. Add your repo to sys.path.")
    raise

try:
    from models.color_heads import ColorHead
except Exception:
    # fallback to alternative path if your module layout differs
    try:
        from color_heads import ColorHead
    except Exception as e:
        print("⚠️ Could not import ColorHead. Adjust the import to your repo layout.")
        raise

try:
    from geomloss import SamplesLoss
    HAVE_GEOMLOSS = True
except Exception:
    HAVE_GEOMLOSS = False

import open_clip
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline


## 1) Configuration
Fill in your paths and knobs below. You can run multiple cells with different configs to generate into separate output folders.

In [ ]:
cfg = SimpleNamespace(
    # Data
    manifest_csv = "/data/thesis/laion_5m_manifest.csv",  # must have file_path column (or local_path)
    embeddings_npy = "/data/degis/embeddings.npy",        # [N,1024]
    hists_npy = "/data/degis/hcl514.npy",                 # [N,514] for HCL / Lab; [N,512] for RGB
    color_space = "hcl514",                                # one of: rgb512 | lab514 | hcl514
    color_head_ckpt = "/data/degis/best_color_head.pth",  # trained head matching hist dim
    prompts_csv = "/data/degis/prompts_25.csv",           # columns: idx,prompt

    # Models
    sd_id = "runwayml/stable-diffusion-v1-5",
    controlnet_id = "lllyasviel/control_v11p_sd15_canny",
    ip_ckpt = "/data/thesis/models/ip-adapter_sd15.bin",
    image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
    hf_cache = None,  # or a path

    # Sweeps
    ip_scales = [0.25, 0.5, 0.75, 1.0],
    cn_scales = [0.8, 1.2, 1.6, 2.0],
    cfg_scale = 7.5,
    steps = 50,
    seed = 123,

    # Edges
    canny_low = 100,
    canny_high = 200,
    edge_size = 512,

    # Output
    outdir = "/data/degis/eval_hcl514_notebook",
    limit = 0,  # 0 => all rows in prompts_csv
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg, device

## 2) Helpers: IO, histograms, metrics, montages

In [ ]:
def ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_manifest(manifest_csv):
    df = pd.read_csv(manifest_csv)
    if "file_path" not in df.columns and "local_path" in df.columns:
        df = df.rename(columns={"local_path": "file_path"})
    assert "file_path" in df.columns, "manifest must have a 'file_path' column"
    return df

# --- histograms ---
def rgb_hist(img_np, bins=8):
    hist, _ = np.histogramdd(
        img_np.reshape(-1, 3), bins=(bins,bins,bins), range=((0,256),(0,256),(0,256))
    )
    h = hist.flatten().astype(np.float32)
    h /= (h.sum() + 1e-8)
    return h

def lab_hist(img_np, bins=8):
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2Lab)
    L,a,b = cv2.split(lab)
    neutral_th = 5
    is_black = (L < 25) & (np.abs(a-128)<neutral_th) & (np.abs(b-128)<neutral_th)
    is_white = (L > 230) & (np.abs(a-128)<neutral_th) & (np.abs(b-128)<neutral_th)
    hist = cv2.calcHist([lab],[0,1,2],None,[bins,bins,bins],[0,256,0,256,0,256]).flatten()
    black_count = int(is_black.sum()); white_count = int(is_white.sum())
    total = hist.sum() + black_count + white_count + 1e-8
    h = np.append(hist, [black_count, white_count]).astype(np.float32) / total
    return h

def hcl_hist(img_np, bins=8, c_max=150.0):
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2Lab).astype(np.float32)
    L,a8,b8 = cv2.split(lab)
    a = a8 - 128.0; b = b8 - 128.0
    C = np.sqrt(a**2 + b**2)
    H = (np.degrees(np.arctan2(b,a)) + 360.0) % 360.0
    neutral_th = 5.0
    is_black = (L < 25) & (C < neutral_th)
    is_white = (L > 230) & (C < neutral_th)
    coords = np.stack([L.flatten(), C.flatten(), H.flatten()], axis=-1)
    hist, _ = np.histogramdd(coords, bins=(bins,bins,bins), range=((0,256),(0,c_max),(0,360)))
    hist = hist.flatten()
    black_count = int(is_black.sum()); white_count = int(is_white.sum())
    total = hist.sum() + black_count + white_count + 1e-8
    h = np.append(hist, [black_count, white_count]).astype(np.float32) / total
    return h

def compute_histogram(pil_img, color_space, bins=8):
    np_img = np.array(pil_img.resize((256,256)))
    if color_space == "rgb512":
        return rgb_hist(np_img, bins=bins)
    elif color_space == "lab514":
        return lab_hist(np_img, bins=bins)
    elif color_space == "hcl514":
        return hcl_hist(np_img, bins=bins)
    else:
        raise ValueError(f"unknown color_space {color_space}")

# --- metrics ---
def sinkhorn_emd(h1, h2, blur=0.05):
    if not HAVE_GEOMLOSS:
        return float(np.abs(h1 - h2).mean())
    D = h1.shape[0]
    x = torch.arange(D, dtype=torch.float32, device="cpu").view(D,1)
    a = torch.tensor(h1, dtype=torch.float32).view(1, D)
    b = torch.tensor(h2, dtype=torch.float32).view(1, D)
    loss = SamplesLoss("sinkhorn", p=2, blur=blur, backend="tensorized")
    return float(loss(x, x, a, b).item())

def neutral_bin_errors(h_gen, h_tgt, color_space):
    if color_space in ("lab514","hcl514"):
        black_err = float(abs(h_gen[-2] - h_tgt[-2]))
        white_err = float(abs(h_gen[-1] - h_tgt[-1]))
        return black_err, white_err
    return math.nan, math.nan

def edge_map(pil_img, size=512, low=100, high=200):
    img = np.array(pil_img.resize((size,size)).convert("L"))
    e = cv2.Canny(img, low, high)
    return (e > 0).astype(np.uint8)

def edge_f1(target_edges, pred_edges, tol=1):
    kernel = np.ones((2*tol+1,2*tol+1), np.uint8)
    target_dil = cv2.dilate(target_edges, kernel, iterations=1)
    tp = np.logical_and(pred_edges, target_dil).sum()
    fp = np.logical_and(pred_edges, ~target_dil).sum()
    fn = np.logical_and(~pred_edges, target_edges).sum()
    prec = tp / max(1, tp+fp); rec = tp / max(1, tp+fn)
    f1 = 0.0 if (prec+rec)==0 else 2*prec*rec/(prec+rec)
    return float(f1), float(prec), float(rec)

def edge_ssim(target_edges, pred_edges):
    a = target_edges.astype(np.float32)
    b = pred_edges.astype(np.float32)
    return float(ssim(a, b, data_range=1.0))

def clip_score_openclip(model, preprocess, device, prompt, pil_img):
    img_t = preprocess(pil_img).unsqueeze(0).to(device)
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
        im = model.encode_image(img_t)
    im = F.normalize(im.float(), dim=-1)
    with torch.no_grad():
        txt = open_clip.tokenize([prompt]).to(device)
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            tt = model.encode_text(txt)
    tt = F.normalize(tt.float(), dim=-1)
    return float((im @ tt.T).squeeze().item())

def control_edge_from_path(path, size=512):
    pil = Image.open(path).convert("RGB")
    gray = ImageOps.grayscale(pil)
    gray = ImageOps.autocontrast(gray)
    return gray.resize((size,size), Image.BILINEAR).convert("RGB")

def palette_bar_figure(h, outpath, title=""):
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure(figsize=(4,1.2))
    plt.bar(np.arange(len(h)), h)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(outpath, dpi=160)
    plt.close()

def save_ablation_montage(imgs, labels, bars_pngs, outpath, cell=512, pad=12):
    cols = 2; rows = 2
    W = cols*cell + (cols+1)*pad
    H = rows*(cell+90) + (rows+1)*pad
    canvas = Image.new("RGB", (W,H), (245,245,245))
    draw = ImageDraw.Draw(canvas)
    for i,(im,lbl,barpng) in enumerate(zip(imgs, labels, bars_pngs)):
        r = i//cols; c = i%cols
        x = pad + c*(cell+pad); y = pad + r*(cell+90+pad)
        canvas.paste(im.resize((cell,cell), Image.LANCZOS), (x,y))
        draw.text((x, y+cell+4), lbl, fill=(0,0,0))
        if barpng and os.path.exists(barpng):
            bar = Image.open(barpng).convert("RGB").resize((cell,80), Image.BILINEAR)
            canvas.paste(bar, (x, y+cell+20))
    canvas.save(outpath)


## 3) Build pipelines and models

In [ ]:
def build_pipelines(sd_id, controlnet_id, device, hf_cache=None):
    cn = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16, cache_dir=hf_cache)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        sd_id, controlnet=cn, torch_dtype=torch.float16,
        safety_checker=None, feature_extractor=None, cache_dir=hf_cache
    ).to(device)
    pipe.controlnet = pipe.controlnet.to(dtype=torch.float16)
    return pipe

def build_ip_adapter(pipe, ip_ckpt, image_encoder_path, device, num_tokens=4, embedding_type="custom"):
    return IPAdapter(
        sd_pipe=pipe, image_encoder_path=image_encoder_path,
        ip_ckpt=ip_ckpt, device=device, num_tokens=num_tokens,
        embedding_type=embedding_type,
    )

def build_color_head(ckpt_path, clip_dim, hist_dim, device):
    head = ColorHead(clip_dim=clip_dim, hist_dim=hist_dim).to(device).eval()
    head.load_state_dict(torch.load(ckpt_path, map_location=device))
    return head

def controlnet_knobs(scale):
    return dict(
        controlnet_conditioning_scale=scale,
        control_guidance_start=0.0,
        control_guidance_end=1.0,
    )

## 4) Load data & models

In [ ]:
df_manifest = load_manifest(cfg.manifest_csv)
df_prompts = pd.read_csv(cfg.prompts_csv)
if cfg.limit and len(df_prompts) > cfg.limit:
    df_prompts = df_prompts.head(cfg.limit).copy()

emb = np.load(cfg.embeddings_npy, mmap_mode="r")
hists = np.load(cfg.hists_npy, mmap_mode="r")
N, clip_dim = emb.shape
hist_dim = hists.shape[1]
assert N == len(df_manifest), "embeddings/hists must align with manifest rows"

pipe = build_pipelines(cfg.sd_id, cfg.controlnet_id, device, hf_cache=cfg.hf_cache)
ip_adapter = build_ip_adapter(pipe, cfg.ip_ckpt, cfg.image_encoder_path, device, embedding_type="custom")
color_head = build_color_head(cfg.color_head_ckpt, clip_dim=clip_dim, hist_dim=hist_dim, device=device)

# CLIPScore model
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    "ViT-H-14", pretrained="laion2b_s32b_b79k", cache_dir=cfg.hf_cache
)
clip_model = clip_model.to(device)
if device.type == "cuda":
    clip_model = clip_model.half()
clip_model.eval()

print("Loaded:")
print("  manifest:", len(df_manifest), "rows")
print("  prompts:", len(df_prompts), "rows")
print("  emb shape:", emb.shape, " hist shape:", hists.shape)

## 5) Evaluation loop (ablation + scale grid)

In [ ]:
def evaluate(cfg, df_manifest, df_prompts, emb, hists, device):
    outdir = Path(cfg.outdir)
    img_dir = outdir / "imgs"; fig_dir = outdir / "figs"; csv_dir = outdir / "csv"
    for d in (img_dir, fig_dir, csv_dir):
        ensure_dir(d)

    rows = []

    for ridx, row in df_prompts.iterrows():
        idx = int(row["idx"])  # must map to manifest row
        prompt = str(row["prompt"])

        src_path = df_manifest.iloc[idx]["file_path"]
        tgt_hist = hists[idx].astype(np.float32)

        z = torch.tensor(emb[idx], dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            _, _, c_emb = color_head(z)
        c_emb = c_emb.float()
        control_pil = control_edge_from_path(src_path, size=cfg.edge_size)

        # 2x2 ablation
        ablation = [
            ("Text only", 0.0, None),
            ("+Edges", 0.0, 1.6),
            ("+Colour", 0.75, None),
            ("+Both", 0.75, 1.6),
        ]
        ablation_imgs, ablation_labels, ablation_bars = [], [], []
        for label, ip_s, cn_s in ablation:
            ip_adapter.set_scale(ip_s)
            kw = {}
            if cn_s is not None:
                kw.update(dict(image=control_pil))
                kw.update(controlnet_knobs(cn_s))
            t0 = time.time()
            imgs = ip_adapter.generate_from_embeddings(
                clip_image_embeds=c_emb.half(),
                prompt=prompt, negative_prompt=None,
                scale=ip_s, num_samples=1, seed=cfg.seed,
                guidance_scale=cfg.cfg_scale, num_inference_steps=cfg.steps,
                **kw
            )
            latency = time.time() - t0
            gen = imgs[0].convert("RGB")

            gen_hist = compute_histogram(gen, cfg.color_space)
            emd = sinkhorn_emd(gen_hist, tgt_hist)
            b_err, w_err = neutral_bin_errors(gen_hist, tgt_hist, cfg.color_space)
            pred_edges = edge_map(gen, size=cfg.edge_size, low=cfg.canny_low, high=cfg.canny_high)
            targ_edges = edge_map(control_pil, size=cfg.edge_size, low=cfg.canny_low, high=cfg.canny_high)
            f1, prec, rec = edge_f1(targ_edges, pred_edges, tol=1)
            ess = edge_ssim(targ_edges, pred_edges)
            cs = clip_score_openclip(clip_model, clip_preprocess, device, prompt, gen)

            out_name = f"{ridx:03d}_{label.replace('+','plus').replace(' ','_')}.png"
            gen_path = img_dir / out_name
            gen.save(gen_path)

            bar_tgt = fig_dir / f"{ridx:03d}_tgtbar.png"
            bar_gen = fig_dir / f"{ridx:03d}_{label.replace('+','plus').replace(' ','_')}_bar.png"
            if not os.path.exists(bar_tgt):
                palette_bar_figure(tgt_hist, bar_tgt, title="Target palette")
            palette_bar_figure(gen_hist, bar_gen, title="Generated palette")

            ablation_imgs.append(gen)
            ablation_labels.append(f"{label}\nEMD={emd:.4f}  F1={f1:.3f}  CLIP={cs:.3f}")
            ablation_bars.append(str(bar_gen))

            rows.append(dict(
                prompt_idx=idx, row_id=ridx, mode=label, ip_scale=ip_s,
                cn_scale=(cn_s if cn_s is not None else 0.0),
                cfg=cfg.cfg_scale, steps=cfg.steps, latency_s=latency,
                color_emd=emd, neutral_black_err=b_err, neutral_white_err=w_err,
                edge_f1=f1, edge_prec=prec, edge_rec=rec, edge_ssim=ess,
                clip_score=cs, image_path=str(gen_path)
            ))

        montage_path = fig_dir / f"{ridx:03d}_ablation_montage.png"
        save_ablation_montage(ablation_imgs, ablation_labels, [None]+ablation_bars[1:], montage_path)

        # scale grid
        for ip_s, cn_s in itertools.product(cfg.ip_scales, cfg.cn_scales):
            ip_adapter.set_scale(ip_s)
            kw = dict(image=control_pil)
            kw.update(controlnet_knobs(cn_s))
            t0 = time.time()
            imgs = ip_adapter.generate_from_embeddings(
                clip_image_embeds=c_emb.half(),
                prompt=prompt, negative_prompt=None,
                scale=ip_s, num_samples=1, seed=cfg.seed,
                guidance_scale=cfg.cfg_scale, num_inference_steps=cfg.steps,
                **kw
            )
            latency = time.time() - t0
            gen = imgs[0].convert("RGB")

            gen_hist = compute_histogram(gen, cfg.color_space)
            emd = sinkhorn_emd(gen_hist, tgt_hist)
            b_err, w_err = neutral_bin_errors(gen_hist, tgt_hist, cfg.color_space)
            pred_edges = edge_map(gen, size=cfg.edge_size, low=cfg.canny_low, high=cfg.canny_high)
            targ_edges = edge_map(control_pil, size=cfg.edge_size, low=cfg.canny_low, high=cfg.canny_high)
            f1, prec, rec = edge_f1(targ_edges, pred_edges, tol=1)
            ess = edge_ssim(targ_edges, pred_edges)
            cs = clip_score_openclip(clip_model, clip_preprocess, device, prompt, gen)

            out_name = f"{ridx:03d}_grid_ip{ip_s}_cn{cn_s}.png".replace(".","p")
            gen_path = img_dir / out_name
            gen.save(gen_path)

            rows.append(dict(
                prompt_idx=idx, row_id=ridx, mode="grid",
                ip_scale=ip_s, cn_scale=cn_s, cfg=cfg.cfg_scale, steps=cfg.steps,
                latency_s=latency, color_emd=emd,
                neutral_black_err=b_err, neutral_white_err=w_err,
                edge_f1=f1, edge_prec=prec, edge_rec=rec, edge_ssim=ess,
                clip_score=cs, image_path=str(gen_path)
            ))

    out_csv = Path(cfg.outdir)/"csv"/"metrics.csv"
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    return out_csv


## 6) Run
This will generate images, montages, and a CSV into `cfg.outdir`.

In [ ]:
ensure_dir(cfg.outdir)
csv_path = evaluate(cfg, df_manifest, df_prompts, emb, hists, device)
print("✅ Done. CSV:", csv_path)
print("Images:", Path(cfg.outdir)/"imgs")
print("Figures:", Path(cfg.outdir)/"figs")

## 7) (Optional) Peek at a montage inline

In [ ]:
from IPython.display import display
figs_dir = Path(cfg.outdir)/"figs"
montages = sorted(figs_dir.glob("*_ablation_montage.png"))
if montages:
    display(Image.open(montages[0]))
else:
    print("No montages yet.")