In [None]:
# MARP Phase 1 - Query Expansion & Asset Type/Count Detection
# Clean, easy-to-follow functions intended to be run in a Jupyter notebook cell-by-cell.

# 1) Install dependencies (run in a notebook cell if needed):
# !pip install openai python-dotenv pydantic

from typing import List, Dict, Any, Optional
import os
import json
import re
import time
from dotenv import load_dotenv
import openai
from pydantic import BaseModel, ValidationError, conlist

# ---------------------------
# Configuration & helpers
# ---------------------------

load_dotenv()  # loads variables from .env in project folder

from openai import OpenAI
client = OpenAI(api_key=OPENAI_API_KEY)
# Do NOT hardcode API keys here. Place them in your .env file as:
# OPENAI_API_KEY=sk-...

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if not OPENAI_API_KEY:
    raise EnvironmentError("OPENAI_API_KEY not found in environment. Put it in your .env file.")

openai.api_key = OPENAI_API_KEY

# ---------------------------
# Output schema from the LLM
# ---------------------------

class ExpansionOutput(BaseModel):
    expanded_queries: conlist(str, min_length=8, max_length=20)  # 8-20 queries
    asset_type: str  # one of: images, videos, sound, mixed
    asset_count: int  # total number of assets requested (approx)
    notes: Optional[str] = None

# ---------------------------
# Prompts and LLM call
# ---------------------------

LLM_PROMPT_TEMPLATE = '''
You are an assistant that converts a user's free-form asset request into a structured JSON response.
Given the user's raw prompt, perform 3 tasks ONLY:

1) Produce 8 to 12 high-quality, semantically diverse search queries (short phrases) suitable for searching creative asset providers (images, video, audio). Return these in the array "expanded_queries".

2) Decide the dominant asset type requested: choose one of ["images", "videos", "sound", "mixed"]. Put this in "asset_type".

3) Suggest a sensible number of assets to return (approximate integer) in "asset_count". Keep it realistic (1-50).

Constraints:
- Output must be valid JSON, and follow this schema:
  {{ "expanded_queries": [...], "asset_type": "images|videos|sound|mixed", "asset_count": <int>, "notes": "optional short note" }}
- expanded_queries should be short (3 words to 6 words), focused, and cover style/genre synonyms when relevant.
- Keep temperature low: be deterministic.
- Do not print any commentary or explanation ‚Äî only the JSON.

User prompt: """{user_prompt}"""
'''
def _extract_json_from_text(text: str) -> Dict:
    """Try to find a JSON object in the model output and parse it."""
    # Strip markdown code fences
    if text.startswith("```"):
        parts = text.split("```")
        # pick the largest JSON-looking chunk
        candidates = [p.strip() for p in parts if '{' in p and '}' in p]
        text = candidates[0] if candidates else text

    # Find first { ... } block
    first = text.find('{')
    last = text.rfind('}')
    if first == -1 or last == -1:
        raise ValueError("No JSON object found in model output.")
    json_str = text[first:last+1]

    # Attempt to fix common trailing commas / single quotes by simple replacements
    json_str = json_str.replace("'", '"')
    json_str = re.sub(r",\s*}", "}", json_str)
    json_str = re.sub(r",\s*]", "]", json_str)

    return json.loads(json_str)

def call_llm_for_expansions(user_prompt: str,
                            model: str = "gpt-4o-mini",
                            max_tokens: int = 512,
                            temperature: float = 0.0,
                            retries: int = 1) -> dict:
    """Call the OpenAI Chat Completions via the new client and return parsed JSON."""
    base_prompt = LLM_PROMPT_TEMPLATE.format(user_prompt=user_prompt)

    attempt = 0
    last_raw = None
    while attempt <= retries:
        resp = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": base_prompt}],
            max_tokens=max_tokens,
            temperature=temperature,
            n=1,
        )
        # New client returns choices similarly; pick the first choice's message content
        text = resp.choices[0].message.content.strip()
        last_raw = text

        try:
            data = _extract_json_from_text(text)
        except Exception as e:
            if attempt < retries:
                strict_prompt = (
                    "You previously returned invalid output. Now only output a single valid JSON object, "
                    "matching exactly the schema with keys: expanded_queries (array), asset_type (string), "
                    "asset_count (int), notes (optional string). No extra text. "
                    "USER PROMPT: " + user_prompt
                )
                base_prompt = strict_prompt
                attempt += 1
                time.sleep(0.3)
                continue
            else:
                return {"_raw": last_raw, "_error": f"JSON extraction failed: {e}"}

        if not isinstance(data, dict) or \
           'expanded_queries' not in data or \
           'asset_type' not in data or \
           'asset_count' not in data:
            if attempt < retries:
                base_prompt = (
                    "Output MUST be valid JSON with keys: expanded_queries (array), "
                    "asset_type (string), asset_count (int), notes (optional). "
                    "Do NOT include any commentary. USER PROMPT: " + user_prompt
                )
                attempt += 1
                time.sleep(0.3)
                continue
            else:
                return {"_raw": last_raw, "_error": "Missing required keys in parsed JSON."}

        return data

    return {"_raw": last_raw, "_error": "Exceeded retries without valid JSON."}# ---------------------------
# Public-facing convenience functions
# ---------------------------

def expand_and_detect(user_prompt: str, model: str = 'gpt-4o-mini') -> ExpansionOutput:
    raw = call_llm_for_expansions(user_prompt, model=model)
    if '_raw' in raw:
        # raise with the raw output for easier debugging in notebook
        raise RuntimeError(f"LLM returned unparsable output: {raw.get('_error')}\n\nRAW OUTPUT:\n{raw.get('_raw')}")
    # Validate and coerce types safely before Pydantic
    # Ensure expanded_queries is a list of strings
    if not isinstance(raw.get('expanded_queries'), list):
        raise RuntimeError(f"expanded_queries not a list in LLM output: {raw}")

    # Minimal cleanup: trim queries and deduplicate preserving order
    seen = set()
    cleaned_queries = []
    for q in raw['expanded_queries']:
        if not isinstance(q, str): continue
        s = q.strip()
        if s and s not in seen:
            cleaned_queries.append(s)
            seen.add(s)

    raw['expanded_queries'] = cleaned_queries

    try:
        validated = ExpansionOutput(**raw)
        return validated
    except ValidationError as ve:
        raise RuntimeError(f"Schema validation failed: {ve}\nLLM output was: {json.dumps(raw, indent=2)}")
