# Hackathon 2 — SDXL + CLIP Reranking

This notebook generates images with **SDXL Base** (optionally **Refiner**) or **SD 3.5 Large**, scores the candidates with **CLIP**, and selects the best match.

### Modes
- **Mode A:** use your prompt as-is.
- **Mode B (optional):** an LLM (via vLLM) converts a long text/article into a concise, SDXL-friendly prompt.

### Quick tips
- Switch `PRESET` to `"fast"` for iteration, `"quality_1min"` for ~1 minute per request, or `"quality_max"` for best quality.
- If it's too slow, set `USE_REFINER = False`.
- If you hit VRAM limits, try `HEIGHT/WIDTH = 768` and `n_images = 2`.

In [None]:
!pip install -q jedi vllm

In [None]:
# Dependencies for out of Collab Run
# Install PyTorch with CUDA first
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

# Then install other dependencies
# pip install -r requirements.txt

## User Settings

In [None]:
# ========== User settings ==========
USE_VLLM = True           # True -> enable auto-prompt via vLLM
USE_REFINER = True        # Refiner improves details but is slower
PRESET = "quality_1min"    # "fast" | "quality_1min" | "quality_max"
USE_TORCH_COMPILE = True # Set True for ~10-30% speedup (slower first run)

OUT_DIR = "images_out"
LOG_JSONL = "images_log.jsonl"

## Imports and Setup

In [None]:
import os, gc, json, re, time, uuid, unicodedata
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

if USE_VLLM:
    os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
    try:
        import multiprocessing as mp
        mp.set_start_method("spawn", force=True)
    except Exception:
        pass

import torch
from transformers import AutoTokenizer, CLIPModel, CLIPProcessor
from diffusers import DiffusionPipeline, StableDiffusion3Pipeline
from PIL import Image

try:
    from google.colab import userdata
except Exception:
    userdata = None

if USE_VLLM:
    from vllm import LLM as VLLM, SamplingParams

print("Imports complete.")

In [None]:
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def gpu_total_gb() -> float:
    if not torch.cuda.is_available():
        return 0.0
    return torch.cuda.get_device_properties(0).total_memory / (1024**3)

def gpu_free_gb() -> float:
    if not torch.cuda.is_available():
        return 0.0
    free, _ = torch.cuda.mem_get_info(0)
    return free / (1024**3)

def gpu_cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def get_gpu_metrics() -> Dict[str, float]:
    if not torch.cuda.is_available():
        return {}
    try:
        free, total = torch.cuda.mem_get_info(0)
        return {
            "gpu_memory_used_gb": round((total - free) / (1024**3), 2),
            "gpu_memory_total_gb": round(total / (1024**3), 2),
        }
    except Exception:
        return {}

if DEVICE == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

print(f"DEVICE: {DEVICE}" + (f" | GPU: {torch.cuda.get_device_name(0)} | VRAM: {gpu_total_gb():.1f} GB" if DEVICE=="cuda" else ""))

In [None]:
# Hugging Face token
HF_TOKEN = None
if userdata is not None:
    try:
        HF_TOKEN = userdata.get("HF_TOKEN")
    except Exception:
        pass
HF_TOKEN = HF_TOKEN or os.environ.get("HF_TOKEN")

if not HF_TOKEN:
    print("⚠️  HF_TOKEN not found. Gated models may fail.")
else:
    print("✅ HF_TOKEN loaded.")

## Configuration

In [None]:
@dataclass
class ImgCfg:
    preset: str = ""
    clip_min: float = 0.22
    n_images: int = 2
    seed_base: int = 42
    steps_base: int = 35
    steps_refine: int = 18
    guidance: float = 7.0
    high_noise_frac: float = 0.78
    height: int = 1024
    width: int = 1024
    out_dir: str = OUT_DIR
    log_jsonl: str = LOG_JSONL
    save_images: bool = True
    max_prompt_chars: int = 900
    style_suffix: str = ", documentary photo, natural light, high detail, sharp focus"
    negative_prompt: str = "text, watermark, logo, brand, nsfw, nude, gore, disfigured, low quality, blurry"

