In [3]:
#!/usr/bin/env python3
"""
iterative_vitl14_pipeline.py

Single-model pipeline (ViT-L-14 OpenCLIP) + iterative LLM-driven refinement.
- Parse user prompt with LLM to produce a search plan (images/videos/audio counts and cleaned query)
- Fetch metadata from Pexels (images+videos) and Freesound (audio)
- Preprocess thumbnails/previews in-memory -> CLIP-ready tensors
- Embed prompt + assets with ViT-L-14 (OpenCLIP) and compute cosine similarity
- If not enough assets have score >= threshold, call LLM refinement prompt to suggest query refinements
  and fetch more, up to max_iter.

Environment variables:
  OPENAI_API_KEY, PEXELS_API_KEY, FREESOUND_API_KEY (optional)
Optional config via env:
  MAX_ITER (default 4), THRESHOLD (default 0.5), PER_ITER_FETCH (default 5)
"""

import os
import io
import json
import time
import requests
import numpy as np
from PIL import Image
from typing import List, Dict, Optional, Tuple, Set
from dotenv import load_dotenv
from tqdm import tqdm

# audio libs
import soundfile as sf
import librosa
import ffmpeg

# ML libs
import torch
import open_clip

# Load .env (if present)
load_dotenv()

# -------------------------
# Config / env
# -------------------------
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PEXELS_KEY = os.getenv("PEXELS_API_KEY")
FREESOUND_KEY = os.getenv("FREESOUND_API_KEY")

MAX_ITER = int(os.getenv("MAX_ITER", "4"))
THRESHOLD = float(os.getenv("THRESHOLD", "0.2"))
PER_ITER_FETCH = int(os.getenv("PER_ITER_FETCH", "5"))

if not OPENAI_API_KEY:
    raise RuntimeError("Set OPENAI_API_KEY in environment (or load .env with load_dotenv()).")
if not PEXELS_KEY:
    print("[warn] No PEXELS_API_KEY set; Pexels fetcher will return empty.")
if not FREESOUND_KEY:
    print("[warn] No FREESOUND_API_KEY set; Freesound fetcher will return empty.")

# -------------------------
# LLM system prompts
# -------------------------
# Primary planning prompt (one-shot): instruct LLM to return JSON only.
PLANNER_SYSTEM_PROMPT = """You are a planning assistant that turns an unstructured user prompt
into a short, clean search plan for fetching media assets (images, videos, audio).
Return JSON ONLY with exactly these keys:
- clean_prompt: string (a concise search query suitable for stock APIs)
- num_images: integer (0-10)
- num_videos: integer (0-10)
- num_audio: integer (0-10)
- notes: string (brief reason for the choices)

Methods to follow:
1) Always consider all three media types. Decide which are relevant and allocate counts accordingly.
2) If the user explicitly asks for a medium (e.g., 'background music'), favor that medium.
3) For visual scenes (people, actions, landscapes), include images and at least 1 video.
4) For auditory prompts, include audio (SFX or music).
5) Choose counts for diversity (broad prompts -> more results; specific prompts -> fewer).
6) Clamp counts to 0..10. Output integers.
7) Output JSON only, no extra text.
"""

# Refinement prompt: given previous results and which assets were below threshold,
# ask the LLM to refine or propose new search queries / modifiers.
REFINER_SYSTEM_PROMPT = """You are an assistant that refines search queries to improve result relevance.
You are given:
- the original user prompt
- the clean_prompt used previously
- a short list of example titles/descriptions (or failures) that had low relevance
Your job: provide JSON ONLY with:
- action: one of ["refine_query","expand_query","suggest_filters","stop"]
- queries: list of 1-4 alternative or refined search query strings to try (can include modifiers like "close-up", "studio", "child", "man speaking", "conversation", "two people")
- per_query_counts: list of integers same length as queries, indicating how many assets to fetch per query (1-10)
- notes: short reasoning

Guidelines:
- Prefer small semantic changes that increase the chance of matching the intent (e.g. add "close-up", specify age/gender if prompt implies it).
- If the current results look OK, return action "stop".
- Output JSON only.
"""