# ---------------------------
# Small interactive example (run in a notebook cell)
# ---------------------------
if __name__ == '__main__':
    # Example usage when running this script directly
    sample_prompt = "Videos and images of serene mountain landscape at sunrise with mist rising from the valley "
    print("User prompt:\n", sample_prompt)
    out = expand_and_detect(sample_prompt)
    print('\nStructured LLM output:')
    print(out.model_dump_json(indent=2))



User prompt:
 Videos and images of serene mountain landscape at sunrise with mist rising from the valley 

Structured LLM output:
{
  "expanded_queries": [
    "serene mountain sunrise",
    "misty valley landscapes",
    "peaceful mountain scenery",
    "sunrise over mountains",
    "mountain landscape videos",
    "tranquil nature images",
    "foggy valley at dawn",
    "pictures of mountain mist",
    "sunrise mountain views",
    "calm mountain vistas",
    "serenity in nature",
    "mountain sunrise photography"
  ],
  "asset_type": "mixed",
  "asset_count": 20,
  "notes": "Includes both videos and images."
}


In [18]:
# Phase 2 (general) - metadata-only fetcher using LLM-generated queries (no hardcoding)
# Run this cell in your Jupyter notebook. It expects `expand_and_detect()` (Phase 1) to be defined
# and .env keys for UNSPLASH_ACCESS_KEY, PEXELS_API_KEY, PIXABAY_API_KEY, FREESOUND_API_KEY to exist.
#
# Dependencies (install if needed):
# !pip install python-dotenv requests pandas

import os
import json
import requests
from typing import List, Dict, Any
from dotenv import load_dotenv
from pprint import pprint

load_dotenv()

UNSPLASH_ACCESS_KEY = os.getenv("UNSPLASH_ACCESS_KEY")
PEXELS_API_KEY = os.getenv("PEXELS_API_KEY")
PIXABAY_API_KEY = os.getenv("PIXABAY_API_KEY")
FREESOUND_API_KEY = os.getenv("FREESOUND_API_KEY")

# -- Normalized asset dict helper --
def make_asset(provider: str,
               asset_url: str,
               title: str = None,
               description: str = None,
               tags: List[str] = None,
               thumbnail: str = None,
               width: int = None,
               height: int = None,
               duration: float = None,
               raw: Dict[str, Any] = None) -> Dict[str, Any]:
    return {
        "provider": provider,
        "asset_url": asset_url,
        "title": title,
        "description": description,
        "tags": tags or [],
        "thumbnail": thumbnail,
        "width": width,
        "height": height,
        "duration": duration,
        "raw": raw or {},
    }

# -- Provider search functions (metadata-only) --
def search_unsplash(query: str, per_page: int = 5) -> List[Dict[str, Any]]:
    if not UNSPLASH_ACCESS_KEY:
        return []
    url = "https://api.unsplash.com/search/photos"
    params = {"query": query, "per_page": per_page}
    headers = {"Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"}
    resp = requests.get(url, params=params, headers=headers, timeout=15)
    resp.raise_for_status()
    data = resp.json()
    out = []
    for item in data.get("results", []):
        out.append(make_asset(
            provider="unsplash",
            asset_url=item.get("links", {}).get("html") or item.get("urls", {}).get("full"),
            title=item.get("alt_description") or item.get("description"),
            description=item.get("description") or item.get("alt_description"),
            tags=[t.get("title") for t in item.get("tags", [])] if item.get("tags") else [],
            thumbnail=item.get("urls", {}).get("small"),
            width=item.get("width"),
            height=item.get("height"),
            raw=item,
        ))
    return out

def search_pexels_images(query: str, per_page: int = 5) -> List[Dict[str, Any]]:
    if not PEXELS_API_KEY:
        return []
    url = "https://api.pexels.com/v1/search"
    params = {"query": query, "per_page": per_page}
    headers = {"Authorization": PEXELS_API_KEY}
    resp = requests.get(url, params=params, headers=headers, timeout=15)
    resp.raise_for_status()
    data = resp.json()
    out = []
    for item in data.get("photos", []):
        out.append(make_asset(
            provider="pexels",
            asset_url=item.get("url"),
            title=item.get("alt"),
            description=item.get("photographer"),
            tags=[],  # pexels doesn't provide tags in this endpoint
            thumbnail=item.get("src", {}).get("tiny"),
            width=item.get("width"),
            height=item.get("height"),
            raw=item,
        ))
    return out

def search_pixabay(query: str, per_page: int = 5, image_type: str = "photo") -> List[Dict[str, Any]]:
    if not PIXABAY_API_KEY:
        return []
    url = "https://pixabay.com/api/"
    params = {"key": PIXABAY_API_KEY, "q": query, "per_page": per_page, "image_type": image_type, "safesearch": "true"}
    resp = requests.get(url, params=params, timeout=15)
    resp.raise_for_status()
    data = resp.json()
    out = []
    for item in data.get("hits", []):
        out.append(make_asset(
            provider="pixabay",
            asset_url=item.get("pageURL"),
            title=item.get("tags"),
            description=item.get("user"),
            tags=[t.strip() for t in (item.get("tags") or "").split(",") if t.strip()],
            thumbnail=item.get("previewURL"),
            width=item.get("imageWidth"),
            height=item.get("imageHeight"),
            raw=item,
        ))
    return out

