# EraEx: ColBERT Embedding (Colab GPU)

This notebook generates ColBERT embeddings for music tracks.

**Model**: colbert-ir/colbertv2.0
**Fallback**: sentence-transformers/all-MiniLM-L6-v2

**Requirements**: GPU runtime (T4 or better)

In [None]:
%pip install -r requirements.txt

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    pass

In [None]:
import torch
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
from pathlib import Path
import numpy as np
import polars as pl
import os
from tqdm import tqdm

try:
    from google.colab import drive
    PROJECT_DIR = Path('/content/drive/MyDrive/EraEx')
except ImportError:
    PROJECT_DIR = Path.cwd().parent

READY_DIR = PROJECT_DIR / 'data' / 'processed' / 'music_ready'
EMBEDDINGS_DIR = PROJECT_DIR / 'data' / 'embeddings'

EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)

YEAR_RANGE = range(2012, 2019)
USE_COLBERT = True  # Set False to use SBERT fallback
print(f'Project Dir: {PROJECT_DIR}')
print(f'Looking for data in: {READY_DIR}')  # DEBUG PRINT

In [None]:
from sentence_transformers import SentenceTransformer

if USE_COLBERT:
    try:
        from transformers import AutoTokenizer, AutoModel, logging
        
        # Suppress HF warnings
        logging.set_verbosity_error()
        import warnings
        warnings.filterwarnings('ignore', category=UserWarning)

        MODEL_NAME = 'colbert-ir/colbertv2.0'
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        print(f'Loading {MODEL_NAME}...')
        print("(Note: Ignore 'UNEXPECTED linear.weight' warnings - normal for Dense mode)")
        
        model = AutoModel.from_pretrained(MODEL_NAME)
        model.to('cuda' if torch.cuda.is_available() else 'cpu')
        model.eval()
        EMBEDDER_TYPE = 'colbert'
        print(f'âœ“ Model Loaded Successfully')
    except Exception as e:
        print(f'ColBERT failed: {e}')
        print('Falling back to SBERT')
        USE_COLBERT = False

if not USE_COLBERT:
    model = SentenceTransformer('all-MiniLM-L6-v2')
    EMBEDDER_TYPE = 'sbert'
    print('Using SBERT: all-MiniLM-L6-v2')

In [None]:
def embed_sbert(texts, batch_size=256):
    embeddings = model.encode(
        texts,
        batch_size=batch_size,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True
    )
    return embeddings.astype(np.float16)

def embed_colbert(texts, batch_size=32):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    all_embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i + batch_size]
        
        inputs = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=180,
            return_tensors='pt'
        ).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            token_emb = outputs.last_hidden_state
            mask = inputs['attention_mask']
            
            mask_expanded = mask.unsqueeze(-1).float()
            sum_emb = (token_emb * mask_expanded).sum(dim=1)
            count = mask_expanded.sum(dim=1)
            pooled = sum_emb / count
            pooled = pooled / (pooled.norm(dim=1, keepdim=True) + 1e-9)
            
            all_embeddings.append(pooled.cpu().numpy())
    
    return np.vstack(all_embeddings).astype(np.float16)

In [None]:
def process_year(year):
    data_path = READY_DIR / f'year={year}' / 'data.parquet'
    if not data_path.exists():
        print(f'{year}: Not found at {data_path}')
        # Helpful debug for file structure mismatch
        if year == 2012:
            print("  > Debug: Contents of parent folder:")
            try:
                if READY_DIR.parent.exists():
                    print(list(READY_DIR.parent.iterdir()))
                else:
                    print(f"    Parent {READY_DIR.parent} does not exist!")
            except Exception as e:
                print(f"    Could not list parent: {e}")
        return
    
    emb_path = EMBEDDINGS_DIR / f'embeddings_{year}.npy'
    ids_path = EMBEDDINGS_DIR / f'ids_{year}.parquet'
    
    if emb_path.exists() and ids_path.exists():
        print(f'{year}: Already exists, skipping')
        return
    
    print(f'\nProcessing {year}...')
    try:
        df = pl.read_parquet(data_path)
    except Exception as e:
        print(f"Error reading parquet {data_path}: {e}")
        return
    
    print(f'  Rows: {df.height:,}')
    
    texts = df['doc_text_music'].to_list()
    texts = [t if t else '' for t in texts]
    
    print(f'  Embedding with {EMBEDDER_TYPE}...')
    if EMBEDDER_TYPE == 'colbert':
        embeddings = embed_colbert(texts)
    else:
        embeddings = embed_sbert(texts)
    
    np.save(emb_path, embeddings)
    print(f'  Saved: {emb_path}')
    
    ids_df = df.select(['track_id'])
    ids_df.write_parquet(ids_path)
    print(f'  Saved: {ids_path}')

In [None]:
for year in YEAR_RANGE:
    process_year(year)

print('\n' + '=' * 50)
print('EMBEDDING COMPLETE')
print('=' * 50)

for f in sorted(EMBEDDINGS_DIR.glob('*.npy')):
    emb = np.load(f)
    print(f'{f.name}: {emb.shape}')