<p align="center">
  <img src="docs/inference_sequence_diagram.png" alt="inferencing diagram" width="1200"/>
</p>

# TITLE & OVERVIEW

This notebook runs DEGIS image generation with two modes: Style (Base Ip-Adapter) and Sinkhorn-constrained (DEGIS Approach) color alignment.  
The notebook builds a hypothetical scenario where a car dealer, interested in creating advertisements, that follows predefined color and layout.
We evaluate three edge map and for each we use three color sources.  

Each evaluation section shows: the source color image, the chosen edge map, a style result, and an Sinkhorn-constrained result with metrics.

## ASSETS EXPECTED
- `edge_maps` and a trained `color_head`.  
- A ready `generator` pipeline.

## OUTPUTS OF INTEREST
- One visualization image per run.  
- Metrics: **Sinkhorn** (lower is better), **CLIP-text cosine** (higher is better), **attempts**, and **generation time**.


In [None]:
# =============================================================================
# ENVIRONMENT CHECK - Run this cell first!
# =============================================================================
# This cell verifies that the DEGIS environment is properly set up
# Make sure you've run ./setup_server_fixed.sh first!

import sys
import os

# Check if we're in the right environment
if 'degis-env' in sys.executable:
    print("DEGIS environment is active")
    print(f"Python: {sys.executable}")
else:
    print("Warning: DEGIS environment not detected")
    print("Please run: ./setup_server_fixed.sh")
    print("Then activate: source degis-env/bin/activate")

# Check if DEGIS package is available
try:
    import degis
    print("DEGIS package is available")
except ImportError:
    print("DEGIS package not found")
    print("Please run: ./setup_server_fixed.sh")

print("\nReady to start image generation!")

# IMPORTS - CONTEXT

Loads core libs (torch, numpy, PIL, matplotlib, torchvision), DEGIS dataset helpers, SDXL + ControlNet with DEGIS patches, and histogram utilities.  
Prints the active device (`cuda` or `cpu`). No parameters are changed here.

In [None]:
import geomloss
import numpy as np
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision import transforms
from IPython.display import display
import os
import glob

# Import DEGIS package components
from degis import IPAdapterXLGenerator, load_trained_color_head, get_color_embedding
from degis.shared.utils import plot_color_palette, display_images_grid, display_comparison_grid
from degis.shared.image_features.color_histograms import compute_color_histogram, compute_lab_histogram
from degis.shared.utils.image_utils import create_control_edge_pil
from degis.data.dataset import UnifiedImageDataset

# Import IP-Adapter XL with DEGIS patches
import ip_adapter_patch  # This applies the DEGIS patches
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline

print("All imports successful!")
print("Ready for high-quality image generation with IP-Adapter XL!")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# PATHS & GLOBAL CONFIG

CSV manifest, edge maps `.npy`, and the **color head checkpoint** paths.  
Define the **image encoder** and **ControlNet** identifiers.

In [None]:
csv_path = "data/laion_5m_manifest.csv"
color_head_checkpoint_path = "evaluation_runs/laion_5m_xl_lab514_tk20_b4096/best_color_head_tmp.pth"
precomputed_adimagenet_edge_maps_path = "data/adimagenet_edge_maps.npy"

image_encoder_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
controlnet_id = "diffusers/controlnet-canny-sdxl-1.0"

# LOAD DATASETS & ARRAYS - EXPECTED SHAPES

Loads the thumbnail dataset (for selecting `color_index`), memory-maps large arrays for **embeddings**, **histograms**, and **edge maps**, and loads the trained **color head**.  
Expect printed shapes for sanity checking.


In [None]:
from degis.inference import load_trained_color_head
# Load datasets
df = pd.read_csv(csv_path)
color_dataset = UnifiedImageDataset(
    df.rename(columns={"local_path": "file_path"}),
    mode="url_df",
    size=(224, 224),
    subset_ratio=1.0
)

# Load precomputed data
edge_maps = np.load(precomputed_adimagenet_edge_maps_path, mmap_mode="r")

print(f"Loaded edge maps: {edge_maps.shape}")

# Load trained color head
color_head = load_trained_color_head(
    checkpoint_path=color_head_checkpoint_path,
    clip_dim=1280,
    hist_dim=514,
    device=device
)
print("Color head loaded successfully")


# HF CACHE SETUP & INITIALIZE PIPELINE

Sets local caches for Hugging Face components and initializes the IP-Adapter XL pipeline.  
First run may download weights; reruns use the cache.

In [None]:
# Setup cache directory
HF_CACHE = "/data/hf-cache" if os.path.exists("/data") else "./hf-cache"
os.makedirs(HF_CACHE, exist_ok=True)

os.environ["HF_HOME"] = HF_CACHE
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(HF_CACHE, "hub")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_CACHE, "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(HF_CACHE, "diffusers")
os.environ["TORCH_HOME"] = os.path.join(HF_CACHE, "torch")

print(f"Using cache directory: {HF_CACHE}")

# Create IP-Adapter XL generator
generator = IPAdapterXLGenerator(device=device)

