# EraEx Re-Embed + Re-Index (Colab, Local Copy Workflow)

This notebook rebuilds only:
- `embeddings.npy`
- `faiss_index.bin`

It is designed for **Google Colab** and **Google Drive**:
1. Mount Drive
2. Copy required project files to local Colab storage (`/content/...`) for faster reads/writes
3. Rebuild embeddings + FAISS index locally
4. Sync outputs back to Drive
5. Optionally zip local outputs for easier download

## Inputs required (already prepared)
- `data/indexes/id_map.json`
- `data/indexes/metadata.json`

If you changed CSV enrichment/metadata fields, rebuild `id_map.json` + `metadata.json` first (in your full pipeline), then run this notebook.


In [None]:
# 1) Install dependencies (Colab)
!pip -q install --upgrade pip
!pip -q install sentence-transformers faiss-cpu python-dotenv

import os
os.environ.setdefault('TOKENIZERS_PARALLELISM', 'false')


In [None]:
# 2) Mount Google Drive + configure paths
from google.colab import drive
from pathlib import Path
import shutil
import json
import sys
import gc
import time
import re

# Mount Drive
DRIVE_MOUNT = '/content/drive'
drive.mount(DRIVE_MOUNT)

# Update this if your repo folder name is different in Drive.
CANDIDATE_PROJECT_ROOTS = [
    '/content/drive/MyDrive/Team4_CPSC-5830-01-Capstone-Project',
    '/content/drive/MyDrive/EraEx',
]

DRIVE_PROJECT_ROOT = None
for cand in CANDIDATE_PROJECT_ROOTS:
    if Path(cand).exists():
        DRIVE_PROJECT_ROOT = Path(cand)
        break

if DRIVE_PROJECT_ROOT is None:
    raise FileNotFoundError(
        'Could not auto-detect project root in Drive. Set DRIVE_PROJECT_ROOT manually.'
    )

LOCAL_ROOT = Path('/content/eraex_reembed_work')
LOCAL_PROJECT_ROOT = LOCAL_ROOT / 'project'
LOCAL_INDEX_DIR = LOCAL_PROJECT_ROOT / 'data' / 'indexes'
DRIVE_INDEX_DIR = DRIVE_PROJECT_ROOT / 'data' / 'indexes'

print('Drive project root:', DRIVE_PROJECT_ROOT)
print('Local project root:', LOCAL_PROJECT_ROOT)

# Clean local workspace to avoid stale files.
if LOCAL_ROOT.exists():
    shutil.rmtree(LOCAL_ROOT)
LOCAL_INDEX_DIR.mkdir(parents=True, exist_ok=True)

# Copy only what we need for re-embed/re-index (fast and lighter than copying whole repo).
COPY_DIRS = ['src', 'config']
COPY_FILES = [
    'requirements.txt',
    'data/indexes/id_map.json',
    'data/indexes/metadata.json',
]
OPTIONAL_FILES = [
    'data/indexes/embeddings.progress.json',  # resume support if present
]

for rel in COPY_DIRS:
    src = DRIVE_PROJECT_ROOT / rel
    dst = LOCAL_PROJECT_ROOT / rel
    if src.exists():
        shutil.copytree(src, dst)
        print(f'Copied dir: {rel}')
    else:
        raise FileNotFoundError(f'Missing required directory in Drive: {src}')

for rel in COPY_FILES:
    src = DRIVE_PROJECT_ROOT / rel
    dst = LOCAL_PROJECT_ROOT / rel
    dst.parent.mkdir(parents=True, exist_ok=True)
    if not src.exists():
        raise FileNotFoundError(f'Missing required file in Drive: {src}')
    shutil.copy2(src, dst)
    print(f'Copied file: {rel}')

for rel in OPTIONAL_FILES:
    src = DRIVE_PROJECT_ROOT / rel
    dst = LOCAL_PROJECT_ROOT / rel
    if src.exists():
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(src, dst)
        print(f'Copied optional file: {rel}')

