# EraEx: Full Production Deployment Pipeline

**Use this notebook to Initialize the System from Scratch.**

> **WARNING**: This process is computationally expensive. It generates embeddings for 780,000+ songs. Run on a machine with a GPU.

---


## 1. Setup & Dependencies
Ensure `ffmpeg` is installed on your system path.


In [None]:
%pip install librosa torch transformers sentence-transformers faiss-cpu python-dotenv tqdm accelerate vllm

import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import json
from dotenv import load_dotenv


try:
    import vllm  # noqa: F401
    print("vLLM available: True")
except Exception as exc:
    print(f"vLLM available: False ({exc})")
    print("Install/upgrade with: %pip install -U vllm, then restart runtime.")


def _resolve_project_root():
    env_root = os.getenv("ERAEX_PROJECT_ROOT", "").strip()
    candidates = []
    if env_root:
        candidates.append(Path(env_root).expanduser())

    cwd = Path.cwd().resolve()
    candidates.extend([cwd, cwd.parent])

    common = [
        Path("/content/Team4_CPSC-5830-01-Capstone-Project"),
        Path("/content/drive/MyDrive/Team4_CPSC-5830-01-Capstone-Project"),
        Path("/content/drive/MyDrive/EraEx"),
    ]
    candidates.extend(common)

    seen = set()
    deduped = []
    for c in candidates:
        key = str(c)
        if key in seen:
            continue
        seen.add(key)
        deduped.append(c)

    for root in deduped:
        if (root / "src").exists() and (root / "config").exists():
            return root

    raise RuntimeError(
        "Could not locate project root with src/ and config/. "
        "Set ERAEX_PROJECT_ROOT, e.g. %env ERAEX_PROJECT_ROOT=/content/Team4_CPSC-5830-01-Capstone-Project"
    )


PROJECT_ROOT = _resolve_project_root()
sys.path.insert(0, str(PROJECT_ROOT))
load_dotenv(PROJECT_ROOT / ".env", override=False)
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
if hf_token:
    os.environ["HF_TOKEN"] = hf_token
    os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
print(f"PROJECT_ROOT={PROJECT_ROOT}")

from src.core.text_embeddings import embedding_handler


## 2. Load Dataset


In [None]:
import os
import shutil
import subprocess

# --- Colab/H100 friendly I/O setup ---
# Put source CSV on fast local disk (/content) first, then sync outputs back.
SOURCE_DATA_PATH = Path(os.getenv("ERAEX_SOURCE_CSV", "../data/EraEx_Dataset_Rich.csv")).expanduser()
DEFAULT_LOCAL_RUN_DIR = "/content/eraex_run" if Path("/content").exists() else "../data"
LOCAL_RUN_DIR = Path(os.getenv("ERAEX_LOCAL_RUN_DIR", DEFAULT_LOCAL_RUN_DIR)).expanduser()
LOCAL_RUN_DIR.mkdir(parents=True, exist_ok=True)

LOCAL_DATA_PATH = LOCAL_RUN_DIR / "EraEx_Dataset_Rich.csv"
RICH_CSV_PATH_DEFAULT = LOCAL_RUN_DIR / "EraEx_Dataset_Rich.csv"
INDEX_DIR = LOCAL_RUN_DIR / "indexes"
INDEX_DIR.mkdir(parents=True, exist_ok=True)

# Optional sync controls (for Drive or rclone remote).
SYNC_ENABLED = bool(int(os.getenv("ERAEX_SYNC_ENABLED", "0")))
SYNC_MODE = os.getenv("ERAEX_SYNC_MODE", "copy").strip().lower()  # copy | rclone
SYNC_TARGET_DIR = os.getenv("ERAEX_SYNC_TARGET_DIR", "").strip()   # e.g. /content/drive/MyDrive/eraex
RCLONE_REMOTE = os.getenv("ERAEX_RCLONE_REMOTE", "").strip()       # e.g. gdrive:eraex


def sync_path(local_path: Path, relative_name: str):
    local_path = Path(local_path)
    if not SYNC_ENABLED:
        return
    if not local_path.exists():
        return

    if SYNC_MODE == "copy":
        if not SYNC_TARGET_DIR:
            print("[SYNC] skipped: ERAEX_SYNC_TARGET_DIR not set")
            return
        target = Path(SYNC_TARGET_DIR) / relative_name
        target.parent.mkdir(parents=True, exist_ok=True)
        if local_path.is_dir():
            if target.exists():
                shutil.rmtree(target, ignore_errors=True)
            shutil.copytree(local_path, target)
        else:
            shutil.copy2(local_path, target)
        print(f"[SYNC] copied -> {target}")
        return

    if SYNC_MODE == "rclone":
        if not RCLONE_REMOTE:
            print("[SYNC] skipped: ERAEX_RCLONE_REMOTE not set")
            return
        remote_path = f"{RCLONE_REMOTE.rstrip('/')}/{relative_name}"
        if local_path.is_dir():
            cmd = ["rclone", "sync", str(local_path), remote_path, "--progress"]
        else:
            cmd = ["rclone", "copyto", str(local_path), remote_path, "--progress"]
        try:
            subprocess.run(cmd, check=True)
            print(f"[SYNC] rclone -> {remote_path}")
        except Exception as exc:
            print(f"[SYNC] rclone failed: {exc}")
        return

    print(f"[SYNC] skipped: unknown SYNC_MODE={SYNC_MODE}")


if not SOURCE_DATA_PATH.exists():
    raise FileNotFoundError(f"Source CSV not found: {SOURCE_DATA_PATH}")

if SOURCE_DATA_PATH.resolve() != LOCAL_DATA_PATH.resolve():
    needs_copy = (not LOCAL_DATA_PATH.exists()) or (SOURCE_DATA_PATH.stat().st_mtime > LOCAL_DATA_PATH.stat().st_mtime)
    if needs_copy:
        print(f"Copying source CSV to local fast storage: {SOURCE_DATA_PATH} -> {LOCAL_DATA_PATH}")
        shutil.copy2(SOURCE_DATA_PATH, LOCAL_DATA_PATH)

DATA_PATH = str(LOCAL_DATA_PATH)
df = pd.read_csv(DATA_PATH, low_memory=False)
print(f"Loaded {len(df)} tracks from {DATA_PATH}")
print(f"LOCAL_RUN_DIR={LOCAL_RUN_DIR}")
print(f"INDEX_DIR={INDEX_DIR}")