# Setup the pipeline
generator.setup_pipeline(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    controlnet_id=controlnet_id,
    ip_ckpt=None,  # Set to None to use default IP-Adapter weights
    image_encoder_path=image_encoder_path,
    cache_dir=HF_CACHE,
    torch_dtype=torch.float16,
)

print("IP-Adapter XL pipeline setup complete")


# GENERATION FUNCTION - STYLE VS Sinkhorn

Reusable function to: clear VRAM, build control image, compute target histogram, run **Style** generation, run **Sinkhorn-constrained** generation, compute metrics, and render a composite visualization.


In [None]:
import torch, gc
from degis.inference import generate_by_style, generate_by_color_sinkhorn_constrained
from degis.shared.utils.visualization import visualize_generation_comparison, create_generation_metrics
from degis.shared.utils.image_utils import create_control_edge_pil
from degis.inference.core_generation import get_color_embedding

def generate_comparison_snippet(
    color_index: int,
    layout_index: int,
    prompt: str = "a cat playing with a ball",
    guidance_scale: float = 6.5,
    steps: int = 40,
    controlnet_conditioning_scale: float = 0.8,
    num_samples: int = 1,
    attn_ip_scale: float = 0.8,
    text_token_scale: float = 1.0,
    ip_token_scale: float = None,
    ip_uncond_scale: float = 0.0,
    zero_ip_in_uncond: bool = True,
    # Sinkhorn generation parameters
    target_sinkhorn_threshold: float = 0.1,
    max_attempts: int = 20,
    top_k: int = 20,
    color_space: str = None,
):
    """Simple snippet to generate both style and Sinkhorn images with visualization."""
    
    # Clear GPU memory
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"Generating comparison for color_index={color_index}, layout_index={layout_index}")
    print(f"Prompt: '{prompt}'")
    
    # Get original image for display
    img_t, _ = color_dataset[color_index]
    pil_img = transforms.ToPILImage()(img_t)
    
    # Compute CLIP embedding
    from degis.shared.clip_vit_bigg14 import compute_clip_embedding_xl
    z_clip = compute_clip_embedding_xl(pil_img).to(device).unsqueeze(0)
    color_embedding = get_color_embedding(color_head, z_clip)
    
    # Create control image from edge data
    control_image = create_control_edge_pil(edge_maps[layout_index], size=512)
    
    # Compute color histogram for Sinkhorn constraint
    from degis.inference.generation_functions import compute_histogram_for_color_space, detect_color_space
    color_histogram = compute_histogram_for_color_space(pil_img, color_space or "lab", bins=8)
    if color_space is None:
        color_space = detect_color_space(color_histogram)
    
    print(f"Color space: {color_space}, Histogram shape: {color_histogram.shape}")
    
    # Edge detection helper functions (from the run_eval_pairs)
    def _canny_bool(pil_img, sigma=1.0):
        import numpy as np
        try:
            from skimage.color import rgb2gray
            from skimage.feature import canny
            g = rgb2gray(np.asarray(pil_img).astype(np.float32) / 255.0)
            e = canny(g, sigma=sigma)
            return e.astype(np.uint8)
        except Exception:
            from PIL import ImageFilter
            e = pil_img.convert("L").filter(ImageFilter.FIND_EDGES)
            return (np.asarray(e) > 32).astype(np.uint8)

    def _edge_scores(pred_bool, ref_bool):
        import numpy as np
        pred = pred_bool.astype(bool); ref = ref_bool.astype(bool)
        tp = np.logical_and(pred, ref).sum()
        fp = np.logical_and(pred, ~ref).sum()
        fn = np.logical_and(~pred, ref).sum()
        prec = tp / (tp + fp + 1e-9)
        rec  = tp / (tp + fn + 1e-9)
        f1   = 2 * prec * rec / (prec + rec + 1e-9)
        iou  = tp / (tp + fp + fn + 1e-9)
        try:
            from skimage.metrics import structural_similarity as ssim
            ssimv = ssim(pred_bool.astype(np.float32), ref_bool.astype(np.float32))
        except Exception:
            ssimv = np.nan
        return f1, iou, ssimv
    
    # Get reference edge map for IoU calculation
    control_bool = (np.asarray(control_image.convert("L")) > 0).astype(np.uint8)
    
    # 1. Style Generation (one-shot)
    print("\nRunning style generation...")
    import time
    start_time = time.time()
    
    style_images = generate_by_style(
        generator=generator,
        pil_image=pil_img,
        control_image=control_image,
        prompt=prompt,
        num_samples=num_samples,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        attn_ip_scale=attn_ip_scale,
        text_token_scale=text_token_scale,
        ip_token_scale=ip_token_scale,
        ip_uncond_scale=ip_uncond_scale,
        zero_ip_in_uncond=zero_ip_in_uncond
    )
    style_time = time.time() - start_time
    style_generated_image = style_images[0]
    
    # Calculate metrics for style generation
    from degis.inference.generation_functions import calculate_sinkhorn_distance_topk, calculate_cosine_similarity
    style_sinkhorn = calculate_sinkhorn_distance_topk(color_histogram, compute_histogram_for_color_space(style_generated_image, color_space, bins=8))
    style_cosine = calculate_cosine_similarity(prompt, style_generated_image)
    
    # Calculate edge metrics for style
    f1_s, iou_s, ssim_s = _edge_scores(_canny_bool(style_generated_image), control_bool)
    
    style_metrics = {
        'generation_time': f"{style_time:.2f}",
        'sinkhorn': style_sinkhorn,
        'cosine': style_cosine,
        'iou': iou_s
    }
    
    print(f"Style generation: {style_time:.2f}s, Sinkhorn: {style_sinkhorn:.4f}, Cosine: {style_cosine:.4f}, IoU: {iou_s:.4f}")
    
    # 2. Sinkhorn-Constrained Generation
    print("\nRunning Sinkhorn-constrained generation...")
    start_time = time.time()
    
    sinkhorn_images, sinkhorn_score, sinkhorn_cosine, attempts = generate_by_color_sinkhorn_constrained(
        generator=generator,
        color_embedding=color_embedding,
        control_image=control_image,
        original_histogram=color_histogram,
        prompt=prompt,
        target_sinkhorn_threshold=target_sinkhorn_threshold,
        max_attempts=max_attempts,
        top_k=top_k,
        color_space=color_space,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        attn_ip_scale=attn_ip_scale,
        text_token_scale=text_token_scale,
        ip_token_scale=ip_token_scale,
        ip_uncond_scale=ip_uncond_scale,
        zero_ip_in_uncond=zero_ip_in_uncond,
        verbose=True
    )
    sinkhorn_time = time.time() - start_time
    sinkhorn_generated_image = sinkhorn_images[0]
    
    # Calculate edge metrics for sinkhorn
    f1_e, iou_e, ssim_e = _edge_scores(_canny_bool(sinkhorn_generated_image), control_bool)
    
    sinkhorn_metrics = {
        'generation_time': f"{sinkhorn_time:.2f}",
        'sinkhorn': sinkhorn_score,
        'cosine': sinkhorn_cosine,
        'iou': iou_e,
        'attempts': attempts
    }
    
    print(f"Sinkhorn generation: {sinkhorn_time:.2f}s, Sinkhorn: {sinkhorn_score:.4f}, Cosine: {sinkhorn_cosine:.4f}, IoU: {iou_e:.4f}, Attempts: {attempts}")
    
    # 3. Create comprehensive visualization
    print("\nCreating comprehensive visualization...")
    
    visualization = visualize_generation_comparison(
        color_source_image=pil_img,
        edge_map_image=control_image,
        style_generated_image=style_generated_image,
        sinkhorn_generated_image=sinkhorn_generated_image,
        color_histogram=color_histogram,
        color_space=color_space,
        style_metrics=style_metrics,
        sinkhorn_metrics=sinkhorn_metrics,
        grid_size=256,
        font_size=16
    )
    
    print(f"Visualization created: {visualization.size}")
    print("Final comparison:")
    print(f"   Style: Sinkhorn={style_sinkhorn:.4f}, Cosine={style_cosine:.4f}, IoU={iou_s:.4f}")
    print(f"   Sinkhorn:   Sinkhorn={sinkhorn_score:.4f}, Cosine={sinkhorn_cosine:.4f}, IoU={iou_e:.4f}")
    
    from IPython.display import display
    display(visualization)
    
    return visualization, style_metrics, sinkhorn_metrics