# Make local project importable as `src.*`
sys.path.insert(0, str(LOCAL_PROJECT_ROOT))


In [None]:
# 3) Imports (project + runtime libs)
import numpy as np
import torch
import faiss
from tqdm.auto import tqdm

from src.core.text_embeddings import embedding_handler
from src.core.media_metadata import build_track_embedding_text_context_first


In [None]:
# 4) Load id_map + metadata from local copy
INDEX_DIR = str(LOCAL_INDEX_DIR)
ID_MAP_PATH = LOCAL_INDEX_DIR / 'id_map.json'
METADATA_PATH = LOCAL_INDEX_DIR / 'metadata.json'

with open(ID_MAP_PATH, 'r', encoding='utf-8') as f:
    ids = json.load(f)
with open(METADATA_PATH, 'r', encoding='utf-8') as f:
    metadata = json.load(f)

print(f'id_map count: {len(ids)}')
print(f'metadata count: {len(metadata)}')
if len(ids) != len(metadata):
    print('WARNING: id_map and metadata counts differ. Embedding loop uses id_map order only.')

# Optional quick sanity preview
if ids:
    first_id = str(ids[0])
    print('Sample ID:', first_id)
    print('Sample metadata keys:', sorted(list((metadata.get(first_id) or {}).keys()))[:20])


In [None]:
# 5) Context-first embedding text builder (shared with runtime recommendation path)
def _suggest_embed_batch_size(vram_gb):
    if vram_gb >= 40:
        return 256
    if vram_gb >= 24:
        return 160
    if vram_gb >= 16:
        return 96
    if vram_gb >= 10:
        return 64
    return 32


def _build_text_for_track(track_id):
    meta = metadata.get(str(track_id), {})
    return build_track_embedding_text_context_first(meta)


In [None]:
# 6) Embedding config + model setup
HAS_CUDA = torch.cuda.is_available()
GPU_NAME = torch.cuda.get_device_name(0) if HAS_CUDA else 'CPU'
VRAM_GB = float(torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)) if HAS_CUDA else 0.0
CPU_COUNT = os.cpu_count() or 4

BATCH_SIZE = _suggest_embed_batch_size(VRAM_GB) if HAS_CUDA else 32
EMBED_CHUNK_SIZE = max(BATCH_SIZE * 220, 22000)
MAX_SEQ_LENGTH = 320
EMBEDDINGS_PATH = LOCAL_INDEX_DIR / 'embeddings.npy'
EMBED_PROGRESS_PATH = LOCAL_INDEX_DIR / 'embeddings.progress.json'

if HAS_CUDA:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    try:
        torch.set_float32_matmul_precision('high')
    except Exception:
        pass

model = embedding_handler.load_model()
if model is not None:
    try:
        model.max_seq_length = int(MAX_SEQ_LENGTH)
    except Exception:
        pass
    if HAS_CUDA:
        try:
            model.half()
        except Exception:
            pass

print(f'Device: {GPU_NAME}')
print(f'VRAM: {VRAM_GB:.1f} GB | CPU cores: {CPU_COUNT}')
print(f'Embedding batch size: {BATCH_SIZE}')
print(f'Embedding chunk size: {EMBED_CHUNK_SIZE}')
print(f'Tracks to encode: {len(ids)}')
if ids:
    preview = _build_text_for_track(ids[0])
    print(f'Sample text: {preview[:280]}...')