CFG = ImgCfg()
os.makedirs(CFG.out_dir, exist_ok=True)

def apply_preset(cfg: ImgCfg, preset: str) -> ImgCfg:
    p = (preset or "").lower().strip()
    cfg.preset = p
    vram = gpu_total_gb()

    if p == "fast":
        cfg.n_images, cfg.steps_base, cfg.steps_refine = 2, 25, 10
        cfg.guidance, cfg.clip_min = 6.5, 0.20
        cfg.height = cfg.width = 768
    elif p == "quality_1min":
        cfg.n_images = 4 if vram >= 24 else 2
        cfg.steps_base, cfg.steps_refine = 37, 20
        cfg.guidance, cfg.clip_min = 7.0, 0.22
        cfg.height = cfg.width = 1024
    elif p == "quality_max":
        cfg.n_images = 5 if vram >= 40 else (3 if vram >= 24 else 2)
        cfg.steps_base, cfg.steps_refine = 40, 25
        cfg.guidance, cfg.clip_min = 7.5, 0.23
        cfg.height = cfg.width = 1024
    return cfg

CFG = apply_preset(CFG, PRESET)
print(f"CFG: {CFG}")

## vLLM (Mode B) — Optional

In [None]:
llm, llm_tok, loaded_llm_id = None, None, None
LLM_CANDIDATES = ["Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"]

def _choose_llm_dtype() -> str:
    if not torch.cuda.is_available():
        return "float16"
    return "bfloat16" if torch.cuda.get_device_capability(0)[0] >= 8 else "float16"

if USE_VLLM:
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
    dtype = _choose_llm_dtype()
    print(f"vLLM dtype: {dtype}")

    for mid in LLM_CANDIDATES:
        try:
            llm_tok = AutoTokenizer.from_pretrained(mid, token=HF_TOKEN)
            if llm_tok.pad_token is None:
                llm_tok.pad_token = llm_tok.eos_token
            # Llama needs special params for Colab compatibility
            is_llama = "llama" in mid.lower()

            extra_kwargs = {}
            if is_llama:
                extra_kwargs["enforce_eager"] = True        # Disables CUDA graphs
                extra_kwargs["tensor_parallel_size"] = 1    # Single-GPU mode

            llm = VLLM(
                model=mid,
                dtype=dtype,
                max_model_len=2048,
                gpu_memory_utilization=0.35,
                disable_log_stats=True,
                **extra_kwargs,
            )
            loaded_llm_id = mid
            print(f"✅ Loaded vLLM: {mid}")
            break
        except Exception as e:
            print(f"⚠️ vLLM failed: {mid} -> {e}")
    if llm is None:
        print("❌ No LLM loaded.")

## Model Management
Smart loading/unloading to prevent OOM.

In [None]:
_sdxl_base, _sdxl_refiner, _sd35_pipe, _active_model = None, None, None, None

SDXL_BASE_ID = "stabilityai/stable-diffusion-xl-base-1.0"
SDXL_REFINER_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
SD35_ID = "stabilityai/stable-diffusion-3.5-large"

def _get_sd_dtype():
    if DEVICE == "cuda" and torch.cuda.is_bf16_supported():
        return torch.bfloat16
    elif DEVICE == "cuda":
        return torch.float16
    return torch.float32

DTYPE_SD = _get_sd_dtype()

def _low_vram_mode() -> bool:
    return DEVICE == "cuda" and gpu_total_gb() < 24

def unload_diffusion_models(keep: str = None):
    global _sdxl_base, _sdxl_refiner, _sd35_pipe, _active_model
    if keep != "sdxl":
        for m in [_sdxl_base, _sdxl_refiner]:
            if m is not None:
                m.to("cpu")
                del m
        _sdxl_base = _sdxl_refiner = None
    if keep != "sd35" and _sd35_pipe is not None:
        _sd35_pipe.to("cpu")
        del _sd35_pipe
        _sd35_pipe = None
    if keep is None:
        _active_model = None
    gpu_cleanup()

