In [1]:
!pip install openai requests python-dotenv Pillow numpy soundfile librosa ffmpeg-python tqdm open-clip-torch torch
# System-level: ffmpeg must be installed on the machine (apt install ffmpeg / brew install ffmpeg)





[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
pip install --upgrade openai


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
"""
llm_fetch_align_pipeline.py

End-to-end script using Hugging Face ALIGN (kakaobrain/align-base):
1) parse user prompt with OpenAI LLM to determine: cleaned prompt, #assets, which types (image/video/audio)
2) fetch metadata from Pexels (images + videos) and Freesound (audio) accordingly
3) download thumbnails/previews in-memory and preprocess (images or audio -> mel images)
4) embed prompt and assets using ALIGN (AlignProcessor + AlignModel) and compute cosine similarity (relevance)
5) print top-K per type

Environment variables:
  OPENAI_API_KEY, PEXELS_API_KEY, FREESOUND_API_KEY

NOTE: This script does not save files to disk.
"""
import os
import io
import json
import time
import requests
import numpy as np
from PIL import Image
from typing import List, Dict, Optional, Union
from dotenv import load_dotenv
from tqdm import tqdm

# audio libs
import soundfile as sf
import librosa
import ffmpeg

# OpenAI (both old openai import and new OpenAI client usage retained from your original script)
from openai import OpenAI
import openai

# Transformers ALIGN
import torch
from transformers import AlignModel, AlignProcessor

load_dotenv()

# -------------------------
# Config & env
# -------------------------
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PEXELS_KEY = os.getenv("PEXELS_API_KEY")
FREESOUND_KEY = os.getenv("FREESOUND_API_KEY")

if not OPENAI_API_KEY:
    raise RuntimeError("Set OPENAI_API_KEY in env")
if not PEXELS_KEY:
    print("[WARN] No PEXELS_API_KEY found; Pexels fetching will return empty.")
if not FREESOUND_KEY:
    print("[WARN] No FREESOUND_API_KEY found; Freesound fetching will return empty.")

openai.api_key = OPENAI_API_KEY
client = OpenAI()  # used for LLM parsing

# -------------------------
# 1) LLM prompt parser (OpenAI)
# -------------------------
def parse_prompt_with_openai(raw_prompt: str) -> dict:
    """
    Uses the new openai client API:
      client.chat.completions.create(...)
    Returns a dict with clean_prompt, num_images, num_videos, num_audio, notes
    """
    system = (
        "You are a helpful assistant that converts an unstructured user prompt "
        "into a structured search plan. Return JSON only with fields: clean_prompt (string), "
        "num_images (int), num_videos (int), num_audio (int), notes (string). "
        "Rules:\n"
        "- If the user didn't request videos, set num_videos to 0.\n"
        "- If the user didn't request audio, set num_audio to 0.\n"
        "- Make counts sensible (1-10 per type). Be conservative if uncertain.\n"
        "- Clean the prompt for image/video/audio search (no extra commentary)."
    )
    user = f"User prompt: {raw_prompt}\n\nReturn JSON only."

    resp = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user}
        ],
        temperature=0.0,
        max_tokens=300,
    )

    text = ""
    try:
        text = resp.choices[0].message["content"]
    except Exception:
        try:
            text = resp.choices[0].message.content
        except Exception:
            text = str(resp)

    try:
        j = json.loads(text)
        return {
            "clean_prompt": j.get("clean_prompt") or raw_prompt,
            "num_images": int(j.get("num_images") or 0),
            "num_videos": int(j.get("num_videos") or 0),
            "num_audio": int(j.get("num_audio") or 0),
            "notes": j.get("notes","")
        }
    except Exception as e:
        print("[warn] OpenAI response not parseable as JSON; falling back. Raw text:", text[:400])
        return {
            "clean_prompt": raw_prompt,
            "num_images": 5,
            "num_videos": 0,
            "num_audio": 0,
            "notes": "fallback heuristic used"
        }