In [None]:
# 7) Generate / resume embeddings.npy (local Colab disk)
def _load_resume_index(progress_path):
    if not progress_path.exists():
        return 0
    try:
        with open(progress_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        idx = int(data.get('next_index', 0))
        return max(0, min(idx, len(ids)))
    except Exception:
        return 0


def _save_resume_index(progress_path, next_index):
    payload = {
        'next_index': int(next_index),
        'total': int(len(ids)),
        'updated_at': time.strftime('%Y-%m-%d %H:%M:%S'),
    }
    with open(progress_path, 'w', encoding='utf-8') as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

# Probe dimension
probe_vec = embedding_handler.encode([_build_text_for_track(ids[0])], batch_size=1)
d = int(probe_vec.shape[1])
print('Embedding dim:', d)
del probe_vec
gc.collect()
if HAS_CUDA:
    torch.cuda.empty_cache()

resume_idx = _load_resume_index(EMBED_PROGRESS_PATH)
if EMBEDDINGS_PATH.exists():
    embeddings_mm = np.load(EMBEDDINGS_PATH, mmap_mode='r+')
    if embeddings_mm.shape != (len(ids), d):
        print('Existing embeddings shape mismatch. Recreating memmap...')
        del embeddings_mm
        EMBEDDINGS_PATH.unlink(missing_ok=True)
        embeddings_mm = np.lib.format.open_memmap(
            EMBEDDINGS_PATH, mode='w+', dtype=np.float32, shape=(len(ids), d)
        )
        resume_idx = 0
else:
    embeddings_mm = np.lib.format.open_memmap(
        EMBEDDINGS_PATH, mode='w+', dtype=np.float32, shape=(len(ids), d)
    )

if resume_idx:
    print(f'Resuming embeddings from row {resume_idx}...')

for start in tqdm(range(resume_idx, len(ids), EMBED_CHUNK_SIZE), desc='Embedding chunks'):
    end = min(start + EMBED_CHUNK_SIZE, len(ids))
    chunk_ids = ids[start:end]
    chunk_texts = [_build_text_for_track(track_id) for track_id in chunk_ids]
    chunk_embeddings = embedding_handler.encode(
        chunk_texts,
        batch_size=BATCH_SIZE,
        show_progress_bar=False,
        normalize_embeddings=True,
    ).astype(np.float32)
    embeddings_mm[start:end] = chunk_embeddings
    embeddings_mm.flush()
    _save_resume_index(EMBED_PROGRESS_PATH, end)

    del chunk_texts, chunk_embeddings
    gc.collect()
    if HAS_CUDA:
        torch.cuda.empty_cache()

# Final load check
embeddings = np.load(EMBEDDINGS_PATH)
print(f'Saved embeddings.npy shape={embeddings.shape} at {EMBEDDINGS_PATH}')


In [None]:
# 8) Build FAISS index (inner product) from embeddings.npy
if len(ids) == 0:
    raise RuntimeError('No ids found in id_map.json')
if embeddings.shape[0] != len(ids):
    raise RuntimeError(f'Count mismatch: embeddings={embeddings.shape[0]} vs id_map={len(ids)}')

d = int(embeddings.shape[1])
index = faiss.IndexFlatIP(d)
index.add(embeddings)

faiss_path = LOCAL_INDEX_DIR / 'faiss_index.bin'
faiss.write_index(index, str(faiss_path))
print(f'Saved faiss_index.bin ntotal={index.ntotal}, dim={d} at {faiss_path}')


In [None]:
# 9) Sync outputs back to Drive + prepare local zip for download
from google.colab import files

SYNC_OUTPUTS = [
    'data/indexes/embeddings.npy',
    'data/indexes/embeddings.progress.json',
    'data/indexes/faiss_index.bin',
]

for rel in SYNC_OUTPUTS:
    src = LOCAL_PROJECT_ROOT / rel
    dst = DRIVE_PROJECT_ROOT / rel
    if src.exists():
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(src, dst)
        print(f'Synced -> Drive: {rel}')

# Create a local zip containing only the regenerated index artifacts (easy Colab download)
zip_base = '/content/eraex_reembed_outputs'
zip_path = shutil.make_archive(zip_base, 'zip', root_dir=LOCAL_PROJECT_ROOT, base_dir='data/indexes')
print('Local zip ready:', zip_path)
print('You can download it with: files.download(zip_path)')

# Uncomment if you want immediate browser download (large files may take a while)
# files.download(zip_path)


## Notes
- This notebook **does not rebuild** `id_map.json` or `metadata.json`.
- It reuses your existing metadata and only regenerates embeddings + FAISS index.
- If you change embedding text logic again, rerun **Step 7** and **Step 8**.