# --- Column mapping audit (ensures all downstream fields are matched) ---
COLUMN_ALIASES = {
    "track_id": ["Track ID", "deezer_id", "video_id"],
    "title": ["Track Title", "title"],
    "artist": ["Artist Name", "artist_name"],
    "album": ["Album Title", "album_title"],
    "release_year": ["Release Year", "year"],
    "tags": ["Deezer Tags", "deezer_tags"],
    "playcount": ["Play Count", "deezer_playcount"],
    "deezer_rank": ["Deezer Rank", "deezer_rank", "rank"],
    "views": ["Views", "views"],
    "cover": ["Cover URL", "cover_url"],
    "description": ["Description", "description", "Track Description", "YouTube Description"],
    "instrumental": ["Instrumental", "instrumental", "is_instrumental"],
    "instrumental_confidence": ["Instrumental Confidence", "instrumental_confidence"],
}

COLUMN_MAP = {}
for canonical, aliases in COLUMN_ALIASES.items():
    matched = next((col for col in aliases if col in df.columns), None)
    COLUMN_MAP[canonical] = matched

print("\nColumn mapping summary:")
for canonical, col in COLUMN_MAP.items():
    print(f"- {canonical}: {col}")

critical = ["track_id", "title", "artist", "tags", "deezer_rank"]
missing_critical = [name for name in critical if COLUMN_MAP.get(name) is None]
if missing_critical:
    raise RuntimeError(f"Missing critical columns: {missing_critical}")

df.head(2)


## 2.5 One-Row LLM Description Preview (Qwen 2.5 3B)
Run this to preview one generated song description from your CSV before full enrichment.


In [None]:
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
PREVIEW_ROW_INDEX = 0
MAX_NEW_TOKENS = 120
TEMPERATURE = 0.30
TOP_P = 0.90


def _safe(value, default=""):
    if value is None:
        return default
    text = str(value).strip()
    if not text or text.lower() == "nan":
        return default
    return text


def _tags_text(row_dict):
    raw = row_dict.get("Deezer Tags", row_dict.get("deezer_tags", ""))
    if isinstance(raw, list):
        return ", ".join(str(v).strip() for v in raw if str(v).strip())
    raw = _safe(raw, "")
    if not raw:
        return ""
    try:
        parsed = json.loads(raw.replace("'", '"'))
        if isinstance(parsed, list):
            return ", ".join(str(v).strip() for v in parsed if str(v).strip())
    except Exception:
        pass
    return raw


def _extract_first_json(text):
    text = str(text or "").strip()
    if not text:
        return None
    try:
        return json.loads(text)
    except Exception:
        pass
    match = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not match:
        return None
    try:
        return json.loads(match.group(0))
    except Exception:
        return None


def _strip_entities(text, entities):
    cleaned = str(text or "")
    for ent in entities:
        ent = str(ent or "").strip()
        if len(ent) < 4:
            continue
        cleaned = re.sub(re.escape(ent), "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\s+", " ", cleaned)
    cleaned = re.sub(r"\s+([,.;:!?])", r"\1", cleaned)
    cleaned = re.sub(r"^[,;:.\- ]+", "", cleaned)
    return cleaned.strip()


row = df.iloc[int(PREVIEW_ROW_INDEX)].to_dict()
title = _safe(row.get("Track Title", row.get("title", "Unknown")), "Unknown")
artist = _safe(row.get("Artist Name", row.get("artist_name", "Unknown")), "Unknown")
album = _safe(row.get("Album Title", row.get("album_title", "")), "")
year = _safe(row.get("Release Year", row.get("year", "")), "")
tags = _tags_text(row)
rank = _safe(row.get("Deezer Rank", row.get("deezer_rank", "")), "")

metadata_block = (
    f"Title: {title}\n"
    f"Artist: {artist}\n"
    f"Album: {album}\n"
    f"Release Year: {year}\n"
    f"Tags: {tags}\n"
    f"Popularity Rank: {rank}\n"
)

user_prompt = (
    "Return strict JSON only with keys: description, instrumental, confidence. "
    "description must be 30-55 words, rich in sonic/mood/style cues, and must NOT mention title, artist, album, year, rank, or use quotes. "
    "Use only provided metadata.\n\n"
    + metadata_block
)

messages = [
    {
        "role": "system",
        "content": "You generate factual music metadata JSON for recommendation systems.",
    },
    {"role": "user", "content": user_prompt},
]

hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"Loading model: {MODEL_ID}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"HF token detected: {bool(hf_token)}")

tokenizer_kwargs = {"token": hf_token} if hf_token else {}
model_kwargs = {"dtype": dtype, "device_map": "auto"}
if hf_token:
    model_kwargs["token"] = hf_token

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **tokenizer_kwargs)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

chat_text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
inputs = tokenizer(chat_text, return_tensors="pt")
target_device = next(model.parameters()).device
inputs = {k: v.to(target_device) for k, v in inputs.items()}

with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=int(MAX_NEW_TOKENS),
        do_sample=True,
        temperature=float(TEMPERATURE),
        top_p=float(TOP_P),
        repetition_penalty=1.08,
        pad_token_id=tokenizer.eos_token_id,
    )

new_tokens = output[0, inputs["input_ids"].shape[1]:]
raw_output = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
parsed = _extract_first_json(raw_output) or {}
description_preview = _strip_entities(parsed.get("description", ""), [title, artist, album])
if year:
    description_preview = re.sub(rf"\b{re.escape(str(year))}\b", "", description_preview)
    description_preview = re.sub(r"\s+", " ", description_preview).strip()

print("\n=== Preview Input ===")
print(metadata_block)
print("=== Raw Model Output ===")
print(raw_output)
print("=== Cleaned Description Preview ===")
print(description_preview)
print("=== Instrumental Preview ===")
print({
    "instrumental": parsed.get("instrumental", None),
    "confidence": parsed.get("confidence", parsed.get("instrumental_confidence", None)),
})

# Release preview model memory before enrichment cell.
del model
del tokenizer
try:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
except Exception:
    pass


## 3. Enrich CSV (Description + Instrumental)
Pure Qwen mode: all description and instrumental fields are generated by Qwen2.5-3B (no local heuristic enrichment).