def load_sdxl(use_refiner: bool = True):
    global _sdxl_base, _sdxl_refiner, _active_model
    if _sdxl_base is not None and (_sdxl_refiner is not None or not use_refiner):
        return _sdxl_base, _sdxl_refiner
    if _active_model == "sd35":
        print("Unloading SD3.5...")
        unload_diffusion_models(keep="sdxl")
    gpu_cleanup()

    print("Loading SDXL Base...")
    _sdxl_base = DiffusionPipeline.from_pretrained(
        SDXL_BASE_ID, torch_dtype=DTYPE_SD,
        variant="fp16" if DTYPE_SD == torch.float16 else None,
        use_safetensors=True, token=HF_TOKEN
    ).to(DEVICE)
    _sdxl_base.set_progress_bar_config(disable=True)
    if _low_vram_mode():
        _sdxl_base.enable_attention_slicing()
        _sdxl_base.enable_vae_slicing()

    # Optional: Compile UNet for faster inference
    if USE_TORCH_COMPILE and hasattr(torch, "compile"):
      try:
          _sdxl_base.unet = torch.compile(_sdxl_base.unet, mode="reduce-overhead")
          print("✅ SDXL Base UNet compiled")
      except Exception as e:
          print(f"⚠️ Compilation skipped: {e}")

    if use_refiner:
        print("Loading SDXL Refiner...")
        _sdxl_refiner = DiffusionPipeline.from_pretrained(
            SDXL_REFINER_ID, torch_dtype=DTYPE_SD,
            variant="fp16" if DTYPE_SD == torch.float16 else None,
            use_safetensors=True, token=HF_TOKEN
        ).to(DEVICE)
        _sdxl_refiner.set_progress_bar_config(disable=True)
        if _low_vram_mode():
            _sdxl_refiner.enable_attention_slicing()
            _sdxl_refiner.enable_vae_slicing()

        if USE_TORCH_COMPILE and hasattr(torch, "compile"):
          try:
              _sdxl_refiner.unet = torch.compile(_sdxl_refiner.unet, mode="reduce-overhead")
              print("✅ SDXL Refiner UNet compiled")
          except Exception as e:
              print(f"⚠️ Compilation skipped: {e}")

    _active_model = "sdxl"
    print("✅ SDXL ready.")
    return _sdxl_base, _sdxl_refiner

def load_sd35():
    global _sd35_pipe, _active_model
    if _sd35_pipe is not None:
        return _sd35_pipe
    if _active_model == "sdxl":
        print("Unloading SDXL...")
        unload_diffusion_models(keep="sd35")
    gpu_cleanup()

    print("Loading SD 3.5 Large...")
    _sd35_pipe = StableDiffusion3Pipeline.from_pretrained(
        SD35_ID, torch_dtype=DTYPE_SD, token=HF_TOKEN
    ).to(DEVICE)
    _sd35_pipe.set_progress_bar_config(disable=True)

    # Optional: Compile transformer for faster inference
    if USE_TORCH_COMPILE and hasattr(torch, "compile"):
      try:
        _sd35_pipe.transformer = torch.compile(_sd35_pipe.transformer, mode="reduce-overhead")
        print("✅ SD3.5 transformer compiled")
      except Exception as e:
        print(f"⚠️ Compilation skipped: {e}")
    _active_model = "sd35"
    print("✅ SD 3.5 ready.")

    return _sd35_pipe

print("Model management ready.")

## Generation Functions
With fixed generator handling.