# 1st Generation Scenario

We generate the images using a fixed prompt and settings for comparability, same layout image, and three different versions of the advertisement with 3 different color palettes


In [None]:
visualization_30_1000, style_metrics_30_1000, sinkhorn_metrics_30_1000 = generate_comparison_snippet(
    color_index=1000,
    layout_index=30,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

In [None]:
visualization_30_1001, style_metrics_30_1001, sinkhorn_metrics_30_1001 = generate_comparison_snippet(
    color_index=1001,
    layout_index=30,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

In [None]:
visualization_30_1008, style_metrics_30_1008, sinkhorn_metrics_30_1008 = generate_comparison_snippet(
    color_index=1008,
    layout_index=30,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

# 2nd Generation Scenario

Similar to 1st generation, but different layout image (footer banner)


In [None]:
visualization_71_1000, style_metrics_71_1000, sinkhorn_metrics_71_1000 = generate_comparison_snippet(
    color_index=1000,
    layout_index=71,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

In [None]:
visualization_71_1001, style_metrics_71_1001, sinkhorn_metrics_71_1001 = generate_comparison_snippet(
    color_index=1001,
    layout_index=71,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

In [None]:
visualization_71_1008, style_metrics_71_1008, sinkhorn_metrics_71_1008 = generate_comparison_snippet(
    color_index=1008,
    layout_index=71,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)


# 3rd Generation Scenario
Features 1 layout and 2 color palettes, a prompt that doesn't describe the layout

In [None]:
visualization_400_1000, style_metrics_400_1000, sinkhorn_metrics_400_1000 = generate_comparison_snippet(
    color_index=1000,
    layout_index=400,
    prompt="a boat, advertisement style, professional photography",
    guidance_scale=7,
    steps=50,
    controlnet_conditioning_scale=0.4,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=2.5,
    ip_token_scale=0.7,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

In [None]:
visualization_400_1008, style_metrics_400_1008, sinkhorn_metrics_400_1008 = generate_comparison_snippet(
    color_index=1008,
    layout_index=400,
    prompt="a boat, advertisement style, professional photography",
    guidance_scale=7,
    steps=50,
    controlnet_conditioning_scale=0.4,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=3.5,
    ip_token_scale=0.9,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    target_sinkhorn_threshold=0.02,
    max_attempts=20
)

# Quantitative Evaluation Snippets

In [None]:
# Build a tidy evaluation DataFrame for N prompt/layout pairs
import time, numpy as np, pandas as pd, torch, gc
from torchvision import transforms
from PIL import Image

from degis.inference import generate_by_style, generate_by_color_sinkhorn_constrained
from degis.inference.core_generation import get_color_embedding
from degis.shared.utils.image_utils import create_control_edge_pil
from degis.inference.generation_functions import (
    compute_histogram_for_color_space,
    calculate_sinkhorn_distance_topk,
    calculate_cosine_similarity,
)

# ---- edge helpers (F1 / IoU / SSIM) ----
def _canny_bool(pil_img, sigma=1.0):
    import numpy as np
    try:
        from skimage.color import rgb2gray
        from skimage.feature import canny
        g = rgb2gray(np.asarray(pil_img).astype(np.float32) / 255.0)
        e = canny(g, sigma=sigma)
        return e.astype(np.uint8)
    except Exception:
        from PIL import ImageFilter
        e = pil_img.convert("L").filter(ImageFilter.FIND_EDGES)
        return (np.asarray(e) > 32).astype(np.uint8)

def _edge_scores(pred_bool, ref_bool):
    import numpy as np
    pred = pred_bool.astype(bool); ref = ref_bool.astype(bool)
    tp = np.logical_and(pred, ref).sum()
    fp = np.logical_and(pred, ~ref).sum()
    fn = np.logical_and(~pred, ref).sum()
    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    iou  = tp / (tp + fp + fn + 1e-9)
    try:
        from skimage.metrics import structural_similarity as ssim
        ssimv = ssim(pred_bool.astype(np.float32), ref_bool.astype(np.float32))
    except Exception:
        ssimv = np.nan
    return f1, iou, ssimv

toPIL = transforms.ToPILImage()

@torch.no_grad()
def run_eval_pairs(
    pairs,
    prompt="a car, advertisement style, professional photography",
    guidance_scale=7,
    steps=30,
    controlnet_conditioning_scale=0.5,
    attn_ip_scale=0.6,
    text_token_scale=1.5,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=False,
    # palette metrics
    color_space="lab",
    bins=8,
    # Sinkhorn-constrained outer loop
    target_sinkhorn_threshold=0.02,
    max_attempts=20,
    top_k=20,
    verbose=False,
):
    rows = []
    dev = generator.device

    for color_index, layout_index in pairs:
        gc.collect(); torch.cuda.empty_cache()

        # --- source PIL for style baseline + target histogram
        img_t, _ = color_dataset[color_index]         # uses the UnifiedImageDataset
        pil_img  = toPIL(img_t)

        target_hist = compute_histogram_for_color_space(pil_img, color_space, bins=bins)

        # Compute CLIP embedding
        from degis.shared.clip_vit_bigg14 import compute_clip_embedding_xl
        z_clip = compute_clip_embedding_xl(pil_img).to(device).unsqueeze(0)
        color_embedding = get_color_embedding(color_head, z_clip)

        # --- control (edge) image + boolean mask for metrics
        control_img  = create_control_edge_pil(edge_maps[layout_index], size=512)
        control_bool = (np.asarray(control_img.convert("L")) > 0).astype(np.uint8)

        # ============ 1) STYLE baseline ============
        t0 = time.time()
        style_images = generate_by_style(
            generator=generator,
            pil_image=pil_img,                          # Pass PIL (not clip_image_embeds)
            control_image=control_img,
            prompt=prompt,
            num_samples=1,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            attn_ip_scale=attn_ip_scale,
            text_token_scale=text_token_scale,
            ip_token_scale=ip_token_scale,
            ip_uncond_scale=ip_uncond_scale,
            zero_ip_in_uncond=zero_ip_in_uncond,
        )
        style_time = time.time() - t0
        style_img  = style_images[0]

        style_hist = compute_histogram_for_color_space(style_img, color_space, bins=bins)
        eval_top_k = min(top_k, len(target_hist)) if top_k is not None else len(target_hist)
        style_sinkhorn  = calculate_sinkhorn_distance_topk(target_hist, style_hist, top_k=eval_top_k, blur=0.01)

        style_cos  = calculate_cosine_similarity(prompt, style_img)
        f1_s, iou_s, ssim_s = _edge_scores(_canny_bool(style_img), control_bool)

        # ============ 2) Sinkhorn-constrained (palette-tokens) ============
        t1 = time.time()
        sinkhorn_images, sinkhorn_score, sinkhorn_cos, attempts = generate_by_color_sinkhorn_constrained(
            generator=generator,
            color_embedding=color_embedding,
            control_image=control_img,
            original_histogram=target_hist,
            prompt=prompt,
            target_sinkhorn_threshold=target_sinkhorn_threshold,
            max_attempts=max_attempts,
            top_k=top_k,
            color_space=color_space,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            attn_ip_scale=attn_ip_scale,
            text_token_scale=text_token_scale,
            ip_token_scale=ip_token_scale,
            ip_uncond_scale=ip_uncond_scale,
            zero_ip_in_uncond=zero_ip_in_uncond,
            verbose=False,
        )
        sinkhorn_time = time.time() - t1
        sinkhorn_img  = sinkhorn_images[0]
        f1_e, iou_e, ssim_e = _edge_scores(_canny_bool(sinkhorn_img), control_bool)

        rows.append(dict(
            color_index=int(color_index),
            layout_index=int(layout_index),
            prompt=prompt,
            style_sinkhorn=float(style_sinkhorn), style_cos=float(style_cos),
            sinkhorn_sinkhorn=float(sinkhorn_score),   sinkhorn_cos=float(sinkhorn_cos),
            style_f1=float(f1_s), style_iou=float(iou_s), style_ssim=float(ssim_s),
            sinkhorn_f1=float(f1_e),   sinkhorn_iou=float(iou_e),   sinkhorn_ssim=float(ssim_e),
            attempts=int(attempts),
            style_time=float(style_time), sinkhorn_time=float(sinkhorn_time),
        ))
        if verbose:
            print(f"[{color_index},{layout_index}] ΔSinkhorn={style_sinkhorn-sinkhorn_score:+.5f} | attempts={attempts}")

    df = pd.DataFrame(rows)
    display(df.head())
    print(f"Collected {len(df)} rows.")
    return df

# Example usage - use specific IDs for controlled evaluation
pairs = [(int(i), int(j)) for i in range(1000, 1020)  # color_index: 1000-1019 (20 images)
         for j in [30, 71, 400, 507]]  # layout_index: 4 specific layouts

print("Testing these color_index, layout_index pairs:")
for i, (c_idx, l_idx) in enumerate(pairs):
    print(f"  {i+1:2d}: color_index={c_idx:7d}, layout_index={l_idx:4d}")
print(f"Total pairs: {len(pairs)}")  # Should be 20 * 4 = 80 pairs

df_eval = run_eval_pairs(pairs, target_sinkhorn_threshold=0.05, max_attempts=20, verbose=False)

In [None]:
# Eval-1 - Pass@{0.020, 0.025, 0.030}

import numpy as np, pandas as pd

def pass_at_t(df, taus=(0.020, 0.025, 0.030), col="sinkhorn_sinkhorn"):
    out = []
    for t in taus:
        passed = (df[col].values <= t).mean()
        out.append({"tau": t, "pass_rate": float(passed)})
    res = pd.DataFrame(out)
    display(res)
    return res

print("Palette-token run (Sinkhorn-constrained):")
pass_at_t(df_eval, taus=(0.020, 0.025, 0.030), col="sinkhorn_sinkhorn")

print("\nSTYLE baseline (for context):")
pass_at_t(df_eval, taus=(0.020, 0.025, 0.030), col="style_sinkhorn")

In [None]:
# Eval-2 - paired improvement over STYLE baseline

import numpy as np, pandas as pd, matplotlib.pyplot as plt

df_delta = df_eval.assign(delta_sinkhorn = df_eval["style_sinkhorn"] - df_eval["sinkhorn_sinkhorn"])
print(f"ΔSinkhorn (style − palette): mean={df_delta['delta_sinkhorn'].mean():.5f}, "
      f"median={df_delta['delta_sinkhorn'].median():.5f}")

plt.figure()
plt.hist(df_delta["delta_sinkhorn"], bins=20)
plt.xlabel("ΔSinkhorn (positive = palette better)"); plt.ylabel("count"); plt.title("ΔSinkhorn vs STYLE baseline")
plt.show()

df_delta.sort_values("delta_sinkhorn", ascending=True).head(5)

In [None]:
# Eval-3 - layout fidelity (edge F1 / IoU / SSIM)

def _brief_stats(x): 
    import numpy as np
    return f"mean={np.nanmean(x):.3f}  std={np.nanstd(x):.3f}"

print("STYLE   - F1:", _brief_stats(df_eval["style_f1"].values),
      "| IoU:", _brief_stats(df_eval["style_iou"].values),
      "| SSIM:", _brief_stats(df_eval["style_ssim"].values))

print("PALETTE - F1:", _brief_stats(df_eval["sinkhorn_f1"].values),
      "| IoU:", _brief_stats(df_eval["sinkhorn_iou"].values),
      "| SSIM:", _brief_stats(df_eval["sinkhorn_ssim"].values))

In [None]:
# Eval-4 - text–image cosine (higher = better)

import numpy as np

m_style  = df_eval["style_cos"].mean()
m_sinkhorn    = df_eval["sinkhorn_cos"].mean()
delta    = (df_eval["sinkhorn_cos"] - df_eval["style_cos"]).mean()

print(f"CLIP cosine - STYLE mean={m_style:.4f} | PALETTE mean={m_sinkhorn:.4f} | Δ (palette−style)={delta:+.4f}")

In [None]:
# Eval-5 - sweep ip_token_scale ∈ {0.4, 0.7, 1.0} on a small subset

scales = [0.4, 0.7, 1.0]
subset_pairs = df_eval[["color_index","layout_index"]].head(10).itertuples(index=False, name=None)

rows = []
for s in scales:
    df_s = run_eval_pairs(list(subset_pairs),
                          ip_token_scale=s,
                          target_sinkhorn_threshold=0.025,
                          max_attempts=20,
                          verbose=False)
    rows.append({"ip_token_scale": s,
                 "sinkhorn_mean": df_s["sinkhorn_sinkhorn"].mean(),
                 "sinkhorn_median": df_s["sinkhorn_sinkhorn"].median(),
                 "pass@0.025": (df_s["sinkhorn_sinkhorn"] <= 0.025).mean()})
    # reset iterator
    subset_pairs = df_eval[["color_index","layout_index"]].head(10).itertuples(index=False, name=None)

df_sweep = pd.DataFrame(rows)
display(df_sweep)

import matplotlib.pyplot as plt
plt.figure()
plt.plot(df_sweep["ip_token_scale"], df_sweep["sinkhorn_mean"], marker="o")
plt.xlabel("ip_token_scale"); plt.ylabel("Sinkhorn (mean)"); plt.title("Sensitivity of palette strength")
plt.show()


In [None]:
# Eval-6 - attempts statistics and success rate within budget

tau = 0.025
budget = 20

succ = (df_eval["sinkhorn_sinkhorn"] <= tau).mean()
att_succ = df_eval.loc[df_eval["sinkhorn_sinkhorn"] <= tau, "attempts"]
print(f"Success@τ={tau:.3f} within {budget} attempts: {succ*100:.1f}%")
if len(att_succ):
    print(f"Attempts for successful runs - mean={att_succ.mean():.2f}, median={att_succ.median():.1f}, max={att_succ.max()}")

import matplotlib.pyplot as plt
plt.figure()
plt.hist(df_eval["attempts"], bins=range(1, budget+2))
plt.xlabel("Attempts"); plt.ylabel("count"); plt.title("Outer-loop attempts distribution")
plt.show()

In [None]:
# Eval-7 - trainable params added (lightweight claim)

import torch, json, glob, os

color_params = sum(p.numel() for p in color_head.parameters() if p.requires_grad)
rest_params  = None

# try to read rest from any best_summary.json (if present)
try:
    any_summary = sorted(glob.glob("evaluation_runs/**/best_summary.json", recursive=True))[0]
    with open(any_summary, "r") as f:
        js = json.load(f)
    rest_params = js.get("param_count_rest")
except Exception:
    pass

if rest_params is None:
    print(f"Trainable params - ColorHead: {color_params:,} (RestHead: n/a at inference)")
else:
    print(f"Trainable params - ColorHead: {color_params:,} | RestHead: {int(rest_params):,} | Total: {color_params+int(rest_params):,}")


In [None]:
# --- helpers (edges, metrics) ---
import numpy as np, time, torch
from PIL import Image
try:
    from skimage.feature import canny as sk_canny
    from skimage.metrics import structural_similarity as ssim
    _HAS_SKIMAGE = True
except Exception:
    _HAS_SKIMAGE = False
from numpy.typing import ArrayLike

def _edge_from_pil(img: Image.Image, sigma: float = 1.0) -> np.ndarray:
    """0..1 float edge map from PIL image; uses skimage if available, else a simple Sobel fallback."""
    g = np.asarray(img.convert("L"), dtype=np.float32) / 255.0
    if _HAS_SKIMAGE:
        e = sk_canny(g, sigma=sigma).astype(np.float32)
    else:
        gx = np.gradient(g, axis=1); gy = np.gradient(g, axis=0)
        mag = np.hypot(gx, gy)
        thr = np.quantile(mag, 0.90)
        e = (mag >= thr).astype(np.float32)
    return e

def _bin(x: ArrayLike, thr: float = 0.5) -> np.ndarray:
    return (np.asarray(x, dtype=np.float32) >= thr).astype(np.uint8)

def f1_iou(pred: ArrayLike, targ: ArrayLike, thr: float = 0.5):
    """F1 and IoU on binarised maps."""
    p = _bin(pred, thr); t = _bin(targ, thr)
    tp = np.sum((p == 1) & (t == 1))
    fp = np.sum((p == 1) & (t == 0))
    fn = np.sum((p == 0) & (t == 1))
    denom_f1 = (2*tp + fp + fn)
    f1 = (2*tp / denom_f1) if denom_f1 else 1.0
    denom_iou = (tp + fp + fn)
    iou = (tp / denom_iou) if denom_iou else 1.0
    return float(f1), float(iou)

# --- evaluation buffer ---
records = []  # each item is a dict (see below)

# --- one-shot case evaluator (returns a dict; does NOT mutate global state) ---
def evaluate_case(
    *,                                   # keyword-only for clarity
    color_img: Image.Image,             # style reference image (PIL)
    control_edge_ref: Image.Image,       # reference edge map (PIL, 0..255)
    color_head,                          # trained color head (nn.Module)
    generator,                           # IP-Adapter pipeline wrapper
    prompt: str,
    z_clip_row: np.ndarray,              # 1D numpy (D,)
    color_space: str = "lab",
    bins: int = 80,                      # for on-the-fly histogram + Sinkhorn
    # gauges
    attn_ip_scale: float = 0.7,
    text_token_scale: float = 1.2,
    ip_token_scale: float = 0.6,
    ip_uncond_scale: float = 0.0,
    zero_ip_in_uncond: bool = True,
    # sampling
    guidance_scale: float = 8.0,
    steps: int = 60,
    controlnet_conditioning_scale: float = 0.7,
    # outer loop (only used in palette path helper if you set target)
    target_sinkhorn_threshold: float | None = None,
    max_attempts: int = 20,
    top_k: int = 20,
):
    """
    Runs (A) style baseline and (B) palette-tokens variant.
    Returns a single dict with Sinkhorn/CLIP/edge metrics & timings for both paths.
    """
    from degis.inference import generate_by_style, generate_by_color_sinkhorn_constrained
    from degis.shared.utils.image_utils import create_control_edge_pil
    from degis.inference.generation_functions import (
        compute_histogram_for_color_space, calculate_sinkhorn_distance_topk, calculate_cosine_similarity,
        get_color_embedding
    )

    # prep common things
    t0 = time.time()
    z_clip = torch.tensor(z_clip_row, dtype=torch.float32, device="cuda" if torch.cuda.is_available() else "cpu").unsqueeze(0)
    hist_color = compute_histogram_for_color_space(color_img, color_space=color_space, bins=bins)
    # ref edges (0..1 float)
    edge_ref = np.asarray(control_edge_ref.convert("L"), dtype=np.float32) / 255.0

    # --- (A) style baseline ---
    tA = time.time()
    style_imgs = generate_by_style(
        generator=generator,
        pil_image=color_img,
        control_image=control_edge_ref,
        prompt=prompt,
        num_samples=1,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        attn_ip_scale=attn_ip_scale,
        text_token_scale=text_token_scale,
        ip_token_scale=None,          # stock behaviour (style ref)
        ip_uncond_scale=None,
        zero_ip_in_uncond=False,
    )
    style_img = style_imgs[0]
    tA = time.time() - tA

    # metrics for style
    style_hist = compute_histogram_for_color_space(style_img, color_space=color_space, bins=bins)
    style_sinkhorn  = calculate_sinkhorn_distance_topk(hist_color, style_hist, top_k=top_k)
    style_cos  = calculate_cosine_similarity(prompt, style_img)
    style_edge = _edge_from_pil(style_img)
    style_f1, style_iou = f1_iou(style_edge, edge_ref, thr=0.5)
    style_ssim = float(ssim(edge_ref, style_edge)) if _HAS_SKIMAGE else np.nan

    # --- (B) palette-tokens path (with outer loop disabled by default) ---
    tB = time.time()
    color_emb = get_color_embedding(color_head, z_clip)  # [B, D]
    sinkhorn_imgs, sinkhorn_score, sinkhorn_cos, attempts = generate_by_color_sinkhorn_constrained(
        generator=generator,
        color_embedding=color_emb,
        control_image=control_edge_ref,
        original_histogram=hist_color,
        prompt=prompt,
        target_sinkhorn_threshold=target_sinkhorn_threshold,
        max_attempts=max_attempts,
        top_k=top_k,
        color_space=color_space,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        attn_ip_scale=attn_ip_scale,
        text_token_scale=text_token_scale,
        ip_token_scale=ip_token_scale,
        ip_uncond_scale=ip_uncond_scale,
        zero_ip_in_uncond=zero_ip_in_uncond,
        verbose=False
    )
    sinkhorn_img = sinkhorn_imgs[0]
    tB = time.time() - tB

    # metrics for palette path
    gen_hist  = compute_histogram_for_color_space(sinkhorn_img, color_space=color_space, bins=bins)
    gen_sinkhorn   = calculate_sinkhorn_distance_topk(hist_color, gen_hist, top_k=top_k)
    gen_cos   = calculate_cosine_similarity(prompt, sinkhorn_img)
    gen_edge  = _edge_from_pil(sinkhorn_img)
    gen_f1, gen_iou = f1_iou(gen_edge, edge_ref, thr=0.5)
    gen_ssim = float(ssim(edge_ref, gen_edge)) if _HAS_SKIMAGE else np.nan

    return {
        # reference
        "edge_ref": edge_ref,
        # style path
        "style_img": style_img, "style_time": tA, "sinkhorn_style": float(style_sinkhorn), "cos_style": float(style_cos),
        "edge_style": style_edge, "f1_style": style_f1, "iou_style": style_iou, "ssim_style": style_ssim,
        # palette path
        "sinkhorn_img": sinkhorn_img, "sinkhorn_time": tB, "sinkhorn_palette": float(gen_sinkhorn), "cos_palette": float(gen_cos),
        "edge_sinkhorn": gen_edge, "f1_sinkhorn": gen_f1, "iou_sinkhorn": gen_iou, "ssim_sinkhorn": gen_ssim,
        "attempts": int(attempts),
    }

In [None]:
# thresholds to report
taus = [0.020, 0.025, 0.030]

sinkhorn_vals_palette = np.array([r["sinkhorn_palette"] for r in records], dtype=float)
sinkhorn_vals_style   = np.array([r["sinkhorn_style"]   for r in records], dtype=float)

def _pass_at_tau(arr, tau): return float(np.mean(arr <= tau))

print("Pass@τ (palette-tokens)  |  (style baseline)")
for t in taus:
    print(f"  τ={t:.3f}:  { _pass_at_tau(sinkhorn_vals_palette,t):.3f}   |   { _pass_at_tau(sinkhorn_vals_style,t):.3f}")

In [None]:
delta = sinkhorn_vals_style - sinkhorn_vals_palette  # positive = Sinkhorn-constrained method improves
print(f"ΔSinkhorn (style − palette): mean={delta.mean():.5f}, median={np.median(delta):.5f}")
hist, edges = np.histogram(delta, bins=7)
print("hist ΔSinkhorn:", hist.tolist())
print("bins:", [round(x,5) for x in edges.tolist()])

In [None]:
f1_s = np.array([r["f1_style"] for r in records], dtype=float)
f1_p = np.array([r["f1_sinkhorn"]   for r in records], dtype=float)
iou_s = np.array([r["iou_style"] for r in records], dtype=float)
iou_p = np.array([r["iou_sinkhorn"]   for r in records], dtype=float)
ssim_s = np.array([r["ssim_style"] for r in records], dtype=float)
ssim_p = np.array([r["ssim_sinkhorn"]   for r in records], dtype=float)

def _m(m): return f"{np.nanmean(m):.3f}±{np.nanstd(m):.3f}"

print("Edge adherence  (mean±sd)")
print(f"  F1:   style={_m(f1_s)}   palette={_m(f1_p)}")
print(f"  IoU:  style={_m(iou_s)}  palette={_m(iou_p)}")
print(f"  SSIM: style={_m(ssim_s)} palette={_m(ssim_p)}  (nan if skimage not present)")

In [None]:
cos_s = np.array([r["cos_style"] for r in records], dtype=float)
cos_p = np.array([r["cos_palette"] for r in records], dtype=float)
print(f"CLIP cosine - style:   mean={cos_s.mean():.4f}")
print(f"CLIP cosine - palette: mean={cos_p.mean():.4f}")
print(f"Δ (palette − style):   mean={(cos_p - cos_s).mean():.4f}")

In [None]:
scales = [0.4, 0.7, 1.0]
subset_idx = list(range(min(20, len(records))))
table = {s: [] for s in scales}

# We reuse the same inputs as the original record, but re-run only palette path:
for s in scales:
    vals = []
    for i in subset_idx:
        r = records[i]
        case = evaluate_case(
            color_img=r["color_img"],                 # <-- store these when creating records
            control_edge_ref=r["control_edge_ref"],
            color_head=color_head,
            generator=generator,
            prompt=r["prompt"],
            z_clip_row=r["z_clip_row"],
            ip_token_scale=s,
            text_token_scale=1.2,
            attn_ip_scale=0.7,
            zero_ip_in_uncond=True,
            ip_uncond_scale=0.0,
            guidance_scale=8.0, steps=60,
            controlnet_conditioning_scale=0.7,
        )
        vals.append(case["sinkhorn_palette"])
    table[s] = vals

print("Sinkhorn vs ip_token_scale (lower is better)")
for s in scales:
    arr = np.array(table[s], dtype=float)
    print(f"  {s:.1f}: mean={arr.mean():.4f}, median={np.median(arr):.4f}, n={len(arr)}")

In [None]:
attempts = np.array([r["attempts"] for r in records], dtype=float)
sinkhorn_p     = np.array([r["sinkhorn_palette"] for r in records], dtype=float)

# success rates within a budget (e.g., 20 attempts) for thresholds:
taus = [0.020, 0.025, 0.030]
budget = 20
print(f"Attempts stats: mean={attempts.mean():.2f}, median={np.median(attempts):.1f}")
for t in taus:
    succ = (sinkhorn_p <= t) & (attempts <= budget)
    print(f"Success@{t:.3f} within {budget} attempts: {succ.mean():.3f}")


    

In [None]:
# df_eval.to_csv("evaluation_runs/evaluation_metrics.csv", index=False)