In [None]:
"""
Google Colab Notebook for Mass Embedding Generation.

This script processes the 'knowledge_chunks.jsonl' file using the BGE-M3 model.
It implements batching and checkpointing to handle large datasets on limited hardware (T4/L4 GPUs).

Environment Variables / Secrets Required:
- DRIVE_FOLDER_NAME: The name of your DVC remote folder in Google Drive.
- DVC_FILE_HASH: The specific hash of the knowledge_chunks file (from `dvc push` logs).
"""

import os
import sys

In [None]:
# --- Helper for Secrets Management ---
def get_secret(name: str) -> str:
    """
    Retrieves a secret value from Google Colab Secrets (preferred) or os.environ.
    """
    try:
        from google.colab import userdata
        return userdata.get(name)
    except (ImportError, AttributeError, Exception):
        # Fallback to standard environment variable
        return os.getenv(name)

In [None]:
# --- Configuration & Validation ---
# Retrieve configuration from environment/secrets
DRIVE_FOLDER_NAME = get_secret("DRIVE_FOLDER_NAME")
FILE_HASH = get_secret("DVC_FILE_HASH")

# Validate configuration
if not DRIVE_FOLDER_NAME or not FILE_HASH:
    print("❌ ERROR: Missing configuration.")
    print("Please set 'DRIVE_FOLDER_NAME' and 'DVC_FILE_HASH' in Colab Secrets or os.environ.")
    sys.exit(1)

print(f"✅ Configuration loaded for Drive Folder: {DRIVE_FOLDER_NAME}")
print(f"✅ Target File Hash: {FILE_HASH[:8]}...")

In [None]:
# --- Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
# --- Dependencies ---
!pip install sentence-transformers tqdm numpy torch

In [None]:
# --- Path Setup ---
# DVC storage logic: /files/md5/{first_2_chars}/{rest_of_hash}
file_hash_prefix = FILE_HASH[:2]
file_hash_suffix = FILE_HASH[2:]
dvc_file_path = f"/content/drive/MyDrive/{DRIVE_FOLDER_NAME}/files/md5/{file_hash_prefix}/{file_hash_suffix}"

if os.path.exists(dvc_file_path):
    print(f"SUCCESS: Data file found at:\n{dvc_file_path}")
    CHUNKS_FILE = dvc_file_path
    # Output directory for embeddings and checkpoints
    OUTPUT_DIR = f"/content/drive/MyDrive/{DRIVE_FOLDER_NAME}/outputs"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Embeddings will be saved to: {OUTPUT_DIR}")
else:
    raise FileNotFoundError(
        f"ERROR: Could not find file at {dvc_file_path}. "
        "Please verify your DRIVE_FOLDER_NAME and DVC_FILE_HASH secrets."
    )

In [None]:
# --- Embedding Generation Logic ---
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
import gc

# PyTorch Memory Optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
print("PYTORCH_CUDA_ALLOC_CONF set to 'expandable_segments:True'")

print("Loading BGE-M3 model onto GPU...")
model = SentenceTransformer('BAAI/bge-m3', device='cuda')
print("Model loaded successfully.")

# Define Output Paths
MONSTER_CHUNKS_FILE = os.path.join(OUTPUT_DIR, "monster_chunks_to_fix.jsonl")

# Config Parameters
MANUAL_BATCH_SIZE = 32
MAX_CHUNK_LENGTH = 20000 

processed_count = 0
skipped_count = 0
batch_texts = []
batch_chunk_ids = []

print(f"Starting embedding generation. Batch Size: {MANUAL_BATCH_SIZE}")

try:
    with open(CHUNKS_FILE, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Processing chunks"):
            try:
                data = json.loads(line)
                text = data.get('text', '')
                chunk_id = data.get('chunk_id', 'unknown')
            except json.JSONDecodeError:
                continue

            # Checkpointing: Skip processed batches
            batch_num = (processed_count + skipped_count) // MANUAL_BATCH_SIZE
            vector_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num}_vectors.npy")
            ids_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num}_ids.json")

            if os.path.exists(vector_file):
                if len(batch_texts) == 0: 
                    processed_count += MANUAL_BATCH_SIZE
                continue
            
            # Guardrail: Filter Monster Chunks
            if len(text) > MAX_CHUNK_LENGTH:
                tqdm.write(f"SKIPPING monster chunk: {chunk_id} ({len(text)} chars)")
                skipped_count += 1
                with open(MONSTER_CHUNKS_FILE, 'a', encoding='utf-8') as monster_f:
                    monster_f.write(json.dumps(data, ensure_ascii=False) + '\n')
                continue
            
            batch_texts.append(text)
            batch_chunk_ids.append(chunk_id)

            # Process Batch
            if len(batch_texts) >= MANUAL_BATCH_SIZE:
                gc.collect()
                torch.cuda.empty_cache()

                batch_embeds = model.encode(batch_texts, batch_size=len(batch_texts), show_progress_bar=False)
                
                np.save(vector_file, batch_embeds)
                with open(ids_file, 'w', encoding='utf-8') as f_ids:
                    json.dump(batch_chunk_ids, f_ids)
                
                tqdm.write(f"Processed and saved batch {batch_num}")
                processed_count += len(batch_texts)
                
                batch_texts = []
                batch_chunk_ids = []
                del batch_embeds

    # Final Batch Processing
    if batch_texts:
        print(f"Processing final batch...")
        batch_num = (processed_count + skipped_count) // MANUAL_BATCH_SIZE
        vector_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num}_vectors.npy")
        ids_file = os.path.join(OUTPUT_DIR, f"batch_{batch_num}_ids.json")

        if not os.path.exists(vector_file):
            gc.collect()
            torch.cuda.empty_cache()
            batch_embeds = model.encode(batch_texts, batch_size=len(batch_texts), show_progress_bar=True)
            
            np.save(vector_file, batch_embeds)
            with open(ids_file, 'w', encoding='utf-8') as f_ids:
                json.dump(batch_chunk_ids, f_ids)
            
            print(f"Processed final batch {batch_num}")
            processed_count += len(batch_texts)

    print("="*50)
    print(f"Complete. Processed: {processed_count}, Skipped: {skipped_count}")
    print("="*50)

except Exception as e:
    print(f"An error occurred: {e}")
    import torch
    torch.cuda.empty_cache()