In [None]:
def sdxl_generate_single(prompt: str, seed: int, negative_prompt: str, cfg: ImgCfg, use_refiner: bool = True) -> Image.Image:
    base, refiner = load_sdxl(use_refiner=use_refiner)

    if use_refiner and refiner is not None:
        g_base = torch.Generator(device=DEVICE).manual_seed(int(seed))
        latents = base(
            prompt=prompt, negative_prompt=negative_prompt or None,
            num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance,
            generator=g_base, height=cfg.height, width=cfg.width, output_type="latent"
        ).images

        # FIXED: Fresh generator for refiner
        g_refine = torch.Generator(device=DEVICE).manual_seed(int(seed))
        return refiner(
            prompt=prompt, negative_prompt=negative_prompt or None,
            num_inference_steps=cfg.steps_refine, guidance_scale=cfg.guidance,
            generator=g_refine, image=latents, strength=1.0 - float(cfg.high_noise_frac)
        ).images[0]

    g = torch.Generator(device=DEVICE).manual_seed(int(seed))
    return base(
        prompt=prompt, negative_prompt=negative_prompt or None,
        num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance,
        generator=g, height=cfg.height, width=cfg.width, output_type="pil"
    ).images[0]

def sdxl_generate_batch(prompt: str, seeds: List[int], negative_prompt: str, cfg: ImgCfg, use_refiner: bool = True) -> List[Image.Image]:
    base, refiner = load_sdxl(use_refiner=use_refiner)
    seeds = [int(s) for s in seeds]
    n = len(seeds)
    prompts, negs = [prompt] * n, [negative_prompt] * n if negative_prompt else None

    if use_refiner and refiner is not None:
        gens_base = [torch.Generator(device=DEVICE).manual_seed(s) for s in seeds]
        latents = base(
            prompt=prompts, negative_prompt=negs,
            num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance,
            generator=gens_base, height=cfg.height, width=cfg.width, output_type="latent"
        ).images

        # FIXED: Fresh generators for refiner
        gens_refine = [torch.Generator(device=DEVICE).manual_seed(s) for s in seeds]
        return refiner(
            prompt=prompts, negative_prompt=negs,
            num_inference_steps=cfg.steps_refine, guidance_scale=cfg.guidance,
            generator=gens_refine, image=latents, strength=1.0 - float(cfg.high_noise_frac)
        ).images

    gens = [torch.Generator(device=DEVICE).manual_seed(s) for s in seeds]
    return base(
        prompt=prompts, negative_prompt=negs,
        num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance,
        generator=gens, height=cfg.height, width=cfg.width, output_type="pil"
    ).images

print("SDXL functions defined.")

In [None]:
# NEW: SD 3.5 batch generation
def sd35_generate_single(prompt: str, seed: int, negative_prompt: str, cfg: ImgCfg, use_refiner: bool = False) -> Image.Image:
    pipe = load_sd35()
    gen = torch.Generator(device=DEVICE).manual_seed(int(seed))
    return pipe(
        prompt=prompt, negative_prompt=negative_prompt,
        height=cfg.height, width=cfg.width,
        num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance, generator=gen
    ).images[0]

def sd35_generate_batch(prompt: str, seeds: List[int], negative_prompt: str, cfg: ImgCfg, use_refiner: bool = False) -> List[Image.Image]:
    pipe = load_sd35()
    seeds = [int(s) for s in seeds]
    n = len(seeds)
    prompts, negs = [prompt] * n, [negative_prompt] * n if negative_prompt else None
    gens = [torch.Generator(device=DEVICE).manual_seed(s) for s in seeds]
    return pipe(
        prompt=prompts, negative_prompt=negs,
        height=cfg.height, width=cfg.width,
        num_inference_steps=cfg.steps_base, guidance_scale=cfg.guidance, generator=gens
    ).images

print("SD 3.5 functions defined.")

In [None]:
# Router
def generate_single(prompt: str, seed: int, negative_prompt: str, cfg: ImgCfg, use_refiner: bool = True) -> Image.Image:
    if cfg.preset == "quality_max":
        return sd35_generate_single(prompt, seed, negative_prompt, cfg)
    return sdxl_generate_single(prompt, seed, negative_prompt, cfg, use_refiner)

def generate_batch(prompt: str, seeds: List[int], negative_prompt: str, cfg: ImgCfg, use_refiner: bool = True) -> List[Image.Image]:
    if cfg.preset == "quality_max":
        return sd35_generate_batch(prompt, seeds, negative_prompt, cfg)
    return sdxl_generate_batch(prompt, seeds, negative_prompt, cfg, use_refiner)