In [None]:
from tqdm.auto import tqdm
import os
import re
import time
import json
import gc
import hashlib
import subprocess
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.core.media_metadata import as_optional_bool, clean_description

RICH_CSV_PATH = str(RICH_CSV_PATH_DEFAULT)
OVERWRITE_SOURCE_CSV = False

# Pure Qwen enrichment with fast single-GPU execution options.
ENABLE_LLM_ENRICHMENT = True
LLM_MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
LLM_BACKEND = os.getenv("ERAEX_LLM_BACKEND", "auto").strip().lower()  # auto | vllm | transformers
LLM_MAX_ROWS = int(os.getenv("ERAEX_LLM_MAX_ROWS", "0"))               # 0 = all rows in shard
LLM_MAX_NEW_TOKENS = int(os.getenv("ERAEX_LLM_MAX_NEW_TOKENS", "120"))
LLM_TEMPERATURE = float(os.getenv("ERAEX_LLM_TEMPERATURE", "0.30"))
LLM_TOP_P = float(os.getenv("ERAEX_LLM_TOP_P", "0.90"))
LLM_REPETITION_PENALTY = float(os.getenv("ERAEX_LLM_REPETITION_PENALTY", "1.08"))

NUM_SHARDS = max(1, int(os.getenv("ERAEX_NUM_SHARDS", "1")))
SHARD_INDEX = int(os.getenv("ERAEX_SHARD_INDEX", "0"))
if SHARD_INDEX < 0 or SHARD_INDEX >= NUM_SHARDS:
    raise ValueError(f"Invalid shard settings: SHARD_INDEX={SHARD_INDEX}, NUM_SHARDS={NUM_SHARDS}")


def _run_nvidia_smi(query):
    cmd = ["nvidia-smi", f"--query-gpu={query}", "--format=csv,noheader,nounits"]
    try:
        out = subprocess.check_output(cmd, text=True).strip().splitlines()
        return out[0].strip() if out else ""
    except Exception:
        return ""


def _gpu_vram_gb():
    total_mb = _run_nvidia_smi("memory.total")
    try:
        return float(total_mb) / 1024.0
    except Exception:
        return 0.0


def _gpu_name():
    return _run_nvidia_smi("name") or "unknown-gpu"


GPU_VRAM_GB = _gpu_vram_gb()
DEFAULT_TF_BATCH_SIZE = (
    24 if GPU_VRAM_GB >= 90 else
    16 if GPU_VRAM_GB >= 60 else
    12 if GPU_VRAM_GB >= 40 else
    8 if GPU_VRAM_GB >= 24 else
    4 if GPU_VRAM_GB >= 16 else 2
)
DEFAULT_VLLM_MAX_NUM_SEQS = (
    256 if GPU_VRAM_GB >= 90 else
    192 if GPU_VRAM_GB >= 60 else
    128 if GPU_VRAM_GB >= 40 else
    96 if GPU_VRAM_GB >= 24 else 64
)
DEFAULT_VLLM_PROMPT_BATCH_SIZE = (
    1024 if GPU_VRAM_GB >= 90 else
    768 if GPU_VRAM_GB >= 60 else
    512 if GPU_VRAM_GB >= 40 else 256
)

TRANSFORMERS_BATCH_SIZE = int(os.getenv("ERAEX_TF_BATCH_SIZE", str(DEFAULT_TF_BATCH_SIZE)))
VLLM_MAX_NUM_SEQS = int(os.getenv("ERAEX_VLLM_MAX_NUM_SEQS", str(DEFAULT_VLLM_MAX_NUM_SEQS)))
VLLM_PROMPT_BATCH_SIZE = int(os.getenv("ERAEX_VLLM_PROMPT_BATCH_SIZE", str(DEFAULT_VLLM_PROMPT_BATCH_SIZE)))
VLLM_GPU_MEMORY_UTIL = float(os.getenv("ERAEX_VLLM_GPU_MEM_UTIL", "0.92"))
VLLM_MAX_MODEL_LEN = int(os.getenv("ERAEX_VLLM_MAX_MODEL_LEN", "8192"))

CACHE_PATH = Path(INDEX_DIR) / "qwen_cache.jsonl"
CACHE_FLUSH_EVERY = int(os.getenv("ERAEX_CACHE_FLUSH_EVERY", "500"))


def _get_hf_token():
    return (
        os.getenv("HF_TOKEN")
        or os.getenv("HUGGINGFACE_HUB_TOKEN")
        or None
    )


def _release_gpu_memory():
    # Free any previously loaded transformer objects from earlier notebook cells.
    for name in ["model", "tokenizer", "_tf_model", "_tf_tokenizer", "_llm_model", "_llm_tokenizer"]:
        if name in globals():
            try:
                del globals()[name]
            except Exception:
                pass
    gc.collect()
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass


def _resolve_backend(preferred):
    if preferred not in {"auto", "vllm", "transformers"}:
        raise ValueError(f"Unsupported backend: {preferred}")
    if preferred in {"auto", "vllm"}:
        try:
            import vllm  # noqa: F401
            return "vllm"
        except Exception as exc:
            if preferred == "vllm":
                raise RuntimeError("vLLM requested but unavailable. Install with `%pip install vllm`.") from exc
            print(f"[LLM] vLLM unavailable ({exc}); falling back to transformers")
    return "transformers"


def _pick(row, keys, default=""):
    for key in keys:
        value = row.get(key, "")
        if value is None:
            continue
        if isinstance(value, float) and pd.isna(value):
            continue
        text = str(value).strip()
        if text and text.lower() != "nan":
            return value
    return default


def _parse_tags(raw_tags):
    if isinstance(raw_tags, list):
        return [str(tag) for tag in raw_tags if str(tag).strip()]
    if not isinstance(raw_tags, str):
        return []
    text = raw_tags.strip()
    if not text:
        return []
    try:
        parsed = json.loads(text.replace("'", '"'))
        if isinstance(parsed, list):
            return [str(tag) for tag in parsed if str(tag).strip()]
    except Exception:
        pass
    return []


def _extract_first_json(text):
    text = str(text or "").strip()
    if not text:
        return None
    try:
        return json.loads(text)
    except Exception:
        pass
    match = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not match:
        return None
    try:
        return json.loads(match.group(0))
    except Exception:
        return None