# -------------------------
# 2) Fetchers (Pexels & Freesound)
# -------------------------
def unified_record(provider, id, typ, title, description, url, thumbnail_url=None,
                   duration=None, uploader=None, published_at=None, tags=None, raw=None):
    return {
        "provider": provider,
        "id": str(id),
        "type": typ,  # image|video|audio
        "title": title or "",
        "description": description,
        "url": url,
        "thumbnail_url": thumbnail_url,
        "duration": duration,
        "uploader": uploader,
        "published_at": published_at,
        "tags": tags or [],
        "raw_meta": raw or {}
    }

def search_pexels_images(prompt: str, per_page: int = 10) -> List[Dict]:
    if not PEXELS_KEY:
        return []
    endpoint = "https://api.pexels.com/v1/search"
    headers = {"Authorization": PEXELS_KEY}
    params = {"query": prompt, "per_page": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("photos", []):
        out.append(unified_record(
            provider="pexels",
            id=it.get("id"),
            typ="image",
            title=it.get("alt") or "",
            description=None,
            url=(it.get("src") or {}).get("original"),
            thumbnail_url=(it.get("src") or {}).get("medium"),
            duration=None,
            uploader=it.get("photographer"),
            published_at=None,
            tags=[],
            raw=it
        ))
    return out

def search_pexels_videos(prompt: str, per_page: int = 8) -> List[Dict]:
    if not PEXELS_KEY:
        return []
    endpoint = "https://api.pexels.com/videos/search"
    headers = {"Authorization": PEXELS_KEY}
    params = {"query": prompt, "per_page": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("videos", []):
        file_url = None
        for vf in it.get("video_files", []) or []:
            if vf.get("quality") == "hd":
                file_url = vf.get("link"); break
        if not file_url and it.get("video_files"):
            file_url = it.get("video_files")[0].get("link")
        out.append(unified_record(
            provider="pexels",
            id=it.get("id"),
            typ="video",
            title=(it.get("user") or {}).get("name") or str(it.get("id")),
            description=it.get("url"),
            url=file_url,
            thumbnail_url=it.get("image"),
            duration=it.get("duration"),
            uploader=(it.get("user") or {}).get("name"),
            published_at=None,
            tags=[],
            raw=it
        ))
    return out

def search_freesound(prompt: str, per_page: int = 10) -> List[Dict]:
    if not FREESOUND_KEY:
        return []
    endpoint = "https://freesound.org/apiv2/search/text/"
    headers = {"Authorization": f"Token {FREESOUND_KEY}"}
    params = {"query": prompt, "page_size": per_page}
    r = requests.get(endpoint, headers=headers, params=params, timeout=15)
    r.raise_for_status()
    out = []
    for it in r.json().get("results", []):
        preview = (it.get("previews") or {}).get("preview-hq-mp3") or (it.get("previews") or {}).get("preview-lq-mp3")
        out.append(unified_record(
            provider="freesound",
            id=it.get("id"),
            typ="audio",
            title=it.get("name"),
            description=it.get("description"),
            url=preview,
            thumbnail_url=None,
            duration=it.get("duration"),
            uploader=it.get("username"),
            published_at=it.get("created"),
            tags=it.get("tags") or [],
            raw=it
        ))
    return out

def fetch_assets(clean_prompt: str, num_images:int, num_videos:int, num_audio:int) -> List[Dict]:
    results = []
    if num_images > 0:
        imgs = search_pexels_images(clean_prompt, per_page=num_images)
        results.extend(imgs[:num_images])
    if num_videos > 0:
        vids = search_pexels_videos(clean_prompt, per_page=num_videos)
        results.extend(vids[:num_videos])
    if num_audio > 0:
        aud = search_freesound(clean_prompt, per_page=num_audio)
        results.extend(aud[:num_audio])
    return results

# -------------------------
# 3) In-memory download & preprocessing to ALIGN-ready inputs
# -------------------------
def download_bytes(url: str, timeout=20) -> Optional[bytes]:
    if not url:
        return None
    try:
        r = requests.get(url, timeout=timeout, stream=True)
        r.raise_for_status()
        return r.content
    except Exception as e:
        print("[warn] download failed:", e)
        return None

def download_image_pil(url: str, timeout=15) -> Optional[Image.Image]:
    b = download_bytes(url, timeout=timeout)
    if not b:
        return None
    try:
        img = Image.open(io.BytesIO(b)).convert("RGB")
        return img
    except Exception as e:
        print("[warn] PIL open failed:", e)
        return None

def center_crop_and_resize_pil(img: Image.Image, size:int=224) -> Image.Image:
    w,h = img.size
    m = min(w,h)
    left = (w-m)//2
    top = (h-m)//2
    img = img.crop((left, top, left+m, top+m))
    img = img.resize((size,size), Image.LANCZOS)
    return img

def pil_to_chw_float(img: Image.Image, size:int=224, normalize:bool=True) -> np.ndarray:
    img = center_crop_and_resize_pil(img, size=size)
    arr = np.asarray(img).astype(np.float32)/255.0  # H W C
    arr = np.transpose(arr,(2,0,1)).copy()  # C H W
    if normalize:
        arr = arr*2.0 - 1.0
    return arr

# audio decode & mel -> 3-channel image-like tensor (C,H,W float in [-1,1])
def decode_audio_bytes_to_waveform(audio_bytes: bytes, target_sr:int=16000, duration:float=5.0) -> Optional[np.ndarray]:
    if not audio_bytes:
        return None
    try:
        bio = io.BytesIO(audio_bytes)
        data, sr = sf.read(bio, dtype='float32')
        if data.ndim > 1:
            data = np.mean(data, axis=1)
        if sr != target_sr:
            data = librosa.resample(data, sr, target_sr)
        desired = int(target_sr * duration)
        if len(data) > desired:
            start = max(0, (len(data)-desired)//2)
            data = data[start:start+desired]
        elif len(data) < desired:
            data = np.concatenate([data, np.zeros(desired - len(data), dtype=np.float32)])
        return data.astype(np.float32)
    except Exception as e:
        # fallback: ffmpeg
        try:
            proc = (
                ffmpeg.input('pipe:0')
                .output('pipe:1', format='f32le', ar=target_sr, ac=1)
                .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
            )
            out, err = proc.communicate(input=audio_bytes)
            if proc.returncode != 0:
                raise RuntimeError("ffmpeg decode fail")
            data = np.frombuffer(out, dtype=np.float32)
            desired = int(target_sr * duration)
            if len(data) > desired:
                start = max(0, (len(data)-desired)//2)
                data = data[start:start+desired]
            elif len(data) < desired:
                data = np.concatenate([data, np.zeros(desired - len(data), dtype=np.float32)])
            return data.astype(np.float32)
        except Exception as e2:
            print("[warn] audio decode failed:", e, e2)
            return None

def waveform_to_mel_image_tensor(wav: np.ndarray, sr:int=16000, n_mels:int=128, n_fft:int=2048, hop_length:int=512, size:int=224) -> Optional[np.ndarray]:
    if wav is None:
        return None
    S = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    S_db = librosa.power_to_db(S, ref=np.max)
    S_min, S_max = S_db.min(), S_db.max()
    S_norm = (S_db - S_min) / (S_max - S_min + 1e-9)
    img_arr = (S_norm * 255.0).astype(np.uint8)  # H(n_mels) x T
    pil = Image.fromarray(img_arr)
    pil = pil.resize((size,size), Image.LANCZOS).convert("RGB")
    arr = np.asarray(pil).astype(np.float32)/255.0
    arr = np.transpose(arr,(2,0,1)).copy()
    arr = arr*2.0 - 1.0
    return arr

def preprocess_records(records: List[Dict], image_size:int=224, audio_duration:float=5.0, audio_sr:int=16000) -> List[Dict]:
    processed = []
    for r in tqdm(records, desc="preprocess"):
        rec = dict(r)
        typ = rec["type"]
        if typ == "image":
            url = rec.get("thumbnail_url") or rec.get("url")
            pil = download_image_pil(url)
            rec["img_pil"] = pil
            rec["img_tensor"] = pil_to_chw_float(pil, size=image_size) if pil is not None else None
        elif typ == "video":
            url = rec.get("thumbnail_url") or rec.get("url")
            pil = download_image_pil(url)
            rec["img_pil"] = pil
            rec["img_tensor"] = pil_to_chw_float(pil, size=image_size) if pil is not None else None
        elif typ == "audio":
            url = rec.get("url")
            audio_bytes = download_bytes(url)
            wav = decode_audio_bytes_to_waveform(audio_bytes, target_sr=audio_sr, duration=audio_duration)
            rec["waveform"] = wav
            rec["audio_mel_tensor"] = waveform_to_mel_image_tensor(wav, sr=audio_sr, size=image_size) if wav is not None else None
            # also keep a PIL for processor
            if rec.get("audio_mel_tensor") is not None:
                # convert CHW float [-1,1] to PIL HWC uint8 via same conversion below when embedding
                pass
        else:
            rec["img_tensor"] = None
            rec["audio_mel_tensor"] = None
        processed.append(rec)
    return processed

# -------------------------
# 4) ALIGN embedding & scoring (Hugging Face AlignModel + AlignProcessor)
# -------------------------
class ALIGNEmbedder:
    """
    Embedder wrapper around Hugging Face AlignModel + AlignProcessor.
    model_name_or_path: e.g. "kakaobrain/align-base"
    Works with inputs that are:
      - numpy arrays shaped (3, H, W) float in [-1,1] or [0,1] or uint8 [0,255]
      - numpy arrays shaped (H, W, 3) uint8 or float
      - PIL.Image
    Returns L2-normalized embeddings as numpy arrays.
    """
    def __init__(self, model_name_or_path: str = "kakaobrain/align-base", device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name_or_path = model_name_or_path

        # load processor + model
        try:
            self.processor = AlignProcessor.from_pretrained(model_name_or_path)
            self.model = AlignModel.from_pretrained(model_name_or_path).to(self.device).eval()
        except Exception as e:
            raise RuntimeError(f"Failed to load ALIGN model/processor '{model_name_or_path}': {e}")

        try:
            self.embedding_dim = self.model.config.projection_dim if hasattr(self.model.config, "projection_dim") else 640
        except Exception:
            self.embedding_dim = 640

        try:
            self.dtype = next(self.model.parameters()).dtype
        except Exception:
            self.dtype = torch.float32

    def _convert_to_processor_image(self, arr_or_pil: Union[np.ndarray, Image.Image]):
        if isinstance(arr_or_pil, Image.Image):
            return arr_or_pil
        if not isinstance(arr_or_pil, np.ndarray):
            return arr_or_pil

        arr = arr_or_pil
        # If CHW convert to HWC
        if arr.ndim == 3 and arr.shape[0] == 3:  # C, H, W
            arr = np.transpose(arr, (1,2,0))
        # If floats in [-1,1], convert to [0,255]
        if np.issubdtype(arr.dtype, np.floating):
            if arr.max() <= 1.0 + 1e-6 and arr.min() >= -1.0 - 1e-6:
                arr = ((arr + 1.0) / 2.0 * 255.0).clip(0,255).astype(np.uint8)
            elif arr.max() <= 1.0 + 1e-6 and arr.min() >= 0.0 - 1e-6:
                arr = (arr * 255.0).clip(0,255).astype(np.uint8)
            else:
                arr = arr.clip(0,255).astype(np.uint8)
        elif np.issubdtype(arr.dtype, np.integer):
            arr = arr.astype(np.uint8)
        else:
            arr = arr.astype(np.uint8)
        return Image.fromarray(arr)

    @torch.no_grad()
    def embed_images(self, imgs: List[Union[np.ndarray, Image.Image]], batch_size: int = 16) -> np.ndarray:
        if len(imgs) == 0:
            return np.zeros((0, self.embedding_dim), dtype=np.float32)

        embs = []
        for i in range(0, len(imgs), batch_size):
            batch = imgs[i:i+batch_size]
            proc_imgs = [self._convert_to_processor_image(x) for x in batch]
            proc_out = self.processor(images=proc_imgs, return_tensors="pt")
            pixel_values = proc_out["pixel_values"].to(self.device).type(self.dtype)
            img_feats = self.model.get_image_features(pixel_values=pixel_values)
            img_feats = img_feats / (img_feats.norm(dim=-1, keepdim=True) + 1e-12)
            embs.append(img_feats.cpu().numpy())
        return np.vstack(embs).astype(np.float32)

    @torch.no_grad()
    def embed_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        if len(texts) == 0:
            return np.zeros((0, self.embedding_dim), dtype=np.float32)

        embs = []
        for i in range(0, len(texts), batch_size):
            chunk = texts[i:i+batch_size]
            proc_out = self.processor(text=chunk, return_tensors="pt", padding=True, truncation=True)
            input_ids = proc_out["input_ids"].to(self.device)
            attention_mask = proc_out.get("attention_mask")
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.device)
                text_feats = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
            else:
                text_feats = self.model.get_text_features(input_ids=input_ids)
            text_feats = text_feats / (text_feats.norm(dim=-1, keepdim=True) + 1e-12)
            embs.append(text_feats.cpu().numpy())
        return np.vstack(embs).astype(np.float32)

    @staticmethod
    def cosine_sim_matrix(img_embs: np.ndarray, txt_embs: np.ndarray) -> np.ndarray:
        if img_embs.size == 0 or txt_embs.size == 0:
            return np.zeros((img_embs.shape[0], txt_embs.shape[0]), dtype=np.float32)
        a = img_embs / (np.linalg.norm(img_embs, axis=1, keepdims=True) + 1e-12)
        b = txt_embs / (np.linalg.norm(txt_embs, axis=1, keepdims=True) + 1e-12)
        return a.dot(b.T)

# -------------------------
# 5) Pipeline orchestration (uses ALIGNEmbedder)
# -------------------------
def pipeline_run(user_prompt: str, top_k:int=5, align_model_id: str = "kakaobrain/align-base"):
    # 1) parse prompt via OpenAI LLM
    plan = parse_prompt_with_openai(user_prompt)
    print("LLM plan:", plan)

    # 2) fetch assets according to plan
    recs = fetch_assets(plan["clean_prompt"], plan["num_images"], plan["num_videos"], plan["num_audio"])
    print(f"Fetched total {len(recs)} assets from providers")

    # 3) preprocess in-memory to tensors / PILs
    processed = preprocess_records(recs, image_size=224, audio_duration=5.0, audio_sr=16000)

    # 4) prepare ALIGN embedder
    align = ALIGNEmbedder(model_name_or_path=align_model_id)

    # text embedding for the prompt
    txt_emb = align.embed_texts([plan["clean_prompt"]])[0]  # (D,)

    # collect image-like inputs (images + video thumbs)
    img_inputs = []
    img_indices = []
    for i, r in enumerate(processed):
        if r["type"] in ("image", "video") and r.get("img_pil") is not None:
            # prefer PIL (processor handles PIL directly)
            img_inputs.append(r["img_pil"])
            img_indices.append(i)

    # embed images
    if img_inputs:
        print(f"Embedding {len(img_inputs)} images with ALIGN...")
        img_embs = align.embed_images(img_inputs, batch_size=16)
    else:
        img_embs = np.zeros((0, align.embedding_dim), dtype=np.float32)

    # audio -> mel images
    audio_inputs = []
    audio_indices = []
    for i, r in enumerate(processed):
        if r["type"] == "audio" and r.get("audio_mel_tensor") is not None:
            # convert CHW float to PIL HWC for processor
            mel = r["audio_mel_tensor"]  # CHW float in [-1,1]
            # convert to HWC uint8
            arr = mel
            if arr.ndim == 3 and arr.shape[0] == 3:
                arr = np.transpose(arr, (1,2,0))
            # arr is float in [-1,1] or [0,1]; scale to 0..255
            if arr.max() <= 1.0 + 1e-6 and arr.min() >= -1.0 - 1e-6:
                arr = ((arr + 1.0) / 2.0 * 255.0).clip(0,255).astype(np.uint8)
            elif arr.max() <= 1.0 + 1e-6 and arr.min() >= 0.0 - 1e-6:
                arr = (arr * 255.0).clip(0,255).astype(np.uint8)
            else:
                arr = arr.clip(0,255).astype(np.uint8)
            audio_inputs.append(Image.fromarray(arr))
            audio_indices.append(i)

    if audio_inputs:
        print(f"Embedding {len(audio_inputs)} audio mel-images with ALIGN...")
        audio_embs = align.embed_images(audio_inputs, batch_size=8)
    else:
        audio_embs = np.zeros((0, align.embedding_dim), dtype=np.float32)

    # 5) compute similarities and rank
    results = []
    # images:
    if img_inputs:
        # compute per-image similarity to the single text embedding
        sims = (img_embs @ txt_emb) / (np.linalg.norm(img_embs, axis=1) * (np.linalg.norm(txt_emb)+1e-12))
        for idx_local, sim in enumerate(sims):
            rec_idx = img_indices[idx_local]
            r = processed[rec_idx]
            results.append({
                "provider": r["provider"],
                "id": r["id"],
                "type": r["type"],
                "title": r["title"],
                "score": float(sim),
                "url": r.get("url"),
                "thumbnail": r.get("thumbnail_url") or None
            })
    # audio:
    if audio_inputs:
        sims_a = (audio_embs @ txt_emb) / (np.linalg.norm(audio_embs, axis=1) * (np.linalg.norm(txt_emb)+1e-12))
        for idx_local, sim in enumerate(sims_a):
            rec_idx = audio_indices[idx_local]
            r = processed[rec_idx]
            results.append({
                "provider": r["provider"],
                "id": r["id"],
                "type": r["type"],
                "title": r["title"],
                "score": float(sim),
                "url": r.get("url"),
                "thumbnail": r.get("thumbnail_url") or None
            })

    # sort by score desc and print top_k
    results_sorted = sorted(results, key=lambda x: -x["score"])
    print("\nTop results:")
    for i, rr in enumerate(results_sorted[:top_k], start=1):
        print(f"{i}. [{rr['type']}] {rr['provider']} id={rr['id']} score={rr['score']:.4f}")
        print(f"    title: {rr['title']}")
        print(f"    url  : {rr['url']}")
        print()
    return results_sorted

# -------------------------
# CLI
# -------------------------
if __name__ == "__main__":
    raw = input("Enter user prompt: ").strip()
    if not raw:
        raw = "a happy person running on a beach at sunset with soft warm lighting"
    out = pipeline_run(raw, top_k=10, align_model_id="kakaobrain/align-base")


LLM plan: {'clean_prompt': 'a girl talking to a man', 'num_images': 5, 'num_videos': 0, 'num_audio': 0, 'notes': ''}
Fetched total 5 assets from providers


preprocess: 100%|██████████| 5/5 [00:01<00:00,  4.70it/s]


Embedding 5 images with ALIGN...

Top results:
1. [image] pexels id=5710922 score=0.1274
    title: A group therapy session with six adults seated in a circle, discussing support and mental health.
    url  : https://images.pexels.com/photos/5710922/pexels-photo-5710922.jpeg

2. [image] pexels id=5711017 score=0.1199
    title: A diverse group of people sitting in a circle during a therapy session in a sports hall.
    url  : https://images.pexels.com/photos/5711017/pexels-photo-5711017.jpeg

3. [image] pexels id=6668312 score=0.1123
    title: Father and daughter share a bonding moment reading a storybook indoors.
    url  : https://images.pexels.com/photos/6668312/pexels-photo-6668312.jpeg

4. [image] pexels id=5710988 score=0.1099
    title: A group therapy session indoors with diverse adults in a supportive environment.
    url  : https://images.pexels.com/photos/5710988/pexels-photo-5710988.jpeg

5. [image] pexels id=5711382 score=0.0868
    title: Black and white photo of a suppo