# Generic helper: call OpenAI chat completions via HTTP
def call_openai_chat(messages: List[Dict], model: str="gpt-4o-mini", max_tokens:int=300, temperature:float=0.0) -> str:
    url = "https://api.openai.com/v1/chat/completions"
    headers = {"Authorization": f"Bearer {OPENAI_API_KEY}", "Content-Type": "application/json"}
    payload = {"model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens}
    r = requests.post(url, headers=headers, json=payload, timeout=30)
    r.raise_for_status()
    data = r.json()
    # robustly extract assistant content
    try:
        return data["choices"][0]["message"]["content"]
    except Exception:
        return json.dumps(data)

def parse_plan_via_llm(raw_prompt: str) -> Dict:
    user = f"User prompt: {raw_prompt}\nReturn JSON only."
    text = call_openai_chat([{"role":"system","content":PLANNER_SYSTEM_PROMPT},{"role":"user","content":user}])
    # extract JSON blob
    if "{" in text and "}" in text:
        try:
            s = text.index("{"); e = text.rindex("}") + 1
            return json.loads(text[s:e])
        except Exception:
            pass
    # fallback: try to parse entire text
    try:
        return json.loads(text)
    except Exception:
        # fallback heuristic
        return {"clean_prompt": raw_prompt, "num_images": 3, "num_videos": 1, "num_audio": 0, "notes":"fallback heuristics"}

def refine_queries_via_llm(original_prompt: str, prev_clean: str, low_examples: List[str]) -> Dict:
    """
    Ask LLM to refine queries given a list of low-relevance example titles/descriptions.
    Returns dict with action, queries(list), per_query_counts(list), notes.
    """
    user_payload = {
        "original_prompt": original_prompt,
        "prev_clean_prompt": prev_clean,
        "low_examples": low_examples[:10]
    }
    user_text = "Context (JSON):\n" + json.dumps(user_payload) + "\n\nReturn JSON only as specified."
    text = call_openai_chat([{"role":"system","content":REFINER_SYSTEM_PROMPT},{"role":"user","content":user_text}], max_tokens=300)
    if "{" in text and "}" in text:
        try:
            s = text.index("{"); e = text.rindex("}") + 1
            return json.loads(text[s:e])
        except Exception:
            pass
    try:
        return json.loads(text)
    except Exception:
        return {"action":"stop","queries":[],"per_query_counts":[],"notes":"fallback stop"}

# -------------------------
# Fetchers (Pexels & Freesound) - metadata only
# -------------------------
def unified_record(provider, id, typ, title, description, url, thumbnail_url=None,
                   duration=None, uploader=None, published_at=None, tags=None, raw=None):
    return {
        "provider": provider,
        "id": str(id),
        "type": typ,
        "title": title or "",
        "description": description,
        "url": url,
        "thumbnail_url": thumbnail_url,
        "duration": duration,
        "uploader": uploader,
        "published_at": published_at,
        "tags": tags or [],
        "raw_meta": raw or {}
    }

def search_pexels_images(prompt: str, per_page: int = 10) -> List[Dict]:
    if not PEXELS_KEY:
        return []
    endpoint = "https://api.pexels.com/v1/search"
    headers = {"Authorization": PEXELS_KEY}
    params = {"query": prompt, "per_page": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("photos", []):
        out.append(unified_record(
            provider="pexels",
            id=it.get("id"),
            typ="image",
            title=it.get("alt") or "",
            description=None,
            url=(it.get("src") or {}).get("original"),
            thumbnail_url=(it.get("src") or {}).get("medium"),
            duration=None,
            uploader=it.get("photographer"),
            published_at=None,
            tags=[],
            raw=it
        ))
    return out

def search_pexels_videos(prompt: str, per_page: int = 8) -> List[Dict]:
    if not PEXELS_KEY:
        return []
    endpoint = "https://api.pexels.com/videos/search"
    headers = {"Authorization": PEXELS_KEY}
    params = {"query": prompt, "per_page": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("videos", []):
        file_url = None
        for vf in it.get("video_files", []) or []:
            if vf.get("quality") == "hd":
                file_url = vf.get("link"); break
        if not file_url and it.get("video_files"):
            file_url = it.get("video_files")[0].get("link")
        out.append(unified_record(
            provider="pexels",
            id=it.get("id"),
            typ="video",
            title=(it.get("user") or {}).get("name") or str(it.get("id")),
            description=it.get("url"),
            url=file_url,
            thumbnail_url=it.get("image"),
            duration=it.get("duration"),
            uploader=(it.get("user") or {}).get("name"),
            published_at=None,
            tags=[],
            raw=it
        ))
    return out

def search_freesound(prompt: str, per_page: int = 10) -> List[Dict]:
    if not FREESOUND_KEY:
        return []
    endpoint = "https://freesound.org/apiv2/search/text/"
    headers = {"Authorization": f"Token {FREESOUND_KEY}"}
    params = {"query": prompt, "page_size": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("results", []):
        preview = (it.get("previews") or {}).get("preview-hq-mp3") or (it.get("previews") or {}).get("preview-lq-mp3")
        out.append(unified_record(
            provider="freesound",
            id=it.get("id"),
            typ="audio",
            title=it.get("name"),
            description=it.get("description"),
            url=preview,
            thumbnail_url=None,
            duration=it.get("duration"),
            uploader=it.get("username"),
            published_at=it.get("created"),
            tags=it.get("tags") or [],
            raw=it
        ))
    return out

def fetch_assets_for_query(query: str, num_images:int=0, num_videos:int=0, num_audio:int=0) -> List[Dict]:
    results = []
    if num_images > 0:
        results.extend(search_pexels_images(query, per_page=num_images)[:num_images])
    if num_videos > 0:
        results.extend(search_pexels_videos(query, per_page=num_videos)[:num_videos])
    if num_audio > 0:
        results.extend(search_freesound(query, per_page=num_audio)[:num_audio])
    return results

# -------------------------
# Download + preprocess (in-memory)
# -------------------------
def download_bytes(url: str, timeout=20) -> Optional[bytes]:
    if not url:
        return None
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return r.content
    except Exception as e:
        # don't spam, just warn
        # print("[warn] download failed:", e)
        return None

def download_image_pil(url: str, timeout=15) -> Optional[Image.Image]:
    b = download_bytes(url, timeout=timeout)
    if not b:
        return None
    try:
        img = Image.open(io.BytesIO(b)).convert("RGB")
        return img
    except Exception:
        return None

def center_crop_and_resize_pil(img: Image.Image, size:int=224) -> Image.Image:
    w,h = img.size
    m = min(w,h)
    left = (w-m)//2
    top = (h-m)//2
    img = img.crop((left, top, left+m, top+m))
    img = img.resize((size,size), Image.LANCZOS)
    return img

def pil_to_clip_tensor(img: Image.Image, size:int=224, normalize:bool=True) -> Optional[np.ndarray]:
    if img is None:
        return None
    img = center_crop_and_resize_pil(img, size=size)
    arr = np.asarray(img).astype(np.float32)/255.0  # H W C
    arr = np.transpose(arr,(2,0,1)).copy()  # C H W
    if normalize:
        arr = arr*2.0 - 1.0
    return arr

def decode_audio_bytes_to_waveform(audio_bytes: bytes, target_sr:int=16000, duration:float=5.0) -> Optional[np.ndarray]:
    if not audio_bytes:
        return None
    try:
        bio = io.BytesIO(audio_bytes)
        data, sr = sf.read(bio, dtype='float32')
        if data.ndim > 1:
            data = np.mean(data, axis=1)
        if sr != target_sr:
            data = librosa.resample(data, sr, target_sr)
        desired = int(target_sr * duration)
        if len(data) > desired:
            start = max(0, (len(data)-desired)//2)
            data = data[start:start+desired]
        elif len(data) < desired:
            data = np.concatenate([data, np.zeros(desired - len(data), dtype=np.float32)])
        return data.astype(np.float32)
    except Exception:
        try:
            proc = (
                ffmpeg.input('pipe:0')
                .output('pipe:1', format='f32le', ar=target_sr, ac=1)
                .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
            )
            out, err = proc.communicate(input=audio_bytes)
            if proc.returncode != 0:
                return None
            data = np.frombuffer(out, dtype=np.float32)
            desired = int(target_sr * duration)
            if len(data) > desired:
                start = max(0, (len(data)-desired)//2)
                data = data[start:start+desired]
            elif len(data) < desired:
                data = np.concatenate([data, np.zeros(desired - len(data), dtype=np.float32)])
            return data.astype(np.float32)
        except Exception:
            return None

def waveform_to_mel_image_tensor(wav: np.ndarray, sr:int=16000, n_mels:int=128, n_fft:int=2048, hop_length:int=512, size:int=224) -> Optional[np.ndarray]:
    if wav is None:
        return None
    S = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    S_db = librosa.power_to_db(S, ref=np.max)
    S_min, S_max = S_db.min(), S_db.max()
    S_norm = (S_db - S_min) / (S_max - S_min + 1e-9)
    img_arr = (S_norm * 255.0).astype(np.uint8)
    pil = Image.fromarray(img_arr)
    pil = pil.resize((size,size), Image.LANCZOS).convert("RGB")
    arr = np.asarray(pil).astype(np.float32)/255.0
    arr = np.transpose(arr,(2,0,1)).copy()
    arr = arr*2.0 - 1.0
    return arr

def preprocess_records(records: List[Dict], image_size:int=224, audio_duration:float=5.0, audio_sr:int=16000) -> List[Dict]:
    processed = []
    for r in tqdm(records, desc="preprocess", leave=False):
        rec = dict(r)
        typ = rec["type"]
        if typ == "image":
            url = rec.get("thumbnail_url") or rec.get("url")
            pil = download_image_pil(url)
            rec["img_pil"] = pil
            rec["img_tensor"] = pil_to_clip_tensor(pil, size=image_size) if pil is not None else None
        elif typ == "video":
            url = rec.get("thumbnail_url") or rec.get("url")
            pil = download_image_pil(url)
            rec["img_pil"] = pil
            rec["img_tensor"] = pil_to_clip_tensor(pil, size=image_size) if pil is not None else None
        elif typ == "audio":
            url = rec.get("url")
            audio_bytes = download_bytes(url)
            wav = decode_audio_bytes_to_waveform(audio_bytes, target_sr=audio_sr, duration=audio_duration)
            rec["waveform"] = wav
            rec["audio_mel_tensor"] = waveform_to_mel_image_tensor(wav, sr=audio_sr, size=image_size) if wav is not None else None
        else:
            rec["img_tensor"] = None
            rec["audio_mel_tensor"] = None
        processed.append(rec)
    return processed

# -------------------------
# ViT-L-14 CLIP wrapper (single model)
# -------------------------
class ViTL14Wrapper:
    def __init__(self, model_name="ViT-L-14", pretrained="openai", device:Optional[str]=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[info] Loading ViT-L-14 (OpenCLIP) on {self.device} — this may download >1.7GB if not cached.")
        model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
        self.model = model.to(self.device).eval()
        self.preprocess = preprocess
        self.tokenizer = open_clip.get_tokenizer(model_name)
        try:
            self.text_dim = self.model.text_projection.shape[1]
        except Exception:
            # fallback guess
            self.text_dim = 768
        self.dtype = next(self.model.parameters()).dtype

    @torch.no_grad()
    def embed_images(self, imgs: List[np.ndarray], batch_size:int=8) -> np.ndarray:
        """imgs: list of numpy arrays (C,H,W) float in [-1,1] or [0,1]"""
        if len(imgs) == 0:
            return np.zeros((0, self.model.visual.output_dim), dtype=np.float32)
        embs = []
        for i in range(0, len(imgs), batch_size):
            batch = imgs[i:i+batch_size]
            t = torch.from_numpy(np.stack(batch, axis=0)).to(self.device)
            # if in [-1,1] -> map to [0,1] for model preprocess expectation
            if t.max() <= 1.0 + 1e-6 and t.min() >= -1.0 - 1e-6:
                t = (t + 1.0) / 2.0
            t = t.type(self.dtype)
            img_emb = self.model.encode_image(t)
            img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
            embs.append(img_emb.cpu().numpy())
        return np.vstack(embs).astype(np.float32)

    @torch.no_grad()
    def embed_texts(self, texts: List[str], batch_size:int=16) -> np.ndarray:
        """
        Robust text embedding handling different tokenizer outputs (tensor / list / dict).
        """
        if len(texts) == 0:
            return np.zeros((0, self.text_dim), dtype=np.float32)
        out_embs = []
        for i in range(0, len(texts), batch_size):
            chunk = texts[i:i+batch_size]
            toks = self.tokenizer(chunk)
            # Normalize tokenizer output
            if isinstance(toks, torch.Tensor):
                toks_t = toks.to(self.device)
                try:
                    txt = self.model.encode_text(toks_t)
                except TypeError:
                    txt = self.model.encode_text(input_ids=toks_t)
            elif isinstance(toks, (list, tuple, np.ndarray)):
                toks_t = torch.tensor(toks, device=self.device)
                try:
                    txt = self.model.encode_text(toks_t)
                except TypeError:
                    txt = self.model.encode_text(input_ids=toks_t)
            elif isinstance(toks, dict):
                dict_t = {}
                for k,v in toks.items():
                    if isinstance(v, torch.Tensor):
                        dict_t[k] = v.to(self.device)
                    else:
                        try:
                            dict_t[k] = torch.tensor(v, device=self.device)
                        except Exception:
                            pass
                # if single tensor value -> positional
                tensor_vals = [v for v in dict_t.values() if isinstance(v, torch.Tensor)]
                if len(tensor_vals) == 1:
                    try:
                        txt = self.model.encode_text(tensor_vals[0])
                    except TypeError:
                        txt = self.model.encode_text(input_ids=tensor_vals[0])
                else:
                    try:
                        txt = self.model.encode_text(**dict_t)
                    except TypeError:
                        txt = self.model.encode_text(tensor_vals[0]) if tensor_vals else torch.zeros((len(chunk), self.text_dim), device=self.device)
            else:
                raise RuntimeError(f"Unsupported tokenizer return type: {type(toks)}")
            txt = txt / (txt.norm(dim=-1, keepdim=True) + 1e-12)
            out_embs.append(txt.cpu().numpy())
        return np.vstack(out_embs).astype(np.float32)

# -------------------------
# Utilities: scoring and dedup
# -------------------------
def cosine_sim_vecs(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    # a: (N,d), b: (d,) -> (N,)
    if a.size == 0 or b.size == 0:
        return np.zeros((a.shape[0],), dtype=np.float32)
    bnorm = b / (np.linalg.norm(b) + 1e-12)
    an = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-12)
    return (an @ bnorm).reshape(-1)

def dedupe_records(existing_ids: Set[str], new_records: List[Dict]) -> List[Dict]:
    out = []
    for r in new_records:
        if r["provider"] + "_" + r["id"] in existing_ids:
            continue
        out.append(r)
    return out

# -------------------------
# Orchestration: iterative fetching + scoring
# -------------------------
def iterative_pipeline(user_prompt: str, threshold: float = THRESHOLD, per_iter_fetch: int = PER_ITER_FETCH, max_iter:int=MAX_ITER, top_k:int=10):
    # 1) initial plan
    plan = parse_plan_via_llm(user_prompt)
    print("LLM initial plan:", plan)
    # instantiate ViT-L-14 model once
    device = "cuda" if torch.cuda.is_available() else "cpu"
    vit = ViTL14Wrapper(model_name="ViT-L-14", pretrained="openai", device=device)

    # Data stores
    all_records: List[Dict] = []
    seen_ids: Set[str] = set()
    iter_count = 0

    # helper to fetch from plan or queries list
    def fetch_from_plan_entry(clean_q: str, imgs:int, vids:int, auds:int) -> List[Dict]:
        recs = fetch_assets_for_query(clean_q, num_images=imgs, num_videos=vids, num_audio=auds)
        # unify dedupe by provider_id tag
        new = dedupe_records(seen_ids, recs)
        for r in new:
            seen_ids.add(r["provider"] + "_" + r["id"])
        return new

    # initial batch fetch using plan
    initial_recs = fetch_from_plan_entry(plan["clean_prompt"], plan.get("num_images",0), plan.get("num_videos",0), plan.get("num_audio",0))
    all_records.extend(initial_recs)
    print(f"Fetched initial {len(initial_recs)} assets.")

    # Main iterative loop
    while iter_count < max_iter:
        iter_count += 1
        print(f"\n=== Iteration {iter_count} === total candidates so far: {len(all_records)}")
        # preprocess newly fetched ones (we preprocess all; could optimize)
        processed = preprocess_records(all_records, image_size=224)

        # build arrays of embeddable image-like tensors
        img_tensors = []
        img_map_idx = []  # map tensor index -> all_records index
        for idx, rec in enumerate(processed):
            if rec["type"] in ("image","video") and rec.get("img_tensor") is not None:
                img_tensors.append(rec["img_tensor"].astype(np.float32))
                img_map_idx.append(idx)
        # audio -> mel-images also considered as images
        audio_tensors = []
        audio_map_idx = []
        for idx, rec in enumerate(processed):
            if rec["type"] == "audio" and rec.get("audio_mel_tensor") is not None:
                audio_tensors.append(rec["audio_mel_tensor"].astype(np.float32))
                audio_map_idx.append(idx)

        combined_tensors = img_tensors + audio_tensors
        combined_map_idx = img_map_idx + audio_map_idx

        if len(combined_tensors) == 0:
            print("[info] No embeddable items found yet.")
            # attempt to fetch per_iter_fetch new items using the same query
            more = fetch_from_plan_entry(plan["clean_prompt"], per_iter_fetch, 0, 0)
            if not more:
                print("[info] No more items available for same query; attempting LLM refinement.")
                ref = refine_queries_via_llm(user_prompt, plan["clean_prompt"], [r.get("title","") for r in all_records])
                if ref.get("action","stop") == "stop":
                    print("[info] Refiner suggested stop. Exiting.")
                    break
                # else apply queries suggested by ref
                for q, c in zip(ref.get("queries",[]), ref.get("per_query_counts",[])):
                    new = fetch_from_plan_entry(q, int(c), 0, 0)
                    all_records.extend(new)
            else:
                all_records.extend(more)
            continue

        print(f"Embedding {len(combined_tensors)} items using ViT-L-14...")
        emb_imgs = vit.embed_images(combined_tensors, batch_size=8)
        # embed prompt
        txt_emb = vit.embed_texts([plan["clean_prompt"]])[0]
        # compute similarity vector
        sims = cosine_sim_vecs(emb_imgs, txt_emb)

        # attach scores back to records
        for local_i, sim in enumerate(sims):
            rec_idx = combined_map_idx[local_i]
            all_records[rec_idx].setdefault("scores", {})
            all_records[rec_idx]["scores"]["vitl14"] = float(sim)

        # count how many distinct assets exceed threshold
        good_records = [r for r in all_records if r.get("scores",{}).get("vitl14", -999) >= threshold]
        print(f"Found {len(good_records)} assets with score >= {threshold} (threshold).")

        # If enough (we might define 'enough' as sum of initially requested counts)
        desired_total = plan.get("num_images",0) + plan.get("num_videos",0) + plan.get("num_audio",0)
        # but if initial desired_total=0 choose a small target (e.g., 3)
        if desired_total <= 0:
            desired_total = max(3, per_iter_fetch)

        if len(good_records) >= desired_total:
            print(f"Target satisfied: {len(good_records)} >= {desired_total}. Stopping iterations.")
            break

        # Otherwise, call the refiner LLM to propose query refinements
        low_examples = []
        # pick some low-scoring examples to show LLM
        sorted_by_score = sorted(all_records, key=lambda r: r.get("scores",{}).get("vitl14", -999), reverse=True)
        # list titles of bottom K (or those below threshold)
        low = [r for r in sorted_by_score if r.get("scores",{}).get("vitl14", -999) < threshold]
        for r in low[:6]:
            title = r.get("title") or r.get("description") or ""
            low_examples.append(title + " || " + (r.get("thumbnail_url") or r.get("url") or ""))

        print("[info] Requesting query refinements from LLM ...")
        ref = refine_queries_via_llm(user_prompt, plan["clean_prompt"], low_examples)
        action = ref.get("action", "stop")
        print("Refiner action:", action)
        if action == "stop":
            print("[info] Refiner suggested stop or no useful suggestions. Ending.")
            break
        # Apply returned queries
        queries = ref.get("queries", []) or []
        per_counts = ref.get("per_query_counts", []) or []
        # pad per_counts if needed
        if len(per_counts) < len(queries):
            per_counts = per_counts + [per_iter_fetch] * (len(queries)-len(per_counts))
        new_found = 0
        for q, c in zip(queries, per_counts):
            try:
                imgs_to_fetch = int(min(max(1, int(c)), 10))
            except Exception:
                imgs_to_fetch = per_iter_fetch
            new_recs = fetch_from_plan_entry(q, imgs_to_fetch, 0, 0)
            print(f"Fetched {len(new_recs)} for query: {q}")
            new_found += len(new_recs)
            all_records.extend(new_recs)
        if new_found == 0:
            print("[info] Refiner queries returned no new assets. Ending iterations.")
            break
        # optionally update plan.clean_prompt to best refined suggestion
        if queries:
            plan["clean_prompt"] = queries[0]

    # end loop
    # Final scoring summary: sort all_records by vitl14 score desc
    final_with_scores = [r for r in all_records if r.get("scores",{}).get("vitl14", -999) != -999]
    final_sorted = sorted(final_with_scores, key=lambda r: -r["scores"]["vitl14"])
    print("\nFinal top results (ViT-L-14 scores):")
    for i, rr in enumerate(final_sorted[:top_k], start=1):
        print(f"{i}. [{rr['type']}] {rr['provider']} id={rr['id']} score={rr['scores']['vitl14']:.4f}")
        print(f"    title: {rr['title']}")
        print(f"    url  : {rr.get('url')}")
        print()
    return final_sorted

# -------------------------
# CLI entry
# -------------------------
if __name__ == "__main__":
    prompt = input("Enter user prompt: ").strip()
    if not prompt:
        prompt = "a girl talking to a man"
    results = iterative_pipeline(prompt, threshold=THRESHOLD, per_iter_fetch=PER_ITER_FETCH, max_iter=MAX_ITER, top_k=10)


LLM initial plan: {'clean_prompt': 'girl talking to a man', 'num_images': 5, 'num_videos': 2, 'num_audio': 0, 'notes': 'Images are included to capture the scene, and videos provide dynamic context. No audio is needed as the focus is on the visual interaction.'}
[info] Loading ViT-L-14 (OpenCLIP) on cpu — this may download >1.7GB if not cached.
Fetched initial 7 assets.

=== Iteration 1 === total candidates so far: 7


                                                         

Embedding 7 items using ViT-L-14...
Found 0 assets with score >= 0.2 (threshold).
[info] Requesting query refinements from LLM ...
Refiner action: refine_query
Fetched 5 for query: girl talking to a man close-up
Fetched 5 for query: young girl speaking with an adult man
Fetched 2 for query: girl having a conversation with a man
Fetched 4 for query: child talking to a man

=== Iteration 2 === total candidates so far: 23


                                                           

Embedding 23 items using ViT-L-14...
Found 0 assets with score >= 0.2 (threshold).
[info] Requesting query refinements from LLM ...
Refiner action: refine_query
Fetched 0 for query: girl talking to a man close-up
Fetched 5 for query: young girl conversing with an adult man
Fetched 0 for query: girl and man having a conversation
Fetched 4 for query: girl speaking with a man in a casual setting

=== Iteration 3 === total candidates so far: 32


                                                           

Embedding 32 items using ViT-L-14...


KeyboardInterrupt: 