# EraEx: Build FAISS + BM25 Indexes (Colab)

This notebook builds:
1. FAISS IVF+PQ indexes for dense retrieval
2. BM25 index for sparse retrieval

In [None]:
%pip install -r requirements.txt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path
import numpy as np
import polars as pl
import faiss
import pickle

PROJECT_DIR = Path('/content/drive/MyDrive/EraEx')
EMBEDDINGS_DIR = PROJECT_DIR / 'data' / 'embeddings'
READY_DIR = PROJECT_DIR / 'data' / 'processed' / 'music_ready'
INDEXES_DIR = PROJECT_DIR / 'data' / 'indexes'

INDEXES_DIR.mkdir(parents=True, exist_ok=True)

YEAR_RANGE = range(2012, 2019)

N_LIST = 256
M_PQ = 48
N_BITS = 8

## 1. Build FAISS Indexes

In [None]:
def build_faiss_index(year):
    emb_path = EMBEDDINGS_DIR / f'embeddings_{year}.npy'
    idx_path = INDEXES_DIR / f'faiss_{year}.index'
    
    if not emb_path.exists():
        print(f'{year}: No embeddings')
        return
    
    if idx_path.exists():
        print(f'{year}: Index exists, skipping')
        return
    
    print(f'\nBuilding FAISS index for {year}...')
    embeddings = np.load(emb_path).astype(np.float32)
    n_vectors, dim = embeddings.shape
    print(f'  Vectors: {n_vectors:,} x {dim}')
    
    n_train = min(n_vectors, 100000)
    train_idx = np.random.choice(n_vectors, n_train, replace=False)
    train_data = embeddings[train_idx]
    
    n_list_actual = min(N_LIST, n_vectors // 50)
    m_pq_actual = min(M_PQ, dim)
    
    quantizer = faiss.IndexFlatIP(dim)
    index = faiss.IndexIVFPQ(quantizer, dim, n_list_actual, m_pq_actual, N_BITS)
    
    print(f'  Training (n_list={n_list_actual}, m_pq={m_pq_actual})...')
    index.train(train_data)
    
    print('  Adding vectors...')
    index.add(embeddings)
    
    faiss.write_index(index, str(idx_path))
    print(f'  Saved: {idx_path}')

In [None]:
for year in YEAR_RANGE:
    build_faiss_index(year)

print('\nFAISS indexes complete!')

## 2. Build BM25 Index

In [None]:
import bm25s

def build_bm25_index():
    bm25_path = INDEXES_DIR / 'bm25_index.pkl'
    
    if bm25_path.exists():
        print('BM25 index exists, skipping')
        return
    
    print('Loading all documents for BM25...')
    all_docs = []
    all_ids = []
    
    for year in YEAR_RANGE:
        data_path = READY_DIR / f'year={year}' / 'data.parquet'
        if not data_path.exists():
            continue
        
        df = pl.read_parquet(data_path)
        texts = df['doc_text_music'].to_list()
        ids = df['track_id'].to_list()
        
        all_docs.extend([t if t else '' for t in texts])
        all_ids.extend([str(i) for i in ids])
        
        print(f'  {year}: {len(texts):,} docs')
    
    print(f'\nTotal documents: {len(all_docs):,}')
    
    print('Building BM25 index...')
    corpus_tokens = bm25s.tokenize(all_docs)
    
    bm25 = bm25s.BM25()
    bm25.index(corpus_tokens)
    
    index_data = {
        'bm25': bm25,
        'doc_ids': all_ids,
        'corpus_tokens': corpus_tokens,
    }
    
    with open(bm25_path, 'wb') as f:
        pickle.dump(index_data, f)
    
    print(f'Saved: {bm25_path}')

In [None]:
build_bm25_index()

print('\n' + '=' * 50)
print('ALL INDEXES COMPLETE')
print('=' * 50)

for f in sorted(INDEXES_DIR.glob('*')):
    size_mb = f.stat().st_size / 1e6
    print(f'{f.name}: {size_mb:.1f} MB')

## 3. Test Search

In [None]:
from sentence_transformers import SentenceTransformer

test_query = 'i miss my ex'
print(f'Test query: "{test_query}"')

model = SentenceTransformer('all-MiniLM-L6-v2')
q_emb = model.encode(test_query, normalize_embeddings=True).astype(np.float32).reshape(1, -1)

year = 2015
idx_path = INDEXES_DIR / f'faiss_{year}.index'
if idx_path.exists():
    index = faiss.read_index(str(idx_path))
    index.nprobe = 10
    scores, indices = index.search(q_emb, 5)
    
    ids_df = pl.read_parquet(EMBEDDINGS_DIR / f'ids_{year}.parquet')
    track_ids = ids_df['track_id'].to_list()
    
    print(f'\nTop 5 from {year}:')
    for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
        print(f'  {i+1}. Score: {score:.4f}, ID: {track_ids[idx]}')