def search_freesound(query: str, per_page: int = 5) -> List[Dict[str, Any]]:
    if not FREESOUND_API_KEY:
        return []
    url = "https://freesound.org/apiv2/search/text/"
    params = {"query": query, "page_size": per_page}
    headers = {"Authorization": f"Token {FREESOUND_API_KEY}"}
    resp = requests.get(url, params=params, headers=headers, timeout=15)
    resp.raise_for_status()
    data = resp.json()
    out = []
    for item in data.get("results", []):
        out.append(make_asset(
            provider="freesound",
            asset_url=item.get("url"),
            title=item.get("name"),
            description=item.get("description"),
            tags=item.get("tags", []),
            thumbnail=None,
            duration=item.get("duration"),
            raw=item,
        ))
    return out

# -- Aggregator (general, uses LLM expansion output) --
def fetch_assets_from_providers(expansion_result: Dict[str, Any],
                                per_provider: int = 5,
                                queries_to_use: int = None) -> List[Dict[str, Any]]:
    """
    expansion_result: dict or pydantic model containing keys:
       - expanded_queries (list[str])
       - asset_type (images|videos|sound|mixed)
    per_provider: results per provider per query
    queries_to_use: how many expanded queries to use (None == use all)
    """
    # Accept either dict or pydantic model (ExpansionOutput)
    if hasattr(expansion_result, "model_dump"):
        data = expansion_result.model_dump()
    elif hasattr(expansion_result, "dict"):
        data = expansion_result.dict()
    else:
        data = dict(expansion_result)

    queries = data.get("expanded_queries", []) or []
    if queries_to_use is None:
        queries_to_use = len(queries)
    queries = queries[:queries_to_use]
    asset_type = data.get("asset_type", "images")

    aggregated: List[Dict[str, Any]] = []

    for q in queries:
        # Images (unsplash, pexels, pixabay)
        if asset_type in ("images", "mixed", None):
            try:
                aggregated.extend(search_unsplash(q, per_page=per_provider))
            except Exception as e:
                aggregated.append({"provider": "unsplash", "error": str(e), "query": q})
            try:
                aggregated.extend(search_pexels_images(q, per_page=per_provider))
            except Exception as e:
                aggregated.append({"provider": "pexels", "error": str(e), "query": q})
            try:
                aggregated.extend(search_pixabay(q, per_page=per_provider))
            except Exception as e:
                aggregated.append({"provider": "pixabay", "error": str(e), "query": q})

        # Sound
        if asset_type in ("sound", "mixed"):
            try:
                aggregated.extend(search_freesound(q, per_page=per_provider))
            except Exception as e:
                aggregated.append({"provider": "freesound", "error": str(e), "query": q})

        # Videos (Pexels videos)
        if asset_type == "videos":
            if PEXELS_API_KEY:
                url = "https://api.pexels.com/videos/search"
                try:
                    resp = requests.get(url, params={"query": q, "per_page": per_provider}, headers={"Authorization": PEXELS_API_KEY}, timeout=15)
                    resp.raise_for_status()
                    data = resp.json()
                    for item in data.get("videos", []):
                        aggregated.append(make_asset(
                            provider="pexels",
                            asset_url=item.get("url"),
                            title=item.get("user", {}).get("name"),
                            description=item.get("url"),
                            tags=[],
                            thumbnail=item.get("image"),
                            duration=item.get("duration"),
                            raw=item,
                        ))
                except Exception as e:
                    aggregated.append({"provider": "pexels_video", "error": str(e), "query": q})

    return aggregated

# -- Example usage (no hardcoded queries) --
# assume `expand_and_detect()` from Phase 1 exists in the notebook
# and returns a Pydantic ExpansionOutput or dict.

sample_prompt = "A moody cinematic photo of a lone fisherman at sunrise on a misty lake, high contrast"

# call Phase-1 function to get expansions
expansion_out = expand_and_detect(sample_prompt)   # returns ExpansionOutput or raises
# convert to plain dict safely
if hasattr(expansion_out, "model_dump"):
    expansion_result = expansion_out.model_dump()
elif hasattr(expansion_out, "dict"):
    expansion_result = expansion_out.dict()
else:
    expansion_result = dict(expansion_out)

# fetch assets using the generated queries (you can tune per_provider / queries_to_use)
assets = fetch_assets_from_providers(expansion_result, per_provider=3, queries_to_use=6)

# show results (first 50) in a readable table using pandas
import pandas as pd
df = pd.json_normalize(assets)
pd.set_option("display.max_colwidth", 200)
print(f"Fetched {len(df)} items. Showing first 50 rows:")
display(df.head(50))


Fetched 54 items. Showing first 50 rows:


Unnamed: 0,provider,asset_url,title,description,tags,thumbnail,width,height,duration,raw.id,...,raw.user_id,raw.user,raw.userImageURL,raw.noAiTraining,raw.isAiGenerated,raw.isGRated,raw.isLowQuality,raw.userURL,raw.topic_submissions.nature.status,raw.topic_submissions.wallpapers.status
0,unsplash,https://unsplash.com/photos/a-boat-on-the-water-gBaDh4y8S0A,a boat on the water,Three fishermen in a boat at sunset.,[],https://images.unsplash.com/photo-1657272179712-dbf132167f1d?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwxfHxtb29keSUyMGZpc2hlcm1hbiUyMHN1bnJpc2V8ZW58MHx8fHwxNzYxNDc0MT...,5184,3456,,gBaDh4y8S0A,...,,,,,,,,,,
1,unsplash,https://unsplash.com/photos/two-people-and-a-dog-fishing-at-sunset-sJLEqM7PfTo,Two people and a dog fishing at sunset,Two people and a dog fishing at sunset,[],https://images.unsplash.com/photo-1757512440405-e55c1b17918a?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwyfHxtb29keSUyMGZpc2hlcm1hbiUyMHN1bnJpc2V8ZW58MHx8fHwxNzYxNDc0MT...,4689,3143,,sJLEqM7PfTo,...,,,,,,,,,,
2,unsplash,https://unsplash.com/photos/silhouette-of-man-standing-on-rocks-facing-ocean-during-golden-hour-WmH61roBSwI,silhouette of man standing on rocks facing ocean during golden hour,"While i on Puger beach, i took a photo",[],https://images.unsplash.com/photo-1541690956806-f9621a226394?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwzfHxtb29keSUyMGZpc2hlcm1hbiUyMHN1bnJpc2V8ZW58MHx8fHwxNzYxNDc0MT...,4007,3005,,WmH61roBSwI,...,,,,,,,,,,
3,pexels,https://www.pexels.com/photo/man-in-black-water-proof-fishing-suit-in-the-water-9962601/,"A lone fisherman stands in calm water at dawn, showcasing peaceful solitude and hobby.",Niklas Jeromin,[],https://images.pexels.com/photos/9962601/pexels-photo-9962601.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,4000,6000,,9962601,...,,,,,,,,,,
4,pexels,https://www.pexels.com/photo/silhouette-of-fisherman-at-calm-seaside-32524935/,"Monochrome photo of a fisherman in shallow waters at sunrise, creating a serene silhouette effect.",Orkhan Aliyev,[],https://images.pexels.com/photos/32524935/pexels-photo-32524935.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,2189,2189,,32524935,...,,,,,,,,,,
5,pexels,https://www.pexels.com/photo/silhouette-of-a-person-riding-on-the-boat-during-sunset-7536552/,Silhouette of a fisherman on a boat during the golden sunrise over the ocean.,MightyTeja,[],https://images.pexels.com/photos/7536552/pexels-photo-7536552.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,4898,3265,,7536552,...,,,,,,,,,,
6,pixabay,https://pixabay.com/photos/fishing-boat-at-sea-fishing-8095632/,"fishing boat at sea, fishing, fisherman fishing, morning, dawn, nature, at sea, floating, fishing, fishing, fishing, morning, morning, morning, morning, morning",HieuNghiaMini,"[fishing boat at sea, fishing, fisherman fishing, morning, dawn, nature, at sea, floating, fishing, fishing, fishing, morning, morning, morning, morning, morning]",https://cdn.pixabay.com/photo/2023/06/29/04/30/fishing-boat-at-sea-8095632_150.jpg,4000,2250,,8095632,...,25859970.0,HieuNghiaMini,https://cdn.pixabay.com/user/2023/07/21/00-14-10-754_250x250.jpg,False,False,True,False,https://pixabay.com/users/25859970/,,
7,pixabay,https://pixabay.com/photos/dawn-beach-fishing-sunset-scenic-6562295/,"dawn, beach, fishing, sunset, scenic, sea, sky, silhouette, nature, seascape, vietnam, twilight, dusk",xuanduongvan87,"[dawn, beach, fishing, sunset, scenic, sea, sky, silhouette, nature, seascape, vietnam, twilight, dusk]",https://cdn.pixabay.com/photo/2021/08/21/09/26/dawn-6562295_150.jpg,6000,4000,,6562295,...,22814888.0,xuanduongvan87,https://cdn.pixabay.com/user/2021/08/10/04-26-38-240_250x250.jpg,False,False,True,False,https://pixabay.com/users/22814888/,,
8,pixabay,https://pixabay.com/photos/nature-fisherman-landscape-7175030/,"nature, fisherman, landscape, reflection, mountains, li river, guilin china, nature, nature, nature, nature, nature, fisherman, fisherman, landscape, landscape, landscape",mercierzeng,"[nature, fisherman, landscape, reflection, mountains, li river, guilin china, nature, nature, nature, nature, nature, fisherman, fisherman, landscape, landscape, landscape]",https://cdn.pixabay.com/photo/2022/05/05/01/05/nature-7175030_150.jpg,7360,4912,,7175030,...,26409936.0,mercierzeng,https://cdn.pixabay.com/user/2022/05/05/00-39-46-458_250x250.jpg,False,False,True,False,https://pixabay.com/users/26409936/,,
9,unsplash,https://unsplash.com/photos/green-trees-on-mountain-during-daytime-GoM3XZKgkAo,green trees on mountain during daytime,shadow lake,[],https://images.unsplash.com/photo-1620251423176-a21fc67ae47b?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwxfHxjaW5lbWF0aWMlMjBsYWtlJTIwcGhvdG98ZW58MHx8fHwxNzYxNDc0MTA4fD...,5379,3026,,GoM3XZKgkAo,...,,,,,,,,,,


In [19]:
# Phase 3 (metadata matching) - compute SAS using OpenAI embeddings and pick top-K
# Requires: pip install numpy pandas openai (or use your existing OpenAI client import)
import numpy as np
import math
import pandas as pd
from typing import List, Dict, Any

# config: change model if you prefer
EMBEDDING_MODEL = "text-embedding-3-small"  # small, fast embedding model

def _text_for_asset(asset: Dict[str, Any]) -> str:
    """Create a concise metadata text representation for an asset for embedding."""
    parts = []
    if asset.get("title"):
        parts.append(asset["title"])
    if asset.get("description"):
        parts.append(asset["description"])
    # tags may be list or comma string
    tags = asset.get("tags") or []
    if isinstance(tags, str):
        tags = [t.strip() for t in tags.split(",") if t.strip()]
    if tags:
        parts.append(" ".join(tags))
    # include provider and url hints sparingly
    # parts.append(asset.get("provider",""))
    return " ||| ".join(parts) if parts else (asset.get("provider") or "")

def _batch_embeddings(texts: List[str], batch_size: int = 64) -> List[List[float]]:
    """Get embeddings for a list of texts using the OpenAI client (batched)."""
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        resp = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
        # resp.data is a list; each item has .embedding
        batch_emb = [d.embedding for d in resp.data]
        embeddings.extend(batch_emb)
    return embeddings

