# Prompt2Song – Contrastive Retrieval & FAISS Indexing

Align prompt embeddings with fused song embeddings via contrastive learning and build a FAISS index for fast recommendation.

## Goals
- Load prompt and song artifacts produced by earlier notebooks
- Train lightweight projection heads using an InfoNCE-style objective with in-batch negatives
- Build a FAISS index over fused song embeddings for low-latency retrieval
- Provide an end-to-end query function that maps free-text prompts to recommended songs

Imports JSON handling, math/random utilities, PyTorch/NumPy, FAISS (if available), and supporting typing helpers.

In [None]:
import json
import math
import random
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

try:
    import faiss  # type: ignore
except ImportError:
    faiss = None
    print("⚠️ FAISS not installed. Install faiss-cpu or faiss-gpu before running indexing steps.")


Resolves project directories for text/fusion artifacts, ensures retrieval output folder exists, and warns if prerequisites are missing.

In [None]:
NOTEBOOK_DIR = Path.cwd().resolve()
if (NOTEBOOK_DIR / "datasets").exists():
    PROJECT_ROOT = NOTEBOOK_DIR
else:
    PROJECT_ROOT = NOTEBOOK_DIR.parent

TEXT_MODEL_DIR = PROJECT_ROOT / "artifacts" / "text_encoder" / "hf_model"
FUSION_DIR = PROJECT_ROOT / "artifacts" / "fusion"
RETRIEVAL_DIR = PROJECT_ROOT / "artifacts" / "retrieval"
RETRIEVAL_DIR.mkdir(parents=True, exist_ok=True)

if not TEXT_MODEL_DIR.exists() or not FUSION_DIR.exists():
    print("⚠️ Missing prerequisites. Run notebooks 01 and 02 first.")


Defines the TextEmotionEncoder wrapper that can produce embeddings and emotion probabilities via the fine-tuned classifier.

In [None]:
class TextEmotionEncoder(torch.nn.Module):
    def __init__(self, model_dir: Path, device: str | None = None):
        super().__init__()
        from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.base_model = AutoModel.from_pretrained(model_dir).to(self.device)
        self.classifier = AutoModelForSequenceClassification.from_pretrained(model_dir).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

    @torch.no_grad()
    def encode(self, texts: List[str], batch_size: int = 32, max_length: int = 256) -> np.ndarray:
        embeddings = []
        for start in range(0, len(texts), batch_size):
            batch = texts[start:start + batch_size]
            tokens = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(self.device)
            outputs = self.base_model(**tokens)
            token_embeddings = outputs.last_hidden_state
            attention_mask = tokens.attention_mask.unsqueeze(-1)
            summed = (token_embeddings * attention_mask).sum(dim=1)
            counts = attention_mask.sum(dim=1)
            mean_pooled = summed / counts
            embeddings.append(mean_pooled.cpu().numpy())
        return np.vstack(embeddings)

    @torch.no_grad()
    def predict_label(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        preds = []
        for start in range(0, len(texts), batch_size):
            batch = texts[start:start + batch_size]
            tokens = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            ).to(self.device)
            outputs = self.classifier(**tokens)
            logits = outputs.logits
            preds.append(logits.softmax(dim=-1).cpu().numpy())
        return np.vstack(preds)


### Load datasets

Parses the emotion prompt dataset splits from disk and prints dataset size for verification.

In [None]:
def load_emotion_split(path: Path) -> List[str]:
    df = []
    with path.open("r", encoding="utf-8") as fh:
        for line in fh:
            if not line.strip():
                continue
            text, label = line.strip().split(";")
            df.append({"text": text.strip(), "label": label.strip()})
    return df

train_prompts = load_emotion_split(PROJECT_ROOT / "datasets" / "emotions_NLP" / "train.txt")
val_prompts = load_emotion_split(PROJECT_ROOT / "datasets" / "emotions_NLP" / "val.txt")
test_prompts = load_emotion_split(PROJECT_ROOT / "datasets" / "emotions_NLP" / "test.txt")
print(f"Loaded {len(train_prompts)} train prompts")


Initializes the encoder, loads the saved label mapping, and reports available emotion labels.

In [None]:
encoder = TextEmotionEncoder(TEXT_MODEL_DIR)
label2id_path = PROJECT_ROOT / "artifacts" / "text_encoder" / "label2id.json"
if label2id_path.exists():
    label2id = json.loads(label2id_path.read_text(encoding="utf-8"))
else:
    # fallback: infer from classifier config
    label2id = encoder.classifier.config.label2id