def _strip_entities(text, entities):
    cleaned = str(text or "")
    for ent in entities:
        ent = str(ent or "").strip()
        if len(ent) < 4:
            continue
        cleaned = re.sub(re.escape(ent), "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\s+", " ", cleaned)
    cleaned = re.sub(r"\s+([,.;:!?])", r"\1", cleaned)
    cleaned = re.sub(r"^[,;:.\- ]+", "", cleaned)
    return cleaned.strip()


def _clean_generated_description(raw_desc, row_dict):
    title = str(_pick(row_dict, ["Track Title", "title"], "")).strip()
    artist = str(_pick(row_dict, ["Artist Name", "artist_name"], "")).strip()
    album = str(_pick(row_dict, ["Album Title", "album_title"], "")).strip()
    year = str(_pick(row_dict, ["Release Year", "year"], "")).strip()

    desc = str(raw_desc or "").strip().strip('"').strip("'")
    desc = _strip_entities(desc, [title, artist, album])
    if year:
        desc = re.sub(rf"\b{re.escape(year)}\b", "", desc)
    desc = re.sub(r"\s+", " ", desc).strip()
    desc = clean_description(desc, max_chars=280)
    return desc


def _prompt_fields(row_dict):
    tags = [str(t).strip().lower() for t in _parse_tags(_pick(row_dict, ["Deezer Tags", "deezer_tags"], "[]"))]
    tags = [t for t in tags if t][:16]
    return {
        "title": str(_pick(row_dict, ["Track Title", "title"], "Unknown")).strip().lower(),
        "artist": str(_pick(row_dict, ["Artist Name", "artist_name"], "Unknown")).strip().lower(),
        "album": str(_pick(row_dict, ["Album Title", "album_title"], "")).strip().lower(),
        "year": str(_pick(row_dict, ["Release Year", "year"], "")).strip(),
        "rank": str(_pick(row_dict, ["Deezer Rank", "deezer_rank", "rank"], "")).strip(),
        "tags": tags,
    }


def _prompt_key(row_dict):
    payload = _prompt_fields(row_dict)
    raw = json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
    return hashlib.sha1(raw.encode("utf-8")).hexdigest()


def _build_llm_messages(row_dict):
    title = str(_pick(row_dict, ["Track Title", "title"], "Unknown")).strip() or "Unknown"
    artist = str(_pick(row_dict, ["Artist Name", "artist_name"], "Unknown")).strip() or "Unknown"
    album = str(_pick(row_dict, ["Album Title", "album_title"], "")).strip()
    year = str(_pick(row_dict, ["Release Year", "year"], "")).strip()
    rank = str(_pick(row_dict, ["Deezer Rank", "deezer_rank", "rank"], "")).strip()
    tags = _parse_tags(_pick(row_dict, ["Deezer Tags", "deezer_tags"], "[]"))
    tags_text = ", ".join(tags[:12])

    user_prompt = (
        "Return strict JSON only with keys: description, instrumental, confidence. "
        "description must be 30-55 words, rich in sonic/mood/style cues, and must NOT mention title, artist, album, year, rank, or use quotes. "
        "Use only provided metadata and avoid hallucinations.\n\n"
        f"Title: {title}\n"
        f"Artist: {artist}\n"
        f"Album: {album}\n"
        f"Release Year: {year}\n"
        f"Tags: {tags_text}\n"
        f"Popularity Rank: {rank}\n"
    )

    return [
        {
            "role": "system",
            "content": "You generate factual music metadata JSON for recommendation systems.",
        },
        {"role": "user", "content": user_prompt},
    ]


def _parse_payload_from_raw(raw_output, row_dict):
    parsed = _extract_first_json(raw_output)
    if not isinstance(parsed, dict):
        return None

    desc_new = _clean_generated_description(parsed.get("description", ""), row_dict)
    inst_new = as_optional_bool(parsed.get("instrumental"))

    conf_val = parsed.get("instrumental_confidence", parsed.get("confidence", None))
    try:
        conf_new = float(conf_val) if conf_val is not None else None
    except Exception:
        conf_new = None
    if conf_new is not None:
        if conf_new > 1.0:
            conf_new = conf_new / 100.0
        conf_new = min(1.0, max(0.0, conf_new))

    return {
        "description": desc_new,
        "instrumental": inst_new,
        "instrumental_confidence": conf_new,
    }


def _load_cache(path):
    cache = {}
    if not path.exists():
        return cache
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            key = str(obj.get("key", "")).strip()
            if not key:
                continue
            cache[key] = {
                "description": str(obj.get("description", "") or ""),
                "instrumental": obj.get("instrumental", None),
                "instrumental_confidence": obj.get("instrumental_confidence", None),
            }
    return cache


def _append_cache_entries(path, entries):
    if not entries:
        return
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        for key, payload in entries:
            record = {
                "key": key,
                "description": payload.get("description", ""),
                "instrumental": payload.get("instrumental", None),
                "instrumental_confidence": payload.get("instrumental_confidence", None),
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")


def _apply_payload_to_row(idx, payload, source):
    changed = False

    desc_new = clean_description(str(payload.get("description", "") or ""), max_chars=280)
    if desc_new:
        df_rich.at[idx, "Description"] = desc_new
        changed = True

    inst_new = as_optional_bool(payload.get("instrumental"))
    if inst_new is not None:
        df_rich.at[idx, "Instrumental"] = bool(inst_new)
        changed = True

    conf_val = payload.get("instrumental_confidence", None)
    try:
        conf_new = float(conf_val) if conf_val is not None else None
    except Exception:
        conf_new = None
    if conf_new is not None:
        if conf_new > 1.0:
            conf_new = conf_new / 100.0
        conf_new = min(1.0, max(0.0, conf_new))
        df_rich.at[idx, "Instrumental Confidence"] = float(conf_new)
        changed = True

    df_rich.at[idx, "Description Source"] = source if changed else "qwen-empty"
    return changed


def _load_transformers_model_once():
    global _tf_tokenizer, _tf_model
    if "_tf_tokenizer" in globals() and "_tf_model" in globals():
        return _tf_tokenizer, _tf_model

    hf_token = _get_hf_token()
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    tokenizer_kwargs = {"token": hf_token} if hf_token else {}
    model_kwargs = {"dtype": dtype, "device_map": "auto"}
    if hf_token:
        model_kwargs["token"] = hf_token

    _tf_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, **tokenizer_kwargs)
    _tf_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_ID, **model_kwargs)
    if _tf_tokenizer.pad_token is None:
        _tf_tokenizer.pad_token = _tf_tokenizer.eos_token

    return _tf_tokenizer, _tf_model


# Initialize output dataframe for pure Qwen population.
df_rich = df.copy()
for legacy_col in ["description", "Track Description", "YouTube Description"]:
    if legacy_col in df_rich.columns:
        df_rich = df_rich.drop(columns=[legacy_col])

df_rich["Track ID"] = df_rich.apply(
    lambda r: str(_pick(r.to_dict(), ["Track ID", "deezer_id", "video_id"], "")).strip(),
    axis=1,
)
df_rich["Description"] = ""
df_rich["Instrumental"] = pd.NA
df_rich["Instrumental Confidence"] = 0.0
df_rich["Description Source"] = "qwen-pending"


rows_updated = 0
rows_failed = 0
rows_cache_hits = 0
unique_generated = 0

if ENABLE_LLM_ENRICHMENT:
    backend = _resolve_backend(LLM_BACKEND)
    hf_token = _get_hf_token()

    print(f"LLM model={LLM_MODEL_ID}")
    print(f"backend={backend} | GPU_VRAM_GB={GPU_VRAM_GB:.1f}")
    print(f"CUDA device: {_gpu_name()}")
    print(f"HF token detected: {bool(hf_token)}")
    print(f"shard={SHARD_INDEX + 1}/{NUM_SHARDS}")

    all_indices = list(df_rich.index)
    shard_indices = [idx for pos, idx in enumerate(all_indices) if pos % NUM_SHARDS == SHARD_INDEX]
    if LLM_MAX_ROWS > 0:
        shard_indices = shard_indices[:LLM_MAX_ROWS]

    key_to_indices = {}
    key_to_row = {}
    for idx in tqdm(shard_indices, desc="Preparing prompt keys"):
        row_dict = df_rich.loc[idx].to_dict()
        key = _prompt_key(row_dict)
        key_to_indices.setdefault(key, []).append(idx)
        if key not in key_to_row:
            key_to_row[key] = row_dict

    unique_keys = list(key_to_indices.keys())
    cache = _load_cache(CACHE_PATH)
    pending_keys = []

    for key in unique_keys:
        cached_payload = cache.get(key)
        if cached_payload is None:
            pending_keys.append(key)
            continue
        for idx in key_to_indices[key]:
            ok = _apply_payload_to_row(idx, cached_payload, "qwen-cache")
            rows_updated += int(ok)
            rows_failed += int(not ok)
            rows_cache_hits += 1

    print(
        f"target_rows={len(shard_indices)} | unique_prompts={len(unique_keys)} | "
        f"cache_hits_rows={rows_cache_hits} | pending_unique_prompts={len(pending_keys)}"
    )

    cache_buffer = []
    run_start = time.time()

    if backend == "vllm":
        from vllm import LLM, SamplingParams

        os.environ.setdefault("HF_TOKEN", hf_token)
        os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", hf_token)
        os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

        chat_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, token=hf_token)
        sampling_params = SamplingParams(
            temperature=float(LLM_TEMPERATURE),
            top_p=float(LLM_TOP_P),
            max_tokens=int(LLM_MAX_NEW_TOKENS),
            repetition_penalty=float(LLM_REPETITION_PENALTY),
        )

        _release_gpu_memory()

        init_candidates = [
            (int(VLLM_MAX_NUM_SEQS), float(VLLM_GPU_MEMORY_UTIL), int(VLLM_MAX_MODEL_LEN)),
            (max(64, int(VLLM_MAX_NUM_SEQS * 0.75)), min(float(VLLM_GPU_MEMORY_UTIL), 0.88), int(VLLM_MAX_MODEL_LEN)),
            (max(64, int(VLLM_MAX_NUM_SEQS * 0.50)), min(float(VLLM_GPU_MEMORY_UTIL), 0.84), min(int(VLLM_MAX_MODEL_LEN), 6144)),
            (64, min(float(VLLM_GPU_MEMORY_UTIL), 0.80), min(int(VLLM_MAX_MODEL_LEN), 4096)),
        ]

        dedup_init_candidates = []
        seen = set()
        for cand in init_candidates:
            key = (int(cand[0]), round(float(cand[1]), 3), int(cand[2]))
            if key in seen:
                continue
            seen.add(key)
            dedup_init_candidates.append(cand)

        vllm_engine = None
        active_max_num_seqs = None
        active_gpu_mem_util = None
        active_max_model_len = None
        vllm_init_error = None

        for cand_max_num_seqs, cand_gpu_mem, cand_max_model_len in dedup_init_candidates:
            try:
                print(
                    f"[LLM] vLLM init attempt: max_num_seqs={cand_max_num_seqs}, "
                    f"gpu_mem_util={cand_gpu_mem:.2f}, max_model_len={cand_max_model_len}"
                )
                vllm_engine = LLM(
                    model=LLM_MODEL_ID,
                    tensor_parallel_size=1,
                    dtype="float16",
                    gpu_memory_utilization=float(cand_gpu_mem),
                    max_num_seqs=int(cand_max_num_seqs),
                    max_model_len=int(cand_max_model_len),
                )
                active_max_num_seqs = int(cand_max_num_seqs)
                active_gpu_mem_util = float(cand_gpu_mem)
                active_max_model_len = int(cand_max_model_len)
                break
            except Exception as exc:
                vllm_init_error = exc
                print(f"[LLM] vLLM init failed for this config: {exc}")
                time.sleep(1)

        if vllm_engine is not None:
            effective_prompt_batch = max(64, min(int(VLLM_PROMPT_BATCH_SIZE), int(active_max_num_seqs) * 4))
            print(
                f"vLLM config: max_num_seqs={active_max_num_seqs}, "
                f"prompt_batch_size={effective_prompt_batch}, gpu_mem_util={active_gpu_mem_util}, "
                f"max_model_len={active_max_model_len}"
            )

            for start in tqdm(range(0, len(pending_keys), effective_prompt_batch), desc="qwen enrichment (vllm)"):
                batch_keys = pending_keys[start : start + effective_prompt_batch]
                prompts = [
                    chat_tokenizer.apply_chat_template(
                        _build_llm_messages(key_to_row[key]),
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                    for key in batch_keys
                ]
                outputs = vllm_engine.generate(prompts, sampling_params=sampling_params, use_tqdm=False)

                for key, out in zip(batch_keys, outputs):
                    raw = ""
                    if out.outputs:
                        raw = str(out.outputs[0].text or "").strip()

                    payload = _parse_payload_from_raw(raw, key_to_row[key])
                    if payload is None:
                        for idx in key_to_indices[key]:
                            df_rich.at[idx, "Description Source"] = "qwen-parse-failed"
                            rows_failed += 1
                        continue

                    unique_generated += 1
                    cache_buffer.append((key, payload))

                    for idx in key_to_indices[key]:
                        ok = _apply_payload_to_row(idx, payload, "qwen2.5-3b")
                        rows_updated += int(ok)
                        rows_failed += int(not ok)

                if len(cache_buffer) >= CACHE_FLUSH_EVERY:
                    _append_cache_entries(CACHE_PATH, cache_buffer)
                    cache_buffer = []
        else:
            print(f"[LLM] vLLM init failed after retries ({vllm_init_error}); switching to transformers backend")
            backend = "transformers"

    if backend != "vllm":
        tf_tokenizer, tf_model = _load_transformers_model_once()
        batch_size = max(1, int(TRANSFORMERS_BATCH_SIZE))
        print(f"transformers batch_size={batch_size}")

        for start in tqdm(range(0, len(pending_keys), batch_size), desc="qwen enrichment (transformers)"):
            batch_keys = pending_keys[start : start + batch_size]
            row_dicts = [key_to_row[key] for key in batch_keys]
            messages_batch = [_build_llm_messages(row_dict) for row_dict in row_dicts]
            chat_texts = [
                tf_tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
                for msgs in messages_batch
            ]

            encoded = tf_tokenizer(
                chat_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=768,
            )
            input_lens = encoded["attention_mask"].sum(dim=1)
            target_device = next(tf_model.parameters()).device
            encoded = {k: v.to(target_device) for k, v in encoded.items()}

            with torch.no_grad():
                outputs = tf_model.generate(
                    **encoded,
                    max_new_tokens=int(LLM_MAX_NEW_TOKENS),
                    do_sample=True,
                    temperature=float(LLM_TEMPERATURE),
                    top_p=float(LLM_TOP_P),
                    repetition_penalty=float(LLM_REPETITION_PENALTY),
                    pad_token_id=tf_tokenizer.eos_token_id,
                )

            for pos, key in enumerate(batch_keys):
                new_tokens = outputs[pos, int(input_lens[pos]) :]
                raw = tf_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

                payload = _parse_payload_from_raw(raw, key_to_row[key])
                if payload is None:
                    for idx in key_to_indices[key]:
                        df_rich.at[idx, "Description Source"] = "qwen-parse-failed"
                        rows_failed += 1
                    continue

                unique_generated += 1
                cache_buffer.append((key, payload))

                for idx in key_to_indices[key]:
                    ok = _apply_payload_to_row(idx, payload, "qwen2.5-3b")
                    rows_updated += int(ok)
                    rows_failed += int(not ok)

            if len(cache_buffer) >= CACHE_FLUSH_EVERY:
                _append_cache_entries(CACHE_PATH, cache_buffer)
                cache_buffer = []

    _append_cache_entries(CACHE_PATH, cache_buffer)

    elapsed_h = (time.time() - run_start) / 3600.0
    print(
        f"rows_updated={rows_updated}, rows_failed={rows_failed}, rows_cache_hits={rows_cache_hits}, "
        f"unique_generated={unique_generated}, elapsed={elapsed_h:.2f}h"
    )
else:
    print("LLM enrichment disabled")


df_rich.to_csv(RICH_CSV_PATH, index=False, encoding="utf-8")
print(f"Saved rich CSV: {RICH_CSV_PATH} rows={len(df_rich)}")
if OVERWRITE_SOURCE_CSV:
    df_rich.to_csv(DATA_PATH, index=False, encoding="utf-8")
    print(f"Overwrote source CSV: {DATA_PATH}")

sync_path(Path(RICH_CSV_PATH), "EraEx_Dataset_Rich.csv")
print(f"rows_updated={rows_updated}, rows_failed={rows_failed}, rows_cache_hits={rows_cache_hits}")
df_rich[["Track ID", "Track Title", "Artist Name", "Description", "Instrumental", "Instrumental Confidence", "Description Source"]].head(3)


## 4. Build `id_map.json` + `metadata.json` from Rich CSV


In [None]:
if "df_rich" not in globals():
    df_rich = pd.read_csv(RICH_CSV_PATH, low_memory=False)


def _parse_tags_from_row(row):
    raw = row.get("Deezer Tags", row.get("deezer_tags", "[]"))
    if isinstance(raw, list):
        return [str(tag).strip() for tag in raw if str(tag).strip()]
    if not isinstance(raw, str):
        return []
    text = raw.strip()
    if not text:
        return []
    try:
        parsed = json.loads(text.replace("'", '"'))
        if isinstance(parsed, list):
            return [str(tag).strip() for tag in parsed if str(tag).strip()]
    except Exception:
        pass
    return []


def _safe_float(value, default=0.0):
    try:
        num = float(value)
        if pd.isna(num):
            return float(default)
        return float(num)
    except Exception:
        return float(default)


def _safe_unit_float(value, default=0.0):
    num = _safe_float(value, default=default)
    if num > 1.0:
        num = num / 100.0
    return float(min(1.0, max(0.0, num)))


def _derive_vibe_tags(tags, tempo, energy, brightness, mood, valence):
    vibes = set()
    for tag in tags:
        for tok in re.findall(r"[a-z0-9]+", str(tag).lower()):
            if tok:
                vibes.add(tok)

    if mood >= 0.68:
        vibes.add("moody")
    if valence >= 0.66:
        vibes.add("uplifting")
    elif valence <= 0.36:
        vibes.add("melancholic")
    if energy >= 0.66:
        vibes.add("energetic")
    elif energy <= 0.38:
        vibes.add("chill")
    if tempo >= 0.70:
        vibes.add("fast")
    elif tempo <= 0.38:
        vibes.add("slow")
    if brightness <= 0.35:
        vibes.add("dark")
    elif brightness >= 0.68:
        vibes.add("bright")

    ordered = sorted(vibes)
    return ordered[:24]


ids = []
metadata = {}

for _, row in tqdm(df_rich.iterrows(), total=len(df_rich), desc="Building metadata"):
    row = row.to_dict()
    track_id = str(row.get("Track ID", "")).strip()
    if not track_id:
        continue

    title = str(row.get("Track Title", row.get("title", "Unknown"))).strip() or "Unknown"
    artist = str(row.get("Artist Name", row.get("artist_name", "Unknown"))).strip() or "Unknown"
    album = str(row.get("Album Title", row.get("album_title", "Unknown"))).strip() or "Unknown"
    tags = _parse_tags_from_row(row)

    try:
        year = int(float(row.get("Release Year", row.get("year", 0)) or 0))
    except Exception:
        year = 0
    try:
        playcount = int(float(row.get("Play Count", row.get("deezer_playcount", 0)) or 0))
    except Exception:
        playcount = 0
    try:
        deezer_rank = int(float(row.get("Deezer Rank", row.get("deezer_rank", row.get("rank", 0))) or 0))
    except Exception:
        deezer_rank = 0
    try:
        views = int(float(row.get("Views", row.get("views", 0)) or 0))
    except Exception:
        views = 0

    tempo = _safe_unit_float(row.get("tempo", row.get("Tempo", 0.0)), default=0.0)
    energy = _safe_unit_float(row.get("energy", row.get("Energy", 0.0)), default=0.0)
    brightness = _safe_unit_float(row.get("brightness", row.get("Brightness", 0.0)), default=0.0)
    mood = _safe_unit_float(row.get("mood", row.get("Mood", 0.0)), default=0.0)
    valence = _safe_unit_float(row.get("valence", row.get("Valence", 0.0)), default=0.0)

    description = clean_description(str(row.get("Description", "") or ""), max_chars=280)
    instrumental = as_optional_bool(row.get("Instrumental"))
    confidence = _safe_unit_float(row.get("Instrumental Confidence", 0.0), default=0.0)

    vibe_tags = _derive_vibe_tags(tags, tempo, energy, brightness, mood, valence)

    ids.append(track_id)
    metadata[track_id] = {
        "id": track_id,
        "video_id": track_id,
        "title": title,
        "name": title,
        "artist_name": artist,
        "album_title": album,
        "description": description,
        "description_source": str(row.get("Description Source", "") or ""),
        "instrumental": instrumental,
        "instrumental_confidence": float(confidence),
        "year": year,
        "release_date": f"{year}-01-01" if year else "",
        "deezer_tags": tags,
        "genres": tags,
        "vibe_tags": vibe_tags,
        "tempo": float(tempo),
        "energy": float(energy),
        "brightness": float(brightness),
        "mood": float(mood),
        "valence": float(valence),
        "deezer_playcount": playcount,
        "deezer_rank": deezer_rank,
        "rank": deezer_rank,
        "views": views,
        "cover_url": str(row.get("Cover URL", row.get("cover_url", "")) or "").strip(),
    }

index_dir = Path(INDEX_DIR)
index_dir.mkdir(parents=True, exist_ok=True)
with open(index_dir / "id_map.json", "w", encoding="utf-8") as f:
    json.dump(ids, f)
with open(index_dir / "metadata.json", "w", encoding="utf-8") as f:
    json.dump(metadata, f, ensure_ascii=False)

sync_path(index_dir / "id_map.json", "indexes/id_map.json")
sync_path(index_dir / "metadata.json", "indexes/metadata.json")
print(f"Saved id_map.json ({len(ids)}) and metadata.json ({len(metadata)})")
if ids:
    sample_meta = metadata[str(ids[0])]
    print("Sample metadata keys:", sorted(sample_meta.keys()))


## 5. Generate Context-First BGE-M3 Embeddings (Genre/Tags/Vibe/Description/Audio)


In [None]:
import gc
import os
import re
import json
import torch
from tqdm.auto import tqdm
from src.core.media_metadata import as_optional_bool, clean_description

# Suggest embedding batch size from available VRAM.
def _suggest_embed_batch_size(vram_gb):
    if vram_gb >= 40:
        return 256
    if vram_gb >= 24:
        return 160
    if vram_gb >= 16:
        return 96
    if vram_gb >= 10:
        return 64
    return 32


def _as_list(value):
    if isinstance(value, list):
        return value
    if isinstance(value, str):
        text = value.strip()
        if not text:
            return []
        try:
            parsed = json.loads(text.replace("'", '"'))
            if isinstance(parsed, list):
                return parsed
        except Exception:
            pass
    return []


def _safe_unit_float(value, default=0.0):
    try:
        num = float(value)
    except Exception:
        num = float(default)
    if num > 1.0:
        num = num / 100.0
    return float(min(1.0, max(0.0, num)))


def _bucket(v, low=0.35, high=0.67):
    if v < low:
        return "low"
    if v > high:
        return "high"
    return "medium"


def _audio_vibe_tokens(tempo, energy, brightness, mood, valence):
    vibes = []
    if mood >= 0.68:
        vibes.append("moody")
    if valence >= 0.66:
        vibes.append("uplifting")
    elif valence <= 0.36:
        vibes.append("melancholic")
    if energy >= 0.66:
        vibes.append("energetic")
    elif energy <= 0.38:
        vibes.append("chill")
    if tempo >= 0.70:
        vibes.append("fast")
    elif tempo <= 0.38:
        vibes.append("slow")
    if brightness <= 0.35:
        vibes.append("dark")
    elif brightness >= 0.68:
        vibes.append("bright")
    dedup = []
    seen = set()
    for v in vibes:
        if v not in seen:
            dedup.append(v)
            seen.add(v)
    return dedup


# Build one context-first embedding text per track id from metadata.
def _build_text_for_track(track_id):
    meta = metadata.get(str(track_id), {})

    title = str(meta.get("title", meta.get("name", "Unknown")) or "Unknown").strip() or "Unknown"
    artist = str(meta.get("artist_name", "Unknown") or "Unknown").strip() or "Unknown"

    tags = [str(t).strip().lower() for t in _as_list(meta.get("deezer_tags", [])) if str(t).strip()]
    genres = [str(g).strip().lower() for g in _as_list(meta.get("genres", [])) if str(g).strip()]
    vibe_tags = [str(v).strip().lower() for v in _as_list(meta.get("vibe_tags", [])) if str(v).strip()]

    tempo = _safe_unit_float(meta.get("tempo", 0.0), default=0.0)
    energy = _safe_unit_float(meta.get("energy", 0.0), default=0.0)
    brightness = _safe_unit_float(meta.get("brightness", 0.0), default=0.0)
    mood = _safe_unit_float(meta.get("mood", 0.0), default=0.0)
    valence = _safe_unit_float(meta.get("valence", 0.0), default=0.0)

    audio_vibes = _audio_vibe_tokens(tempo, energy, brightness, mood, valence)
    if audio_vibes:
        vibe_tags = list(dict.fromkeys(vibe_tags + audio_vibes))

    description = clean_description(str(meta.get("description", "") or ""), max_chars=280)
    instrumental = as_optional_bool(meta.get("instrumental"))
    if instrumental is True:
        vocal_type = "instrumental"
    elif instrumental is False:
        vocal_type = "non-instrumental"
    else:
        vocal_type = "unknown-vocals"

    parts = []

    # Context-heavy fields first.
    if genres:
        parts.append(f"Genre profile: {' '.join(genres[:10])}")
    if tags:
        parts.append(f"Tag profile: {' '.join(tags[:16])}")
    if vibe_tags:
        parts.append(f"Vibe profile: {' '.join(vibe_tags[:16])}")
    if description:
        parts.append(f"Description: {description}")

    parts.append(
        "Audio feature scores: "
        f"tempo {tempo:.3f}, energy {energy:.3f}, brightness {brightness:.3f}, mood {mood:.3f}, valence {valence:.3f}."
    )
    parts.append(
        "Audio feature buckets: "
        f"tempo { _bucket(tempo) }, energy { _bucket(energy) }, brightness { _bucket(brightness) }, "
        f"mood { _bucket(mood) }, valence { _bucket(valence) }."
    )
    parts.append(f"Vocal type: {vocal_type}")

    # Keep identity with lower emphasis for exact-name queries.
    parts.append(f"Identity reference: artist {artist}. title {title}.")

    return ". ".join([p.strip() for p in parts if p and str(p).strip()])


HAS_CUDA = torch.cuda.is_available()
GPU_NAME = torch.cuda.get_device_name(0) if HAS_CUDA else "CPU"
VRAM_GB = float(torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)) if HAS_CUDA else 0.0
CPU_COUNT = os.cpu_count() or 4

BATCH_SIZE = _suggest_embed_batch_size(VRAM_GB) if HAS_CUDA else 32
EMBED_CHUNK_SIZE = max(BATCH_SIZE * 220, 22000)
MAX_SEQ_LENGTH = 320
EMBEDDINGS_PATH = Path(INDEX_DIR) / "embeddings.npy"
EMBED_PROGRESS_PATH = Path(INDEX_DIR) / "embeddings.progress.json"

if HAS_CUDA:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

model = embedding_handler.load_model()
if model is not None:
    try:
        model.max_seq_length = int(MAX_SEQ_LENGTH)
    except Exception:
        pass
    if HAS_CUDA:
        try:
            model.half()
        except Exception:
            pass

print(f"Device: {GPU_NAME}")
print(f"VRAM: {VRAM_GB:.1f} GB | CPU cores: {CPU_COUNT}")
print(f"Embedding batch size: {BATCH_SIZE}")
print(f"Embedding chunk size: {EMBED_CHUNK_SIZE}")
print(f"Tracks to encode: {len(ids)}")
if ids:
    preview = _build_text_for_track(ids[0])
    print(f"Sample text: {preview[:240]}...")


In [None]:
# 2. Encode in chunks and save with resume support
if not ids:
    raise RuntimeError("No track IDs found. Build metadata first.")

# Probe embedding dimension once.
probe_vec = embedding_handler.encode([_build_text_for_track(ids[0])], batch_size=1)
dim = int(probe_vec.shape[1])
total = len(ids)

resume_idx = 0
if EMBEDDINGS_PATH.exists() and EMBED_PROGRESS_PATH.exists():
    try:
        progress = json.loads(EMBED_PROGRESS_PATH.read_text(encoding="utf-8"))
        if int(progress.get("count", -1)) == total and int(progress.get("dim", -1)) == dim:
            resume_idx = int(progress.get("next_idx", 0))
            print(f"Resuming embeddings from row {resume_idx}...")
    except Exception:
        resume_idx = 0

if EMBEDDINGS_PATH.exists() and resume_idx > 0:
    embeddings_mm = np.load(EMBEDDINGS_PATH, mmap_mode="r+")
else:
    embeddings_mm = np.lib.format.open_memmap(
        str(EMBEDDINGS_PATH), mode="w+", dtype=np.float32, shape=(total, dim)
    )
    resume_idx = 0

for start in tqdm(range(resume_idx, total, EMBED_CHUNK_SIZE), desc="Encoding chunks"):
    end = min(start + EMBED_CHUNK_SIZE, total)
    chunk_ids = ids[start:end]
    chunk_texts = [_build_text_for_track(track_id) for track_id in chunk_ids]
    chunk_embeddings = embedding_handler.encode(
        chunk_texts,
        batch_size=BATCH_SIZE,
        show_progress_bar=False,
    ).astype(np.float32, copy=False)

    embeddings_mm[start:end] = chunk_embeddings
    embeddings_mm.flush()
    EMBED_PROGRESS_PATH.write_text(
        json.dumps({"count": total, "dim": dim, "next_idx": end}),
        encoding="utf-8",
    )

    del chunk_texts
    del chunk_embeddings
    gc.collect()
    if HAS_CUDA:
        torch.cuda.empty_cache()

embeddings = np.load(EMBEDDINGS_PATH)
if EMBED_PROGRESS_PATH.exists():
    EMBED_PROGRESS_PATH.unlink()
sync_path(EMBEDDINGS_PATH, "indexes/embeddings.npy")
print(f"Saved embeddings.npy shape={embeddings.shape} at {EMBEDDINGS_PATH}")


## 6. Build FAISS Index


In [None]:
import faiss

if embeddings.shape[0] != len(ids):
    raise RuntimeError(f"Count mismatch: embeddings={embeddings.shape[0]} vs id_map={len(ids)}")

d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)
index.add(embeddings)

faiss_path = Path(INDEX_DIR) / "faiss_index.bin"
faiss.write_index(index, str(faiss_path))
sync_path(faiss_path, "indexes/faiss_index.bin")
print(f"FAISS index built and saved. ntotal={index.ntotal}, dim={index.d}")