def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    # handle zero vectors
    if np.all(a == 0) or np.all(b == 0):
        return 0.0
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

def rank_assets_by_metadata_similarity(assets: List[Dict[str, Any]],
                                       expanded_queries: List[str],
                                       top_k: int = 10) -> pd.DataFrame:
    """
    For each asset, compute embedding of its metadata text and compute cosine similarity
    to each expanded query embedding. Use the maximum similarity across queries as SAS.
    Return a DataFrame of top_k assets ordered by SAS (desc).
    """
    # 1) prepare asset texts
    asset_texts = [_text_for_asset(a) for a in assets]
    # 2) prepare query texts (expanded queries)
    query_texts = [q.strip() for q in expanded_queries if q and q.strip()]

    # 3) embed all texts (queries + assets) in batches
    all_texts = query_texts + asset_texts
    all_embs = _batch_embeddings(all_texts)
    query_embs = [np.array(v, dtype=float) for v in all_embs[:len(query_texts)]]
    asset_embs = [np.array(v, dtype=float) for v in all_embs[len(query_texts):]]

    # 4) compute SAS per asset = max cosine similarity to any query
    sas_scores = []
    for a_emb in asset_embs:
        sims = [_cosine_sim(a_emb, q_emb) for q_emb in query_embs]
        sas = max(sims) if sims else 0.0
        sas_scores.append(sas)

    # 5) attach scores and form DataFrame
    rows = []
    for idx, asset in enumerate(assets):
        rows.append({
            "provider": asset.get("provider"),
            "asset_url": asset.get("asset_url"),
            "title": asset.get("title"),
            "description": asset.get("description"),
            "tags": asset.get("tags"),
            "thumbnail": asset.get("thumbnail"),
            "width": asset.get("width"),
            "height": asset.get("height"),
            "duration": asset.get("duration"),
            "sas_score": sas_scores[idx],
            "raw": asset.get("raw")
        })

    df = pd.DataFrame(rows)
    df = df.sort_values("sas_score", ascending=False).reset_index(drop=True)
    # normalize SAS to 0..1 (optional) ‚Äî cosine is already in [-1,1], but embeddings should give [0,1] mostly
    # clip negatives to 0
    df["sas_score"] = df["sas_score"].clip(lower=0.0)
    # return top_k
    return df.head(top_k)

# -----------------------
# Run ranking on your assets + queries
# -----------------------
# Ensure variables exist: `assets` (list) and `expansion_result` (dict or ExpansionOutput)
if 'assets' not in globals():
    raise RuntimeError("`assets` list not found. Run Phase-2 fetch cell first to produce `assets`.")

# get expanded queries list safely
if hasattr(expansion_result, "get"):
    expanded_queries = expansion_result.get("expanded_queries", [])
elif hasattr(expansion_out, "model_dump"):
    expanded_queries = expansion_out.model_dump().get("expanded_queries", [])
elif hasattr(expansion_out, "dict"):
    expanded_queries = expansion_out.dict().get("expanded_queries", [])
else:
    # fallback: try variable from earlier
    expanded_queries = globals().get("expanded_queries", [])

if not expanded_queries:
    raise RuntimeError("No expanded queries found. Ensure Phase-1 returned expanded_queries in `expansion_result`.")

top_df = rank_assets_by_metadata_similarity(assets, expanded_queries, top_k=10)
print(f"Top {len(top_df)} assets by metadata SAS:")
display(top_df[['provider','asset_url','title','tags','sas_score']])


Top 10 assets by metadata SAS:


Unnamed: 0,provider,asset_url,title,tags,sas_score
0,pexels,https://www.pexels.com/photo/fisherman-on-fishing-boat-in-black-and-white-18176651/,"A fisherman stands on a boat in the mist, using a net in a serene black and white scene.",[],0.680699
1,unsplash,https://unsplash.com/photos/two-people-in-a-boat-on-a-misty-lake-PWdXaCvgly0,two people in a boat on a misty lake,[],0.656573
2,pexels,https://www.pexels.com/photo/man-in-black-water-proof-fishing-suit-in-the-water-9962601/,"A lone fisherman stands in calm water at dawn, showcasing peaceful solitude and hobby.",[],0.64997
3,pixabay,https://pixabay.com/photos/fisherman-boat-lake-fog-sunrise-4411420/,"fisherman, boat, lake, fog, sunrise, tree, early morning, nature, landscape, misty, fisherman, boat, boat, boat, boat, boat, sunrise","[fisherman, boat, lake, fog, sunrise, tree, early morning, nature, landscape, misty, fisherman, boat, boat, boat, boat, boat, sunrise]",0.63692
4,pixabay,https://pixabay.com/photos/fishing-boat-at-sea-fishing-8095632/,"fishing boat at sea, fishing, fisherman fishing, morning, dawn, nature, at sea, floating, fishing, fishing, fishing, morning, morning, morning, morning, morning","[fishing boat at sea, fishing, fisherman fishing, morning, dawn, nature, at sea, floating, fishing, fishing, fishing, morning, morning, morning, morning, morning]",0.628279
5,pexels,https://www.pexels.com/photo/grayscale-photo-of-person-fishing-on-seashore-9082178/,A lone fisherman casts his line in the serene grayscale setting of a tranquil shoreline at twilight.,[],0.602111
6,pexels,https://www.pexels.com/photo/silhouette-of-fisherman-at-calm-seaside-32524935/,"Monochrome photo of a fisherman in shallow waters at sunrise, creating a serene silhouette effect.",[],0.601953
7,pexels,https://www.pexels.com/photo/silhouette-of-a-person-riding-on-the-boat-during-sunset-7536552/,Silhouette of a fisherman on a boat during the golden sunrise over the ocean.,[],0.599565
8,pixabay,https://pixabay.com/photos/sunset-the-sea-tidal-flat-1669264/,"sunset, the sea, nature, tidal flat, silhouettes, high contrast","[sunset, the sea, nature, tidal flat, silhouettes, high contrast]",0.585762
9,unsplash,https://unsplash.com/photos/a-boat-on-the-water-gBaDh4y8S0A,a boat on the water,[],0.581232


In [20]:
!pip install transformers torch pillow requests pandas --quiet



[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


WMRS=(0.30√óSAS)+(0.45√óCS)+(0.25√óRS)

where:

SAS ‚Üí Semantic Alignment Score ‚Äî measures text-to-text similarity (metadata ‚Üî query)

CS ‚Üí Correctness Score ‚Äî measures image‚Äìmetadata alignment (image ‚Üî caption/tags)

RS ‚Üí Relevance/Style Score ‚Äî measures image‚Äìstyle alignment (image ‚Üî user intent/style)

In [21]:
# Phase 4-ish: Visual relevance checks with CLIP ViT-L/14 (image content vs metadata & style)
# Run this in your Jupyter notebook. It expects:
#  - `top_df` : pandas.DataFrame from the metadata-ranking step containing at least:
#       ['provider','asset_url','title','description','tags','sas_score','raw']
#  - `expansion_result` : dict or Pydantic model containing 'expanded_queries' and optional 'notes'
#
# Installs (run once in a cell if needed):
# !pip install transformers torch pillow requests pandas --quiet

import os
import math
import requests
from io import BytesIO
from PIL import Image
import numpy as np
import pandas as pd
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import List, Dict, Any

# ---------- config ----------
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"  # ViT-L/14
BATCH_SIZE = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- helper funcs ----------
def _download_image(url: str, timeout: float = 10.0) -> Image.Image:
    """Download an image and return a PIL Image. Raises on failure."""
    resp = requests.get(url, timeout=timeout)
    resp.raise_for_status()
    return Image.open(BytesIO(resp.content)).convert("RGB")

def _batch_iter(xs: List, n: int):
    for i in range(0, len(xs), n):
        yield xs[i:i + n]

# ---------- load CLIP (image+text towers) ----------
device = torch.device(DEVICE)
model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)

# ---------- embedding utilities ----------
def embed_images(images: List[Image.Image]) -> np.ndarray:
    """Return numpy array (N, D) of normalized image embeddings."""
    all_embs = []
    for batch in _batch_iter(images, BATCH_SIZE):
        inputs = processor(images=batch, return_tensors="pt").to(device)
        with torch.no_grad():
            img_outputs = model.get_image_features(**inputs)
        img_emb = img_outputs.cpu().numpy()
        # L2 normalize
        norms = np.linalg.norm(img_emb, axis=1, keepdims=True) + 1e-12
        img_emb = img_emb / norms
        all_embs.append(img_emb)
    return np.vstack(all_embs) if all_embs else np.zeros((0, model.visual_projection.out_features))