print(f"Router: '{PRESET}' -> {'SD3.5' if PRESET == 'quality_max' else 'SDXL'}")

## CLIP Scoring

In [None]:
CLIP_ID = "openai/clip-vit-large-patch14"
_clip_model, _clip_processor, _clip_tokenizer = None, None, None

def get_clip_tokenizer():
    global _clip_tokenizer
    if _clip_tokenizer is None:
        _clip_tokenizer = CLIPProcessor.from_pretrained(CLIP_ID).tokenizer
    return _clip_tokenizer

def load_clip():
    global _clip_model, _clip_processor
    if _clip_model is not None:
        return _clip_model, _clip_processor
    gpu_cleanup()
    _clip_model = CLIPModel.from_pretrained(CLIP_ID).to(DEVICE).eval()
    _clip_processor = CLIPProcessor.from_pretrained(CLIP_ID)
    return _clip_model, _clip_processor

@torch.inference_mode()
def clip_score_batch(text: str, images: List[Image.Image]) -> List[float]:
    model, processor = load_clip()
    text = " ".join((text or "").strip().split())
    if not text or not images:
        return []

    t_inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77).to(DEVICE)
    i_inputs = processor(images=images, return_tensors="pt").to(DEVICE)

    text_feat = model.get_text_features(input_ids=t_inputs["input_ids"], attention_mask=t_inputs.get("attention_mask"))
    img_feat = model.get_image_features(pixel_values=i_inputs["pixel_values"])

    text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
    img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)

    return (img_feat @ text_feat.T).squeeze(-1).float().cpu().tolist()

print("✅ CLIP ready.")

## Prompt Utilities

In [None]:
def normalize_text(text: str) -> str:
    text = unicodedata.normalize("NFKC", text)
    text = re.sub(r"[\u200b-\u200f\u2060\ufeff]", "", text)
    return " ".join(text.lower().split())

def basic_prompt_filter(prompt: str) -> Dict[str, Any]:
    p = normalize_text(prompt or "")
    banned = [r"n[s$]fw", r"nud[e3]", r"nudity", r"p[o0]rn", r"s[e3]x", r"g[o0]r[e3]", r"explicit", r"xxx", r"naked", r"erotic"]
    hits = set()
    for pat in banned:
        for m in re.finditer(rf"\b{pat}\b", p):
            prev = p[max(0, m.start()-4):m.start()]
            if not prev.endswith("no "):
                hits.add(m.group())
    return {"safe": len(hits) == 0, "hits": sorted(hits)}

def normalize_for_generation(prompt: str, cfg: ImgCfg) -> str:
    p = re.sub(r"\s+", " ", (prompt or "").strip())
    if not p:
        return ""
    if cfg.style_suffix and cfg.style_suffix.lower() not in p.lower():
        p = (p + cfg.style_suffix).strip()
    return p[:cfg.max_prompt_chars]

def make_clip_prompt(prompt: str, cfg: ImgCfg, max_tokens: int = 77) -> str:
    tok = get_clip_tokenizer()
    p = re.sub(r"\s+", " ", (prompt or "").strip())
    if cfg.style_suffix:
        suf_lower = cfg.style_suffix.lower()
        if suf_lower in p.lower():
            idx = p.lower().find(suf_lower)
            p = (p[:idx] + p[idx + len(cfg.style_suffix):]).strip(" ,")
    p = " ".join(p.split()[:40])
    ids = tok(p, add_special_tokens=True, truncation=False)["input_ids"]
    if len(ids) > max_tokens:
        p = tok.decode(ids[:max_tokens], skip_special_tokens=True).strip()
    return p