id2label = {int(idx): label for label, idx in label2id.items()}
label_names = sorted(label2id.keys())
print("Labels:", label_names)


### Embed prompts

Encodes prompts into arrays of texts, labels, and embeddings for each split and saves the prompt embeddings to disk.

In [None]:
def to_arrays(records):
    texts = [r["text"] for r in records]
    labels = [r["label"] for r in records]
    label_ids = [label2id[label] for label in labels]
    embeddings = encoder.encode(texts)
    return {
        "texts": texts,
        "labels": np.array(label_ids, dtype=np.int64),
        "embeddings": embeddings.astype(np.float32),
    }

prompt_arrays = {
    "train": to_arrays(train_prompts),
    "val": to_arrays(val_prompts),
    "test": to_arrays(test_prompts),
}

np.save(RETRIEVAL_DIR / "prompt_train_embeddings.npy", prompt_arrays["train"]["embeddings"])
np.save(RETRIEVAL_DIR / "prompt_val_embeddings.npy", prompt_arrays["val"]["embeddings"])
np.save(RETRIEVAL_DIR / "prompt_test_embeddings.npy", prompt_arrays["test"]["embeddings"])


### Load fused song embeddings and metadata

Loads lyric and fused song embeddings plus metadata generated by the previous notebook, ensuring prerequisites exist.

In [None]:
lyric_embeddings_path = FUSION_DIR / "lyric_embeddings.npy"
fused_embeddings_path = FUSION_DIR / "fused_song_embeddings.npy"
metadata_path = FUSION_DIR / "song_metadata.json"

if not fused_embeddings_path.exists():
    raise FileNotFoundError("Run 02_audio_encoder_and_fusion.ipynb to generate fused song embeddings.")

lyric_embeddings = np.load(lyric_embeddings_path)
fused_embeddings = np.load(fused_embeddings_path)
song_metadata = json.loads(metadata_path.read_text(encoding="utf-8"))

print("Fused embeddings shape:", fused_embeddings.shape)


### Predict song emotion distributions using the classifier head

Derives per-song emotion probabilities from lyrics using the classifier to produce label ids and store soft label scores.

In [None]:
song_texts = [text for text in song_metadata["titles"]]
lyrics_texts = [text for text in song_metadata.get("lyrics", [])]

if not lyrics_texts:
    # If lyrics were not stored in metadata, fall back to re-reading the dataset
    import pandas as pd
    songs_df = pd.read_csv(PROJECT_ROOT / "datasets" / "song_features" / "songs_with_attributes_and_lyrics.csv")
    valid_mask = songs_df["lyrics"].fillna("").str.len() > 0
    lyrics_texts = songs_df.loc[valid_mask, "lyrics"].astype(str).tolist()
else:
    lyrics_texts = [lx if isinstance(lx, str) else "" for lx in lyrics_texts]

song_probs = encoder.predict_label(lyrics_texts, batch_size=16)
song_label_ids = song_probs.argmax(axis=1)
np.save(RETRIEVAL_DIR / "song_label_probs.npy", song_probs)


### Contrastive projection heads

Defines shared projection heads for prompts and songs along with the target embedding dimensionality.

In [None]:
embedding_dim = fused_embeddings.shape[1]
projection_dim = 256

def build_projection_head() -> nn.Sequential:
    return nn.Sequential(
        nn.Linear(embedding_dim, projection_dim),
        nn.ReLU(),
        nn.Linear(projection_dim, projection_dim),
    )

prompt_projector = build_projection_head()
song_projector = build_projection_head()


### Contrastive training loop

Groups prompt and song indices by label so we can sample matched pairs during contrastive training.

In [None]:
prompt_by_label: Dict[int, List[int]] = {}
for idx, label in enumerate(prompt_arrays["train"]["labels"]):
    prompt_by_label.setdefault(int(label), []).append(idx)

song_by_label: Dict[int, List[int]] = {}
for idx, label in enumerate(song_label_ids):
    song_by_label.setdefault(int(label), []).append(idx)

labels_available = sorted(set(prompt_by_label.keys()) & set(song_by_label.keys()))
print("Labels with both prompts and songs:", labels_available)


Moves projectors to the active device, sets up the optimizer, and defines helper functions for batch sampling and the InfoNCE loss.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
prompt_projector = prompt_projector.to(device)
song_projector = song_projector.to(device)

optimizer = torch.optim.AdamW(list(prompt_projector.parameters()) + list(song_projector.parameters()), lr=1e-3)