def embed_texts(texts: List[str]) -> np.ndarray:
    """Return numpy array (N, D) of normalized text embeddings using CLIP text tower."""
    all_embs = []
    for batch in _batch_iter(texts, BATCH_SIZE):
        inputs = processor(text=batch, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            txt_outputs = model.get_text_features(**inputs)
        txt_emb = txt_outputs.cpu().numpy()
        norms = np.linalg.norm(txt_emb, axis=1, keepdims=True) + 1e-12
        txt_emb = txt_emb / norms
        all_embs.append(txt_emb)
    return np.vstack(all_embs) if all_embs else np.zeros((0, model.text_projection.out_features))

def cosine_sim_matrix(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    """Return cosine similarity matrix shape (len(A), len(B)). Assumes rows are normalized."""
    if A.size == 0 or B.size == 0:
        return np.zeros((A.shape[0], B.shape[0]))
    return A @ B.T

# ---------- main function ----------
def check_visual_relevance(top_df: pd.DataFrame,
                           expansion_result: Dict[str, Any],
                           top_k: int = 10,
                           weight_sas: float = 0.30,
                           weight_cs: float = 0.45,
                           weight_rs: float = 0.25) -> pd.DataFrame:
    """
    Compute:
      - CS (Correctness Score): similarity between image embedding and asset metadata text (title/description/tags)
      - RS (Relevance/Style Score): similarity between image embedding and a style/query embedding (derived from expansion_result['notes'] or top expanded_queries)
      - WMRS = weighted sum of SAS (already present), CS, RS
    Returns DataFrame with top_k sorted by WMRS desc.
    """
    # ensure top_df has needed columns
    df = top_df.copy().reset_index(drop=True)
    # limit to top_k rows (caller already selected top 10 usually)
    df = df.head(top_k).copy()

    # Build metadata text per asset (same as Phase-3)
    def _metadata_text(row):
        parts = []
        if row.get("title"):
            parts.append(str(row["title"]))
        if row.get("description"):
            parts.append(str(row["description"]))
        tags = row.get("tags") or []
        if isinstance(tags, str):
            tags = [t.strip() for t in tags.split(",") if t.strip()]
        if tags:
            parts.append(" ".join(tags))
        return " ||| ".join(parts) if parts else ""

    df["metadata_text"] = df.apply(_metadata_text, axis=1)

    # Style/query text: prefer notes, else join top 3 expanded queries
    notes = ""
    if isinstance(expansion_result, dict):
        notes = expansion_result.get("notes", "") or ""
        queries = expansion_result.get("expanded_queries", []) or []
    else:
        # try to access Pydantic model
        try:
            notes = expansion_result.notes or ""
            queries = expansion_result.expanded_queries or []
        except Exception:
            notes = ""
            queries = []

    if notes and notes.strip():
        style_text = notes.strip()
    else:
        style_text = " | ".join(queries[:3]) if queries else ""

    # 1) Download images (use thumbnail if direct asset_url fails)
    images = []
    urls = []
    for _, row in df.iterrows():
        url = row.get("thumbnail") or row.get("asset_url")
        if not url:
            images.append(None)
            urls.append(None)
            continue
        try:
            img = _download_image(url)
            images.append(img)
            urls.append(url)
        except Exception:
            # fallback: try asset_url if thumbnail failed
            fallback = row.get("asset_url")
            try:
                img = _download_image(fallback) if fallback else None
                images.append(img)
                urls.append(fallback)
            except Exception:
                images.append(None)
                urls.append(None)

    # For any None images, replace with a small black image to keep shapes consistent (their sims will be low)
    for i, img in enumerate(images):
        if img is None:
            images[i] = Image.new("RGB", (224,224), color=(0,0,0))
            urls[i] = None

    # 2) Compute embeddings
    img_embs = embed_images(images)  # shape (N, D)
    metadata_texts = df["metadata_text"].fillna("").tolist()
    meta_embs = embed_texts(metadata_texts)  # shape (N, D)
    style_emb = embed_texts([style_text])  # shape (1, D)

    # 3) CS = cosine(image, metadata) per asset (use dot since normalized)
    cs_sims = np.array([float(x) for x in np.diag(cosine_sim_matrix(img_embs, meta_embs))]) if img_embs.shape[0] and meta_embs.shape[0] else np.zeros(img_embs.shape[0])

    # 4) RS = cosine(image, style)
    rs_sims = cosine_sim_matrix(img_embs, style_emb)[:, 0] if img_embs.shape[0] and style_emb.shape[0] else np.zeros(img_embs.shape[0])

    # 5) Collect final WMRS
    sas_scores = df.get("sas_score").fillna(0).astype(float).values
    # Ensure all are in [0,1] - clip to reasonable range
    cs = np.clip(cs_sims, 0.0, 1.0)
    rs = np.clip(rs_sims, 0.0, 1.0)
    sas = np.clip(sas_scores, 0.0, 1.0)

    wmrs = weight_sas * sas + weight_cs * cs + weight_rs * rs

    # attach back to DataFrame
    df["cs_score"] = cs
    df["rs_score"] = rs
    df["sas_score"] = sas
    df["wmrs"] = wmrs
    df["image_source_url"] = urls

    # sort by wmrs descending and return
    df = df.sort_values("wmrs", ascending=False).reset_index(drop=True)
    return df

# ---------- Usage ----------
# Ensure top_df and expansion_result variables exist
# Example:
# result_df = check_visual_relevance(top_df, expansion_result, top_k=10)
# display(result_df[['provider','image_source_url','asset_url','sas_score','cs_score','rs_score','wmrs']])


  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [22]:
# 1) sanity checks
assert 'check_visual_relevance' in globals(), "check_visual_relevance function not defined. Run the cell that defines it."
assert 'top_df' in globals(), "top_df not found ‚Äî run the metadata-ranking cell first."
assert 'expansion_result' in globals() or 'expansion_out' in globals(), "expansion_result (Phase-1 output) not found."

# 2) call the function and show results
try:
    # choose top_k as you want (10)
    result_df = check_visual_relevance(top_df, expansion_result if 'expansion_result' in globals() else expansion_out, top_k=10)
    print("check_visual_relevance completed. Rows returned:", len(result_df))
    # display main columns
    display(result_df[['provider','image_source_url','asset_url','sas_score','cs_score','rs_score','wmrs']].head(10))
except Exception as e:
    print("Error while running visual relevance check:", repr(e))


check_visual_relevance completed. Rows returned: 10


Unnamed: 0,provider,image_source_url,asset_url,sas_score,cs_score,rs_score,wmrs
0,pexels,https://images.pexels.com/photos/18176651/pexels-photo-18176651.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,https://www.pexels.com/photo/fisherman-on-fishing-boat-in-black-and-white-18176651/,0.680699,0.284126,0.187424,0.378922
1,pixabay,https://cdn.pixabay.com/photo/2019/08/17/04/18/fisherman-4411420_150.jpg,https://pixabay.com/photos/fisherman-boat-lake-fog-sunrise-4411420/,0.63692,0.312765,0.173988,0.375318
2,pexels,https://images.pexels.com/photos/32524935/pexels-photo-32524935.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,https://www.pexels.com/photo/silhouette-of-fisherman-at-calm-seaside-32524935/,0.601953,0.300066,0.20491,0.366843
3,unsplash,https://images.unsplash.com/photo-1711502896149-73c89f08f5d0?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwzfHxtaXN0eSUyMGxha2UlMjBmaXNoZXJtYW58ZW58MHx8fHwxNzYxNDc0MTEwfD...,https://unsplash.com/photos/two-people-in-a-boat-on-a-misty-lake-PWdXaCvgly0,0.656573,0.285077,0.163804,0.366208
4,pixabay,https://cdn.pixabay.com/photo/2023/06/29/04/30/fishing-boat-at-sea-8095632_150.jpg,https://pixabay.com/photos/fishing-boat-at-sea-fishing-8095632/,0.628279,0.238427,0.186089,0.342298
5,pexels,https://images.pexels.com/photos/9082178/pexels-photo-9082178.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,https://www.pexels.com/photo/grayscale-photo-of-person-fishing-on-seashore-9082178/,0.602111,0.242763,0.202949,0.340614
6,pexels,https://images.pexels.com/photos/9962601/pexels-photo-9962601.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,https://www.pexels.com/photo/man-in-black-water-proof-fishing-suit-in-the-water-9962601/,0.64997,0.242058,0.146151,0.340455
7,pexels,https://images.pexels.com/photos/7536552/pexels-photo-7536552.jpeg?auto=compress&cs=tinysrgb&dpr=1&fit=crop&h=200&w=280,https://www.pexels.com/photo/silhouette-of-a-person-riding-on-the-boat-during-sunset-7536552/,0.599565,0.264351,0.165842,0.340288
8,unsplash,https://images.unsplash.com/photo-1657272179712-dbf132167f1d?crop=entropy&cs=tinysrgb&fit=max&fm=jpg&ixid=M3w4MDU2NDN8MHwxfHNlYXJjaHwxfHxtb29keSUyMGZpc2hlcm1hbiUyMHN1bnJpc2V8ZW58MHx8fHwxNzYxNDc0MT...,https://unsplash.com/photos/a-boat-on-the-water-gBaDh4y8S0A,0.581232,0.257548,0.163239,0.331076
9,pixabay,https://cdn.pixabay.com/photo/2016/09/14/11/32/sunset-1669264_150.jpg,https://pixabay.com/photos/sunset-the-sea-tidal-flat-1669264/,0.585762,0.210497,0.182263,0.316018


In [23]:
# Compare stored top-10 vs next top-10, compute relevance for all 20 and report dominators.
# Assumes:
# - `assets` (list of all fetched metadata assets) exists
# - `expansion_result` (dict/Pydantic) exists
# - helper functions available: rank_assets_by_metadata_similarity, check_visual_relevance
# - CLIP model already loaded by check_visual_relevance

import pandas as pd
from IPython.display import display

# 1) sanity checks
if 'assets' not in globals():
    raise RuntimeError("`assets` not found. Run Phase-2 fetch cell first.")
if 'expansion_result' not in globals() and 'expansion_out' not in globals():
    raise RuntimeError("`expansion_result` (Phase-1 output) not found. Ensure expand_and_detect(...) was run.")

# get expansion_result variable
exp_res = expansion_result if 'expansion_result' in globals() else expansion_out

# 2) build full SAS-ranked DataFrame for all assets (use existing rank function)
# Use a reasonably large top_k to get full ordering; function returns top_k rows, so request len(assets)
full_ranked_df = rank_assets_by_metadata_similarity(assets, exp_res.get("expanded_queries") if isinstance(exp_res, dict) else exp_res.expanded_queries, top_k=len(assets))

# 3) determine stored (first 10) and challenger (next 10)
stored_count = 10
challenger_count = 10

# if user already computed and saved previous visual relevance results, reuse them
if 'stored_relevance_df' in globals():
    stored_df = stored_relevance_df.reset_index(drop=True)
    # ensure stored_df has SAS column; if not, derive via asset_url matching from full_ranked_df
    if 'sas_score' not in stored_df.columns:
        stored_df = stored_df.merge(full_ranked_df[['asset_url','sas_score']], on='asset_url', how='left')
else:
    # take top 10 by SAS as stored
    stored_df = full_ranked_df.head(stored_count).copy().reset_index(drop=True)
    # save for later reuse
    stored_relevance_df = stored_df.copy()

# challenger: next N assets that are not in stored (by asset_url)
stored_urls = set(stored_df['asset_url'].astype(str).tolist())
remaining_df = full_ranked_df[~full_ranked_df['asset_url'].astype(str).isin(stored_urls)].reset_index(drop=True)
challenger_df = remaining_df.head(challenger_count).copy().reset_index(drop=True)

if challenger_df.empty:
    raise RuntimeError("No remaining assets found for challenger selection. Check `assets` and `full_ranked_df`.")

print(f"Stored (SAS top {stored_count}) count: {len(stored_df)}")
print(f"Challenger (next {challenger_count}) count: {len(challenger_df)}")

# 4) combine the two DataFrames and run visual relevance (CLIP) over the combined set
combined_df_for_clip = pd.concat([stored_df, challenger_df], ignore_index=True)

# check_visual_relevance expects a DataFrame similar to top_df; call it with top_k = len(combined)
combined_wmrs_df = check_visual_relevance(combined_df_for_clip, exp_res, top_k=len(combined_df_for_clip))

# 5) determine dominance: challenger asset dominates stored set if its wmrs > max wmrs among stored
max_stored_wmrs = combined_wmrs_df.loc[combined_wmrs_df['asset_url'].isin(stored_urls), 'wmrs'].max()
combined_wmrs_df['is_challenger'] = combined_wmrs_df['asset_url'].apply(lambda u: u not in stored_urls)
combined_wmrs_df['dominates_stored'] = combined_wmrs_df.apply(
    lambda r: bool(r['is_challenger'] and (r['wmrs'] > (max_stored_wmrs if not pd.isna(max_stored_wmrs) else -1e9))),
    axis=1
)

# 6) show results sorted by wmrs desc
display_cols = ['provider','image_source_url','asset_url','sas_score','cs_score','rs_score','wmrs','is_challenger','dominates_stored']
print(f"Max WMRS among stored set: {max_stored_wmrs:.4f}" if not pd.isna(max_stored_wmrs) else "No stored WMRS found.")
print("\nCombined top 20 (sorted by WMRS):")
display(combined_wmrs_df.sort_values('wmrs', ascending=False).reset_index(drop=True)[display_cols])

# 7) summary: which challengers dominated stored set
dominators = combined_wmrs_df[combined_wmrs_df['dominates_stored']].sort_values('wmrs', ascending=False)
print(f"\nNumber of challenger assets that dominate stored set: {len(dominators)}")
if len(dominators) > 0:
    print("Dominating challenger assets (top results):")
    display(dominators[display_cols].head(10))
else:
    print("No challenger asset had WMRS greater than the max of the stored set.")

# 8) persist combined results for later analysis
combined_results_df = combined_wmrs_df.sort_values('wmrs', ascending=False).reset_index(drop=True)
# save to variable for interactive use
last_combined_results = combined_results_df.copy()

# (Optional) save to CSV in working directory:
combined_results_df.to_csv("marp_combined_top20_relevance.csv", index=False)
print("\nCombined results saved to marp_combined_top20_relevance.csv")


Stored (SAS top 10) count: 10
Challenger (next 10) count: 10


Token indices sequence length is longer than the specified maximum sequence length for this model (101 > 77). Running this sequence through the model will result in indexing errors


ValueError: Sequence length must be less than max_position_embeddings (got `sequence length`: 101 and max_position_embeddings: 77