def auto_prompt_from_text(text: str) -> str:
    if not USE_VLLM or llm is None:
        raise RuntimeError("LLM not loaded")
    system = "You are a prompt-writer for an image model."
    user = f"Turn the text into ONE short visual prompt for a documentary photo. Focus on concrete scene. Output only the prompt.\n\nTEXT:\n{text.strip()}"
    sp = SamplingParams(temperature=0.6, top_p=0.9, max_tokens=120)
    msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
    prompt = llm_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    out = llm.generate([prompt], sp)[0].outputs[0].text
    print(f"\n{'═'*60}\n{'AUTO-PROMPT':^60}\n{'═'*60}\n{out.strip()}\n{'═'*60}\n")
    return re.sub(r"\s+", " ", out).strip()

print("Prompt utilities ready.")

## Logging

In [None]:
def ts_id() -> str:
    return time.strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]

def write_log(rec: Dict[str, Any], path: str):
    rec["gpu_metrics"] = get_gpu_metrics()
    rec["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print("Logging ready.")

## Main Pipeline

In [None]:
def run_image_pipeline(*, mode: str, user_prompt: Optional[str] = None, text: Optional[str] = None,
                        item_id: Optional[str] = None, base_seed: Optional[int] = None,
                        cfg: ImgCfg = CFG, use_refiner: bool = USE_REFINER,
                        return_candidates_images: bool = False) -> Dict[str, Any]:
    assert mode in {"user_prompt", "auto_prompt"}
    item_id = item_id or ts_id()
    base_seed = cfg.seed_base if base_seed is None else int(base_seed)

    # 1) Get prompt
    if mode == "user_prompt":
        raw = (user_prompt or "").strip()
        if not raw:
            rec = {"ok": False, "reason": "missing_user_prompt", "mode": mode, "item_id": item_id}
            write_log(rec, cfg.log_jsonl)
            return rec
    else:
        if not (text or "").strip():
            rec = {"ok": False, "reason": "missing_text", "mode": mode, "item_id": item_id}
            write_log(rec, cfg.log_jsonl)
            return rec
        try:
            raw = auto_prompt_from_text(text)
        except Exception as e:
            rec = {"ok": False, "reason": "auto_prompt_failed", "err": str(e)[:300], "mode": mode, "item_id": item_id}
            write_log(rec, cfg.log_jsonl)
            return rec

    # 2) Filter
    f = basic_prompt_filter(raw)
    if not f["safe"]:
        rec = {"ok": False, "reason": "blocked_prompt", "hits": f["hits"], "mode": mode, "item_id": item_id}
        write_log(rec, cfg.log_jsonl)
        return rec

    sd_prompt = normalize_for_generation(raw, cfg)
    clip_prompt = make_clip_prompt(sd_prompt, cfg)
    if not sd_prompt or not clip_prompt:
        rec = {"ok": False, "reason": "empty_prompt", "mode": mode, "item_id": item_id}
        write_log(rec, cfg.log_jsonl)
        return rec

    # 3) Generate
    seeds = [base_seed + i for i in range(cfg.n_images)]
    candidates = [{"seed": int(s)} for s in seeds]
    ok_imgs, ok_indices = [], []

    t0 = time.time()
    try:
        imgs = generate_batch(sd_prompt, seeds, cfg.negative_prompt, cfg, use_refiner)
        ok_imgs, ok_indices = imgs, list(range(len(seeds)))
    except Exception as e:
        print(f"⚠️ Batch failed: {e}, trying sequential...")
        for idx, s in enumerate(seeds):
            try:
                ok_imgs.append(generate_single(sd_prompt, s, cfg.negative_prompt, cfg, use_refiner))
                ok_indices.append(idx)
            except Exception as e2:
                candidates[idx]["error"] = str(e2)[:200]
    gen_s = time.time() - t0

    if not ok_imgs:
        rec = {"ok": False, "reason": "generation_failed", "mode": mode, "item_id": item_id,
               "sd_prompt": sd_prompt, "candidates": candidates}
        write_log(rec, cfg.log_jsonl)
        return rec

    # 4) CLIP score
    t1 = time.time()
    scores = clip_score_batch(clip_prompt, ok_imgs)
    clip_s = time.time() - t1

    for idx, sc in zip(ok_indices, scores):
        candidates[idx]["clip"] = float(sc)

    best_local = max(range(len(scores)), key=lambda i: scores[i])
    best_clip, best_idx = float(scores[best_local]), ok_indices[best_local]
    best_seed, best_img = seeds[best_idx], ok_imgs[best_local]

    if best_clip < cfg.clip_min:
        rec = {"ok": False, "reason": "low_clip_alignment", "best_clip": best_clip, "clip_min": cfg.clip_min,
               "mode": mode, "item_id": item_id, "candidates": candidates}
        write_log(rec, cfg.log_jsonl)
        return rec

    # 5) Save
    img_path = None
    if cfg.save_images:
        img_path = os.path.join(cfg.out_dir, f"{item_id}_seed{best_seed}_clip{best_clip:.3f}.png")
        best_img.save(img_path)

    # Log record (without images - they're not JSON serializable)
    log_rec = {"ok": True, "mode": mode, "item_id": item_id, "sd_prompt": sd_prompt, "clip_prompt": clip_prompt,
               "best_seed": best_seed, "best_clip": best_clip, "img_path": img_path,
               "timing_s": {"gen": round(gen_s, 2), "clip": round(clip_s, 2)}, "candidates": candidates, "preset": cfg.preset}
    write_log(log_rec, cfg.log_jsonl)

    # Return record (with images if requested)
    result = log_rec.copy()
    if return_candidates_images:
        result["candidate_images"] = ok_imgs
        result["best_image"] = best_img

    return result

print("✅ Pipeline ready.")

## Convenience Function

In [None]:
def generate_image(mode: str = "user_prompt", prompt: str = "", text: str = "", show: bool = True, **kwargs) -> Dict[str, Any]:
    # Use return_candidates_images from kwargs if provided, otherwise use show
    return_images = kwargs.pop("return_candidates_images", show)

    result = run_image_pipeline(mode=mode, user_prompt=prompt if mode == "user_prompt" else None,
                                 text=text if mode == "auto_prompt" else None, return_candidates_images=return_images, **kwargs)

    if show and result.get("ok") and "best_image" in result:
        display(result["best_image"])
        print(f"\n✅ Best: seed={result['best_seed']}, CLIP={result['best_clip']:.4f}")
        print(f"   Timing: gen={result['timing_s']['gen']:.1f}s, clip={result['timing_s']['clip']:.2f}s")
        if result.get("img_path"):
            print(f"   Saved: {result['img_path']}")
    elif not result.get("ok"):
        print(f"❌ Failed: {result.get('reason', 'unknown')}")
        if "hits" in result:
            print(f"   Blocked: {result['hits']}")
    return result

print("✅ Ready! Use generate_image() to create images.")

## Examples\n### Mode A: Direct Prompt

In [None]:
# Example: Direct prompt
result = generate_image(prompt="A golden retriever puppy playing in autumn leaves, warm afternoon sunlight")

### Mode B: Auto-Prompt from Text

In [None]:
# Example: Auto-prompt from long text
text = '''
  Climate scientists have observed unprecedented changes in Arctic ice patterns.
  The polar regions are experiencing warming at twice the global average rate,
  leading to significant changes in wildlife habitats and migration patterns.
'''

result = generate_image(mode="auto_prompt", text=text)

### Show All Candidates

In [None]:
# Show all candidates
result = generate_image(prompt="A cozy coffee shop interior with morning light", show=False, return_candidates_images=True)

if result.get("ok") and "candidate_images" in result:
    import matplotlib.pyplot as plt
    n = len(result["candidate_images"])
    fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
    if n == 1: axes = [axes]
    for i, (img, cand) in enumerate(zip(result["candidate_images"], result["candidates"])):
        axes[i].imshow(img)
        axes[i].set_title(f"Seed {cand['seed']}\nCLIP: {cand.get('clip', 0):.4f}")
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()
    print(f"\n✅ Best: seed={result['best_seed']}, CLIP={result['best_clip']:.4f}")