# Test-Time Diversity Steering via Early-Step Clean Image Estimation

This notebook implements and evaluates **Clean-Manifold Diversity Guidance** — a training-free method to improve diversity in text-to-image diffusion models at inference time.

**Three methods compared:**
| Method | Guidance domain | Description |
|---|---|---|
| `baseline` | — | Standard CFG sampling |
| `naive` | Noisy latent $x_t$ | Diversity gradient on noise-corrupted signal |
| `clean_estimate` | Clean estimate $\hat{x}_0$ | **Ours** — diversity gradient on Tweedie estimate |

## 1. Setup & Installation

In [None]:
!pip install -q torch torchvision diffusers transformers accelerate lpips matplotlib tqdm numpy Pillow image-reward scipy
!pip install -q git+https://github.com/openai/CLIP.git

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  Preparing metadata (setup.py) ... [?25l[?25hdone


In [19]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from itertools import combinations
from tqdm.auto import tqdm
from typing import List, Dict, Any
import json, os, gc

from diffusers import StableDiffusionPipeline, DDIMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Device: cuda
GPU: Tesla T4
VRAM: 15.6 GB


## 2. Core Pipeline

In [20]:
# ──────────────────────────────────────────────────────────────────────
# Helper functions
# ──────────────────────────────────────────────────────────────────────

def predict_x0(x_t, model_output, alpha_prod_t, prediction_type):
    """Compute x̂₀ from model output via Tweedie's formula."""
    if prediction_type == "epsilon":
        x0 = (x_t - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5
    elif prediction_type == "v_prediction":
        x0 = alpha_prod_t ** 0.5 * x_t - (1 - alpha_prod_t) ** 0.5 * model_output
    else:
        raise ValueError(f"Unknown prediction type: {prediction_type}")
    return x0.clamp(-20, 20)


def x0_to_model_output(x_t, x0, alpha_prod_t, prediction_type):
    """Convert (possibly modified) x̂₀ back to model-output space."""
    if prediction_type == "epsilon":
        return (x_t - alpha_prod_t ** 0.5 * x0) / (1 - alpha_prod_t) ** 0.5
    elif prediction_type == "v_prediction":
        return (alpha_prod_t ** 0.5 * x_t - x0) / (1 - alpha_prod_t) ** 0.5
    else:
        raise ValueError(f"Unknown prediction type: {prediction_type}")


def compute_diversity_gradient(tensor, loss_type="cosine"):
    """
    Gradient of pairwise similarity loss w.r.t. tensor.

    loss_type:
      "cosine"    – cosine sim on full flattened latent (texture-level)
      "structural" – cosine sim on avg-pooled latent    (layout-level, recommended)
      "mse"       – negative pairwise MSE

    Returns  (grad [B,C,H,W],  loss_value float)
    Gradient is NOT unit-normalised — its magnitude reflects actual similarity.
    """
    B = tensor.shape[0]
    if B < 2:
        return torch.zeros_like(tensor), 0.0

    with torch.enable_grad():
        x = tensor.detach().float().clone()
        x.requires_grad_(True)

        if loss_type == "structural":
            # avg-pool to 8×8 → captures global composition, ignores texture
            pooled = F.adaptive_avg_pool2d(x, (8, 8))
            flat = pooled.reshape(B, -1)
        elif loss_type in ("cosine", "mse"):
            flat = x.reshape(B, -1)
        else:
            raise ValueError(loss_type)

        if loss_type in ("cosine", "structural"):
            normed = F.normalize(flat, dim=-1)
            sim = torch.mm(normed, normed.t())
            mask = ~torch.eye(B, dtype=torch.bool, device=flat.device)
            loss = sim[mask].mean()
        elif loss_type == "mse":
            diffs = flat.unsqueeze(0) - flat.unsqueeze(1)
            sq_dist = (diffs ** 2).sum(-1)
            mask = ~torch.eye(B, dtype=torch.bool, device=flat.device)
            loss = -sq_dist[mask].mean()

        grad = torch.autograd.grad(loss, x)[0]

    # normalise so that ||grad||_batch = 1  (scale is controlled by λ only)
    batch_norm = grad.norm() + 1e-8
    grad = grad / batch_norm
    return grad.to(tensor.dtype), loss.item()


def diversity_weight(step_idx, num_guidance_steps, schedule="constant"):
    """Scalar weight in [0,1] for the diversity term at step_idx."""
    if num_guidance_steps == 0:
        return 0.0
    ratio = step_idx / max(num_guidance_steps - 1, 1)
    if schedule == "constant":
        return 1.0
    elif schedule == "linear_decay":
        return 1.0 - ratio
    elif schedule == "cosine_decay":
        return 0.5 * (1 + np.cos(np.pi * ratio))
    raise ValueError(schedule)

In [None]:
# ──────────────────────────────────────────────────────────────────────
# Diversity-Guided Pipeline
# ──────────────────────────────────────────────────────────────────────

class DiversityGuidedPipeline:
    def __init__(self, model_id, device="cuda", dtype=torch.float16):
        self.device = device
        self.dtype = dtype
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, torch_dtype=dtype, safety_checker=None,
        )
        self.pipe.scheduler = DDIMScheduler.from_config(
            self.pipe.scheduler.config,
            clip_sample=False,
        )
        self.pipe = self.pipe.to(device)

        # ── memory optimisations (essential for 16 GB GPUs) ──
        self.pipe.enable_vae_slicing()          # decode one image at a time
        self.pipe.enable_attention_slicing(1)    # slice attention to save VRAM
        try:
            self.pipe.enable_xformers_memory_efficient_attention()
            print("xformers enabled")
        except Exception:
            pass

        self.prediction_type = self.pipe.scheduler.config.prediction_type
        print(f"Loaded {model_id}  prediction_type={self.prediction_type}")

    def _encode_prompt(self, prompt, negative_prompt, batch_size):
        tok, enc = self.pipe.tokenizer, self.pipe.text_encoder
        ids = tok(prompt, padding="max_length",
                  max_length=tok.model_max_length,
                  truncation=True, return_tensors="pt").input_ids.to(self.device)
        p_emb = enc(ids)[0].repeat(batch_size, 1, 1)
        neg = negative_prompt or ""
        nids = tok(neg, padding="max_length",
                   max_length=tok.model_max_length,
                   truncation=True, return_tensors="pt").input_ids.to(self.device)
        n_emb = enc(nids)[0].repeat(batch_size, 1, 1)
        return p_emb, n_emb

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        batch_size: int = 4,
        num_inference_steps: int = 50,
        guidance_scale: float = 9.0,
        diversity_scale: float = 30.0,
        early_stop_ratio: float = 0.2,
        method: str = "clean_estimate",
        loss_type: str = "structural",
        weight_schedule: str = "constant",
        seeds: list = None,
        height: int = 512,
        width: int = 512,
        negative_prompt: str = "",
    ) -> Dict[str, Any]:

        if seeds is None:
            seeds = list(range(42, 42 + batch_size))

        p_emb, n_emb = self._encode_prompt(prompt, negative_prompt, batch_size)

        latent_ch = self.pipe.unet.config.in_channels
        lh, lw = height // 8, width // 8
        gens = [torch.Generator(device=self.device).manual_seed(s)
                for s in seeds]
        latents = torch.cat([
            torch.randn(1, latent_ch, lh, lw, generator=g,
                        device=self.device, dtype=self.dtype)
            for g in gens
        ])

        self.pipe.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.pipe.scheduler.timesteps
        latents = latents * self.pipe.scheduler.init_noise_sigma

        n_guide = int(len(timesteps) * early_stop_ratio) if method != "baseline" else 0
        div_losses = []

        for si, t in enumerate(tqdm(timesteps, leave=False, desc=method)):
            do_guide = si < n_guide

            lat_in = torch.cat([latents] * 2)
            lat_in = self.pipe.scheduler.scale_model_input(lat_in, t)
            emb_in = torch.cat([n_emb, p_emb])

            noise_pred = self.pipe.unet(lat_in, t,
                                        encoder_hidden_states=emb_in).sample
            n_unc, n_cond = noise_pred.chunk(2)
            n_cfg = n_unc + guidance_scale * (n_cond - n_unc)

            if do_guide:
                w = diversity_weight(si, n_guide, weight_schedule)
                apt = self.pipe.scheduler.alphas_cumprod[t.long()]

                if method == "naive":
                    grad, lv = compute_diversity_gradient(latents, loss_type)
                    latents = self.pipe.scheduler.step(n_cfg, t, latents).prev_sample
                    latents = latents - (diversity_scale * w) * grad
                    div_losses.append(lv)
                    continue

                elif method == "clean_estimate":
                    x0h = predict_x0(latents, n_cfg, apt, self.prediction_type)
                    grad, lv = compute_diversity_gradient(x0h, loss_type)
                    latents = self.pipe.scheduler.step(n_cfg, t, latents).prev_sample
                    latents = latents - (diversity_scale * w) * grad
                    div_losses.append(lv)
                    continue

            latents = self.pipe.scheduler.step(n_cfg, t, latents).prev_sample

        # decode with slicing (one image at a time → constant VRAM)
        imgs_t = self.pipe.vae.decode(
            latents / self.pipe.vae.config.scaling_factor
        ).sample
        imgs_t = (imgs_t / 2 + 0.5).clamp(0, 1)
        imgs_pil = [
            Image.fromarray(
                (im.permute(1, 2, 0).cpu().float().numpy() * 255).astype(np.uint8)
            ) for im in imgs_t
        ]

        # free intermediate VRAM
        del latents, noise_pred, lat_in, n_cfg
        torch.cuda.empty_cache()

        return {"images": imgs_pil, "images_tensor": imgs_t.cpu(),
                "diversity_losses": div_losses}

## 3. Evaluation Metrics

In [22]:
# ── Patch transformers compatibility for ImageReward ──────────
# ImageReward's BLIP imports functions that moved in transformers >= 4.45
import transformers.modeling_utils as _mu
try:
    from transformers.pytorch_utils import (
        apply_chunking_to_forward,
        find_pruneable_heads_and_indices,
        prune_linear_layer,
    )
    for _name, _fn in [
        ("apply_chunking_to_forward", apply_chunking_to_forward),
        ("find_pruneable_heads_and_indices", find_pruneable_heads_and_indices),
        ("prune_linear_layer", prune_linear_layer),
    ]:
        if not hasattr(_mu, _name):
            setattr(_mu, _name, _fn)
    print("Patched transformers compatibility for ImageReward")
except ImportError:
    pass

# ──────────────────────────────────────────────────────────────

class DiversityMetrics:
    def __init__(self, device="cuda"):
        self.device = device
        self._lpips = None
        self._clip_model = None
        self._clip_proc = None
        self._inception = None
        self._reward = None

    # ── Lazy-loaded models ──────────────────────────────────

    @property
    def lpips_fn(self):
        if self._lpips is None:
            import lpips
            self._lpips = lpips.LPIPS(net="alex").to(self.device).eval()
        return self._lpips

    def _load_clip(self):
        from transformers import CLIPModel, CLIPProcessor
        name = "openai/clip-vit-base-patch32"
        self._clip_model = CLIPModel.from_pretrained(name).to(self.device).eval()
        self._clip_proc = CLIPProcessor.from_pretrained(name)

    @property
    def clip_model(self):
        if self._clip_model is None: self._load_clip()
        return self._clip_model

    @property
    def clip_proc(self):
        if self._clip_proc is None: self._load_clip()
        return self._clip_proc

    @property
    def inception_model(self):
        if self._inception is None:
            from torchvision.models import inception_v3
            self._inception = inception_v3(pretrained=True).to(self.device).eval()
            print("Inception-v3 loaded for IS/FID")
        return self._inception

    @property
    def reward_model(self):
        if self._reward is None:
            import ImageReward as RM
            self._reward = RM.load("ImageReward-v1.0", device=self.device)
            print("ImageReward-v1.0 loaded")
        return self._reward

    # ── Feature extraction helpers ──────────────────────────

    def _get_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = self.clip_proc(images=images, return_tensors="pt")
        pixel_values = inputs["pixel_values"].to(self.device)
        feats = self.clip_model.get_image_features(pixel_values=pixel_values)
        if not isinstance(feats, torch.Tensor):
            feats = feats.pooler_output if hasattr(feats, "pooler_output") else feats[0]
        return feats

    def _get_text_features(self, text: str) -> torch.Tensor:
        inputs = self.clip_proc(text=[text], return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        feats = self.clip_model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
        if not isinstance(feats, torch.Tensor):
            feats = feats.pooler_output if hasattr(feats, "pooler_output") else feats[0]
        return feats

    @torch.no_grad()
    def _get_inception_outputs(self, images: List[Image.Image], batch_size=32):
        """Extract Inception-v3 pool features (2048-d) and class probs (1000-d)."""
        from torchvision import transforms as TF
        transform = TF.Compose([
            TF.Resize((299, 299), interpolation=TF.InterpolationMode.BILINEAR),
            TF.ToTensor(),
        ])
        model = self.inception_model
        feat_buf = []
        hook = model.avgpool.register_forward_hook(
            lambda m, inp, out: feat_buf.append(out.flatten(1))
        )
        all_feats, all_probs = [], []
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            tensors = torch.stack(
                [transform(img.convert("RGB")) for img in batch]
            ).to(self.device)
            feat_buf.clear()
            logits = model(tensors)
            if isinstance(logits, tuple):
                logits = logits[0]
            all_probs.append(F.softmax(logits, dim=1).cpu())
            all_feats.append(feat_buf[0].cpu())
        hook.remove()
        return torch.cat(all_feats), torch.cat(all_probs)

    # ── Per-batch diversity metrics ─────────────────────────

    @torch.no_grad()
    def pairwise_lpips(self, images: List[Image.Image]) -> float:
        tensors = []
        for img in images:
            arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0
            tensors.append(torch.from_numpy(arr).permute(2, 0, 1) * 2 - 1)
        tensors = torch.stack(tensors).to(self.device)
        vals = [self.lpips_fn(tensors[i:i+1], tensors[j:j+1]).item()
                for i, j in combinations(range(len(images)), 2)]
        return float(np.mean(vals)) if vals else 0.0

    @torch.no_grad()
    def clip_diversity(self, images: List[Image.Image]) -> float:
        feats = F.normalize(self._get_image_features(images), dim=-1)
        sim = torch.mm(feats, feats.t())
        B = len(images)
        mask = ~torch.eye(B, dtype=torch.bool, device=self.device)
        return 1.0 - sim[mask].mean().item()

    # ── Per-batch alignment metrics ─────────────────────────

    @torch.no_grad()
    def clip_score(self, images: List[Image.Image], prompt: str) -> float:
        img_feats  = F.normalize(self._get_image_features(images), dim=-1)
        text_feats = F.normalize(self._get_text_features(prompt),  dim=-1)
        sims = (img_feats @ text_feats.T).squeeze(-1)
        return sims.mean().item()

    @torch.no_grad()
    def image_reward_score(self, images: List[Image.Image], prompt: str) -> float:
        """Mean ImageReward score across the batch."""
        scores = []
        for img in images:
            s = self.reward_model.score(prompt, img)
            scores.append(float(s))
        return float(np.mean(scores))

    # ── Aggregate quality metrics ───────────────────────────

    @torch.no_grad()
    def inception_score(self, images: List[Image.Image], splits: int = 5):
        """Inception Score over a set of images.  Returns (mean, std)."""
        _, probs = self._get_inception_outputs(images)
        N = len(probs)
        splits = min(splits, N)
        scores = []
        for k in range(splits):
            part = probs[k * N // splits : (k + 1) * N // splits]
            if len(part) < 2:
                continue
            py = part.mean(0, keepdim=True)
            kl = (part * (part.log() - py.log())).sum(1).mean()
            scores.append(kl.exp().item())
        return (float(np.mean(scores)), float(np.std(scores))) if scores else (0.0, 0.0)

    @torch.no_grad()
    def compute_fid(self, real_images: List[Image.Image],
                    fake_images: List[Image.Image]) -> float:
        """FID between two image sets."""
        real_f, _ = self._get_inception_outputs(real_images)
        fake_f, _ = self._get_inception_outputs(fake_images)
        return self._fid_from_features(real_f, fake_f)

    @staticmethod
    def _fid_from_features(feats1: torch.Tensor, feats2: torch.Tensor) -> float:
        """Compute FID from pre-extracted Inception features."""
        from scipy import linalg
        f1 = feats1.double().numpy()
        f2 = feats2.double().numpy()
        mu1, mu2 = f1.mean(0), f2.mean(0)
        s1 = np.cov(f1, rowvar=False) + np.eye(f1.shape[1]) * 1e-6
        s2 = np.cov(f2, rowvar=False) + np.eye(f2.shape[1]) * 1e-6
        diff = mu1 - mu2
        covmean, _ = linalg.sqrtm(s1 @ s2, disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        return float(diff @ diff + np.trace(s1 + s2 - 2 * covmean))

    # ── Combined evaluation ─────────────────────────────────

    def evaluate_batch(self, images, prompt):
        return {
            "lpips_diversity": self.pairwise_lpips(images),
            "clip_score":      self.clip_score(images, prompt),
            "clip_diversity":  self.clip_diversity(images),
            "image_reward":    self.image_reward_score(images, prompt),
        }

Patched transformers compatibility for ImageReward


## 4. Benchmark Prompts

In [None]:
DIVERSITY_PROMPTS = [
    "a cat sitting",
    "a dog playing in a field",
    "a red sports car",
    "a bouquet of flowers in a vase",
    "a bird flying over the ocean",
    "a beautiful mountain landscape",
    "a city street at night with neon lights",
    "a cozy living room with a fireplace",
    "a tropical beach at sunset",
    "a forest path in autumn",
    "portrait of a woman with curly hair",
    "a man walking in the rain with an umbrella",
    "a child playing in a park",
    "an oil painting of a castle on a hill",
    "a photograph of a steaming coffee cup",
    "a watercolor painting of a sailboat on a lake",
    "a futuristic city skyline",
    "a magical forest with glowing mushrooms",
    "a steampunk robot in a workshop",
    "an astronaut floating in outer space",
]

GENEVAL_PROMPTS = [
    "a blue cube on a white surface",
    "a red sphere on a gray background",
    "a green pyramid on a wooden table",
    "a yellow cylinder on a dark surface",
    "a red cube and a blue sphere",
    "a cat and a dog sitting together",
    "a car next to a tree",
    "an apple and a banana on a plate",
    "a cat sitting on a wooden table",
    "a ball under a red chair",
    "a bird perched above a birdhouse",
    "a bicycle in front of a brick building",
    "two cats sleeping on a couch",
    "three red apples on a wooden table",
    "four colorful birds sitting on a wire",
    "a red car and a blue bicycle on a street",
    "a green frog and an orange butterfly in a garden",
    "a purple hat on a brown wooden table",
    "a white cat and a black dog on grass",
    "a pink flower in a blue vase",
    "a beautiful cinematic shark",
    "a majestic knight standing in a mystrical forest, cinematic lighting, highly detailed",
    "a beautiful cartoonlike village full of trees and houses."
]

ALL_PROMPTS = DIVERSITY_PROMPTS + GENEVAL_PROMPTS
print(f"Total prompts: {len(ALL_PROMPTS)}  "
      f"(diversity: {len(DIVERSITY_PROMPTS)}, geneval: {len(GENEVAL_PROMPTS)})")

Total prompts: 40  (diversity: 20, geneval: 20)


## 5. Configuration

**Four runs:**
| # | Model | Method | CFG | Purpose |
|---|---|---|---|---|
| 1 | `runwayml/stable-diffusion-v1-5` | baseline | 7 | Reference — base model with standard CFG |
| 2 | `Lykon/dreamshaper-8` | baseline | 7 | Aligned fine-tune — shows reduced diversity |
| 3 | `Lykon/dreamshaper-8` | naive | 7 | Diversity guidance on noisy latent |
| 4 | `Lykon/dreamshaper-8` | clean_estimate | 7 | **Ours** — diversity guidance on clean estimate |

In [None]:
# ─── Experiment configuration ────────────────────────────────────────

# Each run: (model_id, method, guidance_scale, diversity_scale)
RUNS = [
    ("runwayml/stable-diffusion-v1-5", "baseline",       3.5,  0.0),
    ("Lykon/dreamshaper-8",            "baseline",       7.5,  0.0),
    ("Lykon/dreamshaper-8",            "naive",          7.5, 10.0),
    ("Lykon/dreamshaper-8",            "clean_estimate", 7.5, 10.0),
]

BATCH_SIZE       = 4
NUM_STEPS        = 50
HEIGHT, WIDTH    = 512, 512
SEEDS            = [42, 100, 1234, 9999]   # one per image in the batch
NEGATIVE_PROMPT  = ""

# ─── Diversity guidance settings ─────────────────────────────────────
EARLY_STOP_RATIO  = 0.3
LOSS_TYPE         = "structural"
WEIGHT_SCHEDULE   = "constant"

# ─── Prompt selection ────────────────────────────────────────────────
PROMPTS = ["a majestic knight standing in a mystical forest, cinematic lighting, highly detailed."]
NUM_PROMPTS = 10

if NUM_PROMPTS is not None:
    PROMPTS = PROMPTS[:NUM_PROMPTS]

assert len(SEEDS) == BATCH_SIZE, f"Need {BATCH_SIZE} seeds, got {len(SEEDS)}"
print(f"Will run {len(RUNS)} configurations × {len(PROMPTS)} prompts = "
      f"{len(RUNS) * len(PROMPTS)} generations")

Will run 4 configurations × 1 prompts = 4 generations


## 6. Load Model & Metrics

In [36]:
metrics = DiversityMetrics(device=device)
print("Metrics module ready (models loaded lazily)")

Metrics module ready (models loaded lazily)


## 7. Run Experiment

In [90]:
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)

all_results  = {}   # key → list[dict]
image_store  = {}   # key → list[list[PIL]]

def model_short(model_id):
    return model_id.split("/")[-1]

def run_key(model_id, method, cfg, div_scale):
    return f"{model_short(model_id)}/{method}/CFG={cfg}/λ={div_scale}"

# Group runs by model, preserving the order defined in RUNS
from collections import OrderedDict
runs_by_model = OrderedDict()
for r in RUNS:
    runs_by_model.setdefault(r[0], []).append(r)

for model_id, group in runs_by_model.items():
    mname = model_short(model_id)
    print(f"\n{'═' * 70}")
    print(f"  MODEL: {model_id}")
    print(f"{'═' * 70}")

    pipeline = DiversityGuidedPipeline(model_id, device=device)

    for model_id_r, method, cfg, div_scale in group:
        key = run_key(model_id_r, method, cfg, div_scale)
        print(f"\n{'━' * 70}")
        print(f"  {key}   early_stop={EARLY_STOP_RATIO}")
        print(f"{'━' * 70}")

        results_list = []
        images_list  = []

        for pi, prompt in enumerate(PROMPTS):
            out = pipeline.generate(
                prompt=prompt,
                batch_size=BATCH_SIZE,
                num_inference_steps=NUM_STEPS,
                guidance_scale=cfg,
                diversity_scale=div_scale,
                early_stop_ratio=EARLY_STOP_RATIO,
                method=method,
                loss_type=LOSS_TYPE,
                weight_schedule=WEIGHT_SCHEDULE,
                seeds=SEEDS,
                height=HEIGHT, width=WIDTH,
                negative_prompt=NEGATIVE_PROMPT,
            )

            m = metrics.evaluate_batch(out["images"], prompt)
            m["prompt"] = prompt
            m["diversity_losses"] = out["diversity_losses"]
            results_list.append(m)
            images_list.append(out["images"])

            print(f"  [{pi+1:2d}/{len(PROMPTS)}] "
                  f"LPIPS={m['lpips_diversity']:.4f}  "
                  f"CLIP={m['clip_score']:.4f}  "
                  f"CLIPdiv={m['clip_diversity']:.4f}  "
                  f"IR={m['image_reward']:.4f}  │ {prompt[:40]}")

        all_results[key] = results_list
        image_store[key] = images_list

    del pipeline
    gc.collect()
    torch.cuda.empty_cache()
    print(f"\n  Freed {mname} from GPU")

# save raw results
with open(output_dir / "results.json", "w") as f:
    json.dump(all_results, f, indent=2, default=str)
print(f"\nResults saved to {output_dir / 'results.json'}")


══════════════════════════════════════════════════════════════════════
  MODEL: runwayml/stable-diffusion-v1-5
══════════════════════════════════════════════════════════════════════


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Loaded runwayml/stable-diffusion-v1-5  prediction_type=epsilon

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  stable-diffusion-v1-5/baseline/CFG=3.5/λ=0.0   early_stop=0.4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


baseline:   0%|          | 0/50 [00:00<?, ?it/s]

  [ 1/1] LPIPS=0.6498  CLIP=0.3392  CLIPdiv=0.1518  IR=0.3464  │ a majestic knight standing in a mystical

  Freed stable-diffusion-v1-5 from GPU

══════════════════════════════════════════════════════════════════════
  MODEL: Lykon/dreamshaper-8
══════════════════════════════════════════════════════════════════════


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Loaded Lykon/dreamshaper-8  prediction_type=epsilon

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  dreamshaper-8/baseline/CFG=7.5/λ=0.0   early_stop=0.4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


baseline:   0%|          | 0/50 [00:00<?, ?it/s]

  [ 1/1] LPIPS=0.4571  CLIP=0.3565  CLIPdiv=0.0708  IR=0.9113  │ a majestic knight standing in a mystical

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  dreamshaper-8/naive/CFG=7.5/λ=10.0   early_stop=0.4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


naive:   0%|          | 0/50 [00:00<?, ?it/s]

  [ 1/1] LPIPS=0.5847  CLIP=0.3515  CLIPdiv=0.0729  IR=0.5347  │ a majestic knight standing in a mystical

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  dreamshaper-8/clean_estimate/CFG=7.5/λ=10.0   early_stop=0.4
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━


clean_estimate:   0%|          | 0/50 [00:00<?, ?it/s]

  [ 1/1] LPIPS=0.7822  CLIP=0.3589  CLIPdiv=0.1446  IR=0.6476  │ a majestic knight standing in a mystical

  Freed dreamshaper-8 from GPU

Results saved to outputs/results.json


In [38]:
# ═══════════════════════════════════════════════════════════════
# 7.5  Aggregate Quality Metrics — Inception Score & FID
# ═══════════════════════════════════════════════════════════════

# Save generated images to per-method folders
for key, images_list in image_store.items():
    method_dir = output_dir / key.replace("/", "_").replace("=", "")
    method_dir.mkdir(parents=True, exist_ok=True)
    idx = 0
    for batch in images_list:
        for img in batch:
            img.save(method_dir / f"{idx:04d}.png")
            idx += 1
    print(f"Saved {idx} images -> {method_dir}")

# ── Inception Score (no reference needed) ──
print("\n── Inception Score (per method) ──")
aggregate_metrics = {}
for key, images_list in image_store.items():
    all_imgs = [img for batch in images_list for img in batch]
    is_mean, is_std = metrics.inception_score(all_imgs)
    aggregate_metrics[key] = {"is_mean": is_mean, "is_std": is_std}
    print(f"  {key:<50}  IS = {is_mean:.2f} +/- {is_std:.2f}")

# ── FID (requires reference real images) ──
FID_REF_DIR = Path("outputs/val2017")

if FID_REF_DIR.exists():
    ref_paths = sorted(FID_REF_DIR.glob("*.jpg"))[:2048]
    ref_images = [Image.open(p).convert("RGB")
                  for p in tqdm(ref_paths, desc="Loading FID reference")]
    print(f"Loaded {len(ref_images)} reference images")

    ref_feats, _ = metrics._get_inception_outputs(ref_images)

    print("\n── FID (per method, vs COCO val2017) ──")
    for key, images_list in image_store.items():
        all_imgs = [img for batch in images_list for img in batch]
        fake_feats, _ = metrics._get_inception_outputs(all_imgs)
        fid_val = metrics._fid_from_features(ref_feats, fake_feats)
        aggregate_metrics[key]["fid"] = fid_val
        print(f"  {key:<50}  FID = {fid_val:.2f}")

    del ref_images, ref_feats
    torch.cuda.empty_cache()
else:
    print(f"\nFID skipped -- reference images not found at {FID_REF_DIR}")
    print("  To enable FID, run:")
    print("  !wget http://images.cocodataset.org/zips/val2017.zip -O outputs/val2017.zip")
    print("  !unzip -q outputs/val2017.zip -d outputs/ && rm outputs/val2017.zip")

Saved 4 images -> outputs/dreamshaper-8_baseline_λ0.0
Saved 4 images -> outputs/dreamshaper-8_naive_λ10
Saved 4 images -> outputs/dreamshaper-8_clean_estimate_λ10

── Inception Score (per method) ──
Inception-v3 loaded for IS/FID
  dreamshaper-8/baseline/λ=0.0                        IS = 0.00 +/- 0.00
  dreamshaper-8/naive/λ=10                            IS = 0.00 +/- 0.00
  dreamshaper-8/clean_estimate/λ=10                   IS = 0.00 +/- 0.00

FID skipped -- reference images not found at outputs/val2017
  To enable FID, run:
  !wget http://images.cocodataset.org/zips/val2017.zip -O outputs/val2017.zip
  !unzip -q outputs/val2017.zip -d outputs/ && rm outputs/val2017.zip


In [None]:
has_fid = any("fid" in aggregate_metrics.get(k, {}) for k in all_results)

kw = 55
cols = f"{'Run':<{kw}} {'LPIPS':>7} {'CLIP':>7} {'CLIPdiv':>8} {'ImgRwd':>7} {'IS':>10}"
if has_fid:
    cols += f"  {'FID':>8}"
sep = "=" * len(cols)

print(sep)
print(cols)
print(f"{'':>{kw}} {'div':>7} {'align':>7} {'div':>8} {'align':>7} {'qual':>10}", end="")
if has_fid:
    print(f"  {'qual':>8}", end="")
print()
print("-" * len(cols))

for key, results in all_results.items():
    lp = np.mean([r["lpips_diversity"] for r in results])
    cs = np.mean([r["clip_score"]      for r in results])
    cd = np.mean([r["clip_diversity"]   for r in results])
    ir = np.mean([r["image_reward"]    for r in results])

    agg = aggregate_metrics.get(key, {})
    is_str = f"{agg['is_mean']:.2f}+/-{agg.get('is_std',0):.2f}" if "is_mean" in agg else "--"

    row = f"{key:<{kw}} {lp:>7.4f} {cs:>7.4f} {cd:>8.4f} {ir:>7.4f} {is_str:>10}"
    if has_fid:
        fid_str = f"{agg['fid']:.1f}" if "fid" in agg else "--"
        row += f"  {fid_str:>8}"
    print(row)

print(sep)
print()
print("Legend:  LPIPS/CLIPdiv = diversity (higher is better)")
print("        CLIP/ImgRwd   = alignment (higher is better)")
print("        IS            = quality   (higher is better)")
if has_fid:
    print("        FID           = quality   (lower  is better)")

## 9. Pareto Curve — Diversity vs. Quality

In [None]:
MARKERS = {"baseline": "o", "naive": "s", "clean_estimate": "^"}
COLORS  = {"baseline": "#d62728", "naive": "#1f77b4", "clean_estimate": "#2ca02c"}
# Distinguish models by fill style
MODEL_FILL = {"stable-diffusion-v1-5": "full", "dreamshaper-8": "none"}

fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

for ax, div_key, xlabel in [
    (axes[0], "lpips_diversity",  "LPIPS Diversity  (higher → more diverse)"),
    (axes[1], "clip_diversity",   "CLIP Feature Diversity  (higher → more diverse)"),
]:
    for key, results in all_results.items():
        parts = key.split("/")
        mname, method = parts[0], parts[1]
        divs   = [r[div_key]      for r in results]
        clips  = [r["clip_score"] for r in results]
        mk = MARKERS.get(method, "x")
        cl = COLORS.get(method, "gray")
        fill = MODEL_FILL.get(mname, "full")
        fc = cl if fill == "full" else "none"

        ax.scatter(divs, clips, marker=mk, facecolors=fc,
                   edgecolors=cl, s=18, alpha=0.25)
        ax.scatter(np.mean(divs), np.mean(clips), marker=mk,
                   facecolors=fc, edgecolors="k", s=120,
                   linewidths=0.8, zorder=5, label=key)

    ax.set_xlabel(xlabel, fontsize=11)
    ax.set_ylabel("CLIP Score  (higher → better alignment)", fontsize=11)
    ax.legend(fontsize=6, loc="best")
    ax.grid(True, alpha=0.3)

fig.suptitle("Diversity vs Quality — SD1.5 (filled) vs DreamShaper-8 (hollow)",
             fontsize=13)
plt.tight_layout()
plt.savefig(output_dir / "pareto.png", dpi=150, bbox_inches="tight")
plt.show()

## 10. Diversity Loss Curves

In [None]:
fig, ax = plt.subplots(figsize=(8, 4.5))

LINESTYLES = {"stable-diffusion-v1-5": "-", "dreamshaper-8": "--"}

for key, results in all_results.items():
    curves = [r.get("diversity_losses", []) for r in results]
    if not any(curves):
        continue
    max_len = max(len(c) for c in curves)
    padded  = [c + [np.nan]*(max_len - len(c)) for c in curves]
    avg     = np.nanmean(padded, axis=0)
    parts   = key.split("/")
    mname, method = parts[0], parts[1]
    ax.plot(avg, label=key, color=COLORS.get(method, "gray"),
            linestyle=LINESTYLES.get(mname, "-"))

ax.set_xlabel("Guided step index")
ax.set_ylabel("Pairwise similarity  (lower → more diverse)")
ax.set_title("Diversity Loss During Guided Steps")
ax.legend(fontsize=6)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_dir / "loss_curves.png", dpi=150, bbox_inches="tight")
plt.show()

## 11. Per-Prompt Bar Chart

In [None]:
keys = list(all_results.keys())
n_prompts = len(PROMPTS)
short = [p[:25]+"…" if len(p)>25 else p for p in PROMPTS]

fig, ax = plt.subplots(figsize=(max(10, n_prompts*0.6), 5))
x = np.arange(n_prompts)
w = 0.8 / len(keys)

# hatch patterns to distinguish models
MODEL_HATCH = {"stable-diffusion-v1-5": "", "dreamshaper-8": "//"}

for mi, key in enumerate(keys):
    vals = [r["lpips_diversity"] for r in all_results[key]]
    parts = key.split("/")
    mname, method = parts[0], parts[1]
    ax.bar(x + mi*w, vals, w, label=key,
           color=COLORS.get(method, "gray"), alpha=0.85,
           hatch=MODEL_HATCH.get(mname, ""))

ax.set_xticks(x + w*(len(keys)-1)/2)
ax.set_xticklabels(short, rotation=45, ha="right", fontsize=7)
ax.set_ylabel("LPIPS Diversity ↑")
ax.set_title("Per-Prompt Diversity (hatched = DreamShaper-8)")
ax.legend(fontsize=5, ncol=2)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(output_dir / "per_prompt.png", dpi=150, bbox_inches="tight")
plt.show()

## 12. Image Grids — Visual Comparison

In [None]:
def show_grid(prompt_idx=0, selected_keys=None):
    """
    Show images for one prompt, matching the reference layout:
      - Title bar at top with prompt + config
      - Left text column with method name + metrics
      - Image columns to the right
    """
    if selected_keys is None:
        selected_keys = list(image_store.keys())

    n_rows = len(selected_keys)
    bs = len(image_store[selected_keys[0]][prompt_idx])
    prompt_text = all_results[selected_keys[0]][prompt_idx]["prompt"]

    # Layout: label column (narrow) + bs image columns
    label_w = 2.5   # inches for text labels
    img_sz  = 3.0   # inches per image cell
    fig_w = label_w + img_sz * bs
    fig_h = img_sz * n_rows + 0.8  # +title space

    fig = plt.figure(figsize=(fig_w, fig_h), dpi=120)
    gs = fig.add_gridspec(n_rows, bs + 1,
                          width_ratios=[label_w] + [img_sz] * bs,
                          wspace=0.05, hspace=0.12)

    # ── Title ──
    seeds_str = ", ".join(str(s) for s in SEEDS)
    fig.suptitle(f'Prompt: "{prompt_text}"\nBatch={bs}  Steps={NUM_STEPS}  Seeds=[{seeds_str}]',
                 fontsize=12, fontweight="bold", y=0.98, va="top")

    # ── Image rows ──
    for row, key in enumerate(selected_keys):
        imgs = image_store[key][prompt_idx]
        m = all_results[key][prompt_idx]
        parts = key.split("/")
        mname, method = parts[0], parts[1]
        color = COLORS.get(method, "gray")

        # Left label cell
        ax_label = fig.add_subplot(gs[row, 0])
        ax_label.axis("off")
        label_text = (f"{mname}\n"
                      f"{method}\n\n"
                      f"LPIPS: {m['lpips_diversity']:.3f}\n"
                      f"CLIP:  {m['clip_score']:.3f}\n"
                      f"IR:    {m['image_reward']:.3f}")
        ax_label.text(0.95, 0.5, label_text, transform=ax_label.transAxes,
                      fontsize=8, fontweight="bold", color=color,
                      va="center", ha="right", family="monospace",
                      bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
                                edgecolor=color, alpha=0.9))

        # Image cells
        for col in range(bs):
            ax = fig.add_subplot(gs[row, col + 1])
            ax.imshow(imgs[col], interpolation="lanczos")
            ax.set_xticks([]); ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_edgecolor(color)
                spine.set_linewidth(2)

    plt.savefig(output_dir / f"grid_{prompt_idx:03d}.png",
                dpi=150, bbox_inches="tight")
    plt.show()


for i in range(min(3, len(PROMPTS))):
    show_grid(i)

## 13. Scale Ablation Plot

In [None]:
# ── Bar chart comparing the 4 runs across all metrics ──

keys = list(all_results.keys())
metric_names = ["lpips_diversity", "clip_score", "clip_diversity", "image_reward"]
metric_labels = ["LPIPS Diversity ↑", "CLIP Score ↑", "CLIP Diversity ↑", "ImageReward ↑"]

fig, axes = plt.subplots(1, 4, figsize=(18, 5))
x = np.arange(len(keys))

for ax, mname, mlabel in zip(axes, metric_names, metric_labels):
    vals = [np.mean([r[mname] for r in all_results[k]]) for k in keys]
    parts_list = [k.split("/") for k in keys]
    colors = [COLORS.get(p[1], "gray") for p in parts_list]
    hatches = ["" if p[0] == "stable-diffusion-v1-5" else "//" for p in parts_list]

    bars = ax.bar(x, vals, color=colors, alpha=0.85)
    for bar, h in zip(bars, hatches):
        bar.set_hatch(h)

    ax.set_xticks(x)
    ax.set_xticklabels([k.replace("/", "\n") for k in keys], fontsize=6)
    ax.set_ylabel(mlabel)
    ax.grid(axis="y", alpha=0.3)

fig.suptitle("Metric Comparison Across All 4 Runs (hatched = DreamShaper-8)", fontsize=13)
plt.tight_layout()
plt.savefig(output_dir / "metric_comparison.png", dpi=150, bbox_inches="tight")
plt.show()

## Early-Stop Ratio Ablation

This ablation studies the effect of `early_stop_ratio` (k) on diversity and quality for both models.

In [None]:
ABLATION_RATIOS  = [0.1, 0.2, 0.5, 1.0]
ABLATION_SCALE   = 10.0
ABLATION_CFG     = 7.0
ABLATION_MODEL   = "Lykon/dreamshaper-8"
ABLATION_PROMPTS = PROMPTS[:5]

ratio_results = {}

print(f"Early-Stop Ablation on {ABLATION_MODEL}")
abl_pipeline = DiversityGuidedPipeline(ABLATION_MODEL, device=device)

for ratio in ABLATION_RATIOS:
    key = f"clean_estimate/k={ratio}"
    print(f"\n--- {key} ---")
    res = []
    for pi, prompt in enumerate(ABLATION_PROMPTS):
        out = abl_pipeline.generate(
            prompt=prompt, batch_size=BATCH_SIZE,
            num_inference_steps=NUM_STEPS,
            guidance_scale=ABLATION_CFG,
            diversity_scale=ABLATION_SCALE,
            early_stop_ratio=ratio,
            method="clean_estimate",
            loss_type=LOSS_TYPE,
            seeds=SEEDS, height=HEIGHT, width=WIDTH,
        )
        m = metrics.evaluate_batch(out["images"], prompt)
        res.append(m)
        print(f"  [{pi+1}] LPIPS={m['lpips_diversity']:.4f} "
              f"CLIP={m['clip_score']:.4f} IR={m['image_reward']:.4f}")
    ratio_results[key] = res

del abl_pipeline
gc.collect()
torch.cuda.empty_cache()

# plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
ratios = ABLATION_RATIOS
lpips_vals = [np.mean([r["lpips_diversity"] for r in ratio_results[f"clean_estimate/k={k}"]]) for k in ratios]
clip_vals  = [np.mean([r["clip_score"]      for r in ratio_results[f"clean_estimate/k={k}"]]) for k in ratios]
ir_vals    = [np.mean([r["image_reward"]    for r in ratio_results[f"clean_estimate/k={k}"]]) for k in ratios]

axes[0].plot(ratios, lpips_vals, "-o", color="#2ca02c")
axes[0].set_xlabel("early_stop_ratio (k)"); axes[0].set_ylabel("LPIPS Diversity")
axes[0].set_title("Diversity vs. Early-Stop Ratio")
axes[0].grid(True, alpha=0.3)

axes[1].plot(ratios, clip_vals, "-o", color="#1f77b4")
axes[1].set_xlabel("early_stop_ratio (k)"); axes[1].set_ylabel("CLIP Score")
axes[1].set_title("CLIP Alignment vs. Early-Stop Ratio")
axes[1].grid(True, alpha=0.3)

axes[2].plot(ratios, ir_vals, "-o", color="#d62728")
axes[2].set_xlabel("early_stop_ratio (k)"); axes[2].set_ylabel("ImageReward")
axes[2].set_title("ImageReward vs. Early-Stop Ratio")
axes[2].grid(True, alpha=0.3)

fig.suptitle(f"Early-Stop Ablation — {model_short(ABLATION_MODEL)} (λ={ABLATION_SCALE}, CFG={ABLATION_CFG})",
             fontsize=13)
plt.tight_layout()
plt.savefig(output_dir / "early_stop_ablation.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Free GPU memory if needed
# del pipeline, metrics
# gc.collect()
# torch.cuda.empty_cache()

print("Done! All plots & images saved to:", output_dir)