def sample_batch(batch_size: int = 64):
    prompt_embeds = []
    song_embeds = []
    for _ in range(batch_size):
        label = random.choice(labels_available)
        p_idx = random.choice(prompt_by_label[label])
        s_idx = random.choice(song_by_label[label])
        prompt_embeds.append(prompt_arrays["train"]["embeddings"][p_idx])
        song_embeds.append(fused_embeddings[s_idx])
    prompt_batch = torch.from_numpy(np.stack(prompt_embeds)).float().to(device)
    song_batch = torch.from_numpy(np.stack(song_embeds)).float().to(device)
    return prompt_batch, song_batch


def info_nce_loss(prompt_z: torch.Tensor, song_z: torch.Tensor, temperature: float = 0.07):
    prompt_norm = F.normalize(prompt_z, dim=-1)
    song_norm = F.normalize(song_z, dim=-1)
    logits = prompt_norm @ song_norm.T
    logits = logits / temperature
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2


Implements the contrastive training loop that repeatedly samples matched pairs and optimizes the projection heads.

In [None]:
def train_contrastive(epochs: int = 10, steps_per_epoch: int = 100, batch_size: int = 64):
    for epoch in range(epochs):
        running_loss = 0.0
        for step in range(steps_per_epoch):
            prompt_batch, song_batch = sample_batch(batch_size)
            optimizer.zero_grad()
            prompt_z = prompt_projector(prompt_batch)
            song_z = song_projector(song_batch)
            loss = info_nce_loss(prompt_z, song_z)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / steps_per_epoch
        print(f"Epoch {epoch + 1}/{epochs} | loss={avg_loss:.4f}")

# Uncomment to train
# train_contrastive(epochs=15, steps_per_epoch=200, batch_size=128)


### Persist projection heads

Placeholders for saving the trained projection heads to disk once training is complete.

In [None]:
# torch.save(prompt_projector.state_dict(), RETRIEVAL_DIR / "prompt_projector.pt")
# torch.save(song_projector.state_dict(), RETRIEVAL_DIR / "song_projector.pt")


### Build FAISS index

Builds a normalized FAISS index of projected song embeddings and persists both the index and numpy cache.

In [None]:
def build_retrieval_index():
    if faiss is None:
        raise ImportError("faiss is required for indexing")

    with torch.no_grad():
        song_tensor = torch.from_numpy(fused_embeddings).float().to(device)
        projected = song_projector(song_tensor).cpu().numpy()
        projected = projected / np.linalg.norm(projected, axis=1, keepdims=True)

    dim = projected.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(projected.astype(np.float32))
    faiss.write_index(index, str(RETRIEVAL_DIR / "faiss_song.index"))
    np.save(RETRIEVAL_DIR / "projected_song_embeddings.npy", projected.astype(np.float32))
    print("Indexed", index.ntotal, "songs")

# build_retrieval_index()


### Prompt-to-song search helper

Provides helpers to reload the FAISS index and run prompt-to-song retrieval returning top-k recommendations.

In [None]:
def load_faiss_index(path: Path):
    if faiss is None:
        raise ImportError("Install faiss before using retrieval")
    return faiss.read_index(str(path))


def recommend(prompt_text: str, top_k: int = 5):
    if faiss is None:
        raise ImportError("Install faiss before using retrieval")

    with torch.no_grad():
        prompt_embedding = encoder.encode([prompt_text])
        prompt_tensor = torch.from_numpy(prompt_embedding).float().to(device)
        projected_prompt = prompt_projector(prompt_tensor)
        projected_prompt = F.normalize(projected_prompt, dim=-1)

    index = load_faiss_index(RETRIEVAL_DIR / "faiss_song.index")
    song_embeddings = np.load(RETRIEVAL_DIR / "projected_song_embeddings.npy")

    query = projected_prompt.cpu().numpy().astype(np.float32)
    scores, indices = index.search(query, top_k)

    results = []
    for score, idx in zip(scores[0], indices[0]):
        results.append({
            "score": float(score),
            "song_id": song_metadata["song_ids"][idx],
            "title": song_metadata["titles"][idx],
            "artists": song_metadata["artists"][idx],
        })
    return results

# Example invocation after training/indexing
# recommend("my dog passed", top_k=5)


### Validation ideas
- Evaluate retrieval quality by checking if prompts whose dominant emotion is *sadness* retrieve sad songs.
- Create qualitative plots comparing cosine similarity distributions for positive vs negative pairs.
- Consider additional supervision such as valence/arousal regression if labels become available.