# üîç FAISS Embedding Index Creation

This notebook creates the FAISS vector index for retrieval during inference.

**Input:** `chunks.jsonl` from notebook 01
**Output:** `faiss_index/` folder with index and metadata

**Embedding Model:** S-PubMedBert-MS-MARCO (medical domain)

## 1. Setup

In [None]:
# Install dependencies
!pip install -q sentence-transformers faiss-cpu tqdm

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Paths
DATA_DIR = "/content/drive/MyDrive/RAFT_dental_data"
CHUNKS_FILE = f"{DATA_DIR}/chunks.jsonl"
OUTPUT_DIR = f"{DATA_DIR}/faiss_index"

import os
os.makedirs(OUTPUT_DIR, exist_ok=True)

if os.path.exists(CHUNKS_FILE):
    print(f"‚úì Found chunks file")
else:
    print(f"‚úó chunks.jsonl not found. Run notebook 01 first!")

## 2. Load Chunks

In [None]:
import json
from tqdm.notebook import tqdm

# Load all chunks
chunks = []
with open(CHUNKS_FILE, 'r', encoding='utf-8') as f:
    for line in f:
        chunks.append(json.loads(line))

print(f"Loaded {len(chunks)} chunks")

# Preview
print(f"\nSample chunk:")
print(f"  ID: {chunks[0]['chunk_id']}")
print(f"  Source: {chunks[0]['source']}")
print(f"  Category: {chunks[0]['category']}")
print(f"  Tokens: {chunks[0].get('token_count', 'N/A')}")

## 3. Load Embedding Model

In [None]:
from sentence_transformers import SentenceTransformer
import numpy as np

# Medical domain embedding model
MODEL_NAME = "pritamdeka/S-PubMedBert-MS-MARCO"

print(f"Loading embedding model: {MODEL_NAME}")
embed_model = SentenceTransformer(MODEL_NAME)

embedding_dim = embed_model.get_sentence_embedding_dimension()
print(f"‚úì Model loaded")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Max sequence length: {embed_model.max_seq_length}")

## 4. Generate Embeddings

In [None]:
# Extract texts
chunk_texts = [c['text'] for c in chunks]

print(f"Generating embeddings for {len(chunk_texts)} chunks...")
print("This may take a few minutes...")

# Generate embeddings in batches
embeddings = embed_model.encode(
    chunk_texts,
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True  # Normalize for cosine similarity
)

print(f"\n‚úì Generated embeddings")
print(f"  Shape: {embeddings.shape}")
print(f"  Dtype: {embeddings.dtype}")

## 5. Build FAISS Index

In [None]:
import faiss

# Create FAISS index
# Using IndexFlatIP (Inner Product) since embeddings are normalized -> cosine similarity
print("Building FAISS index...")

index = faiss.IndexFlatIP(embedding_dim)
index.add(embeddings.astype(np.float32))

print(f"‚úì FAISS index built")
print(f"  Total vectors: {index.ntotal}")
print(f"  Dimension: {index.d}")

In [None]:
# Test the index with a sample query
test_query = "What are the indications for root canal treatment?"

# Encode query
query_embedding = embed_model.encode([test_query], normalize_embeddings=True)

# Search
k = 3
scores, indices = index.search(query_embedding.astype(np.float32), k)

print(f"Query: {test_query}")
print(f"\nTop {k} results:")
for i, (idx, score) in enumerate(zip(indices[0], scores[0])):
    chunk = chunks[idx]
    print(f"\n{i+1}. Score: {score:.4f}")
    print(f"   Source: {chunk['source']}")
    print(f"   Category: {chunk['category']}")
    print(f"   Text: {chunk['text'][:150]}...")

## 6. Save Index and Metadata

In [None]:
# Save FAISS index
index_path = os.path.join(OUTPUT_DIR, "dental.index")
faiss.write_index(index, index_path)
print(f"‚úì Saved FAISS index to {index_path}")

# Save metadata (for retrieval)
metadata = []
for i, chunk in enumerate(chunks):
    metadata.append({
        "id": i,
        "chunk_id": chunk['chunk_id'],
        "source": chunk['source'],
        "category": chunk['category'],
        "page_number": chunk['page_number'],
        "text": chunk['text']
    })

metadata_path = os.path.join(OUTPUT_DIR, "metadata.jsonl")
with open(metadata_path, 'w', encoding='utf-8') as f:
    for m in metadata:
        f.write(json.dumps(m, ensure_ascii=False) + "\n")

print(f"‚úì Saved metadata to {metadata_path}")

# Save config
config = {
    "model_name": MODEL_NAME,
    "embedding_dim": embedding_dim,
    "num_vectors": len(chunks),
    "index_type": "IndexFlatIP",
    "normalized": True
}

config_path = os.path.join(OUTPUT_DIR, "config.json")
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)

print(f"‚úì Saved config to {config_path}")

## 7. Create Retriever Class

In [None]:
# Save a retriever helper class
retriever_code = '''
"""
Dental FAISS Retriever
Load and use the FAISS index for retrieval.
"""

import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict
from pathlib import Path


class DentalRetriever:
    """Retrieve relevant dental documents using FAISS."""
    
    def __init__(self, index_dir: str):
        """
        Initialize retriever.
        
        Args:
            index_dir: Directory containing FAISS index and metadata
        """
        index_dir = Path(index_dir)
        
        # Load config
        with open(index_dir / "config.json") as f:
            self.config = json.load(f)
        
        # Load FAISS index
        self.index = faiss.read_index(str(index_dir / "dental.index"))
        
        # Load metadata
        self.metadata = []
        with open(index_dir / "metadata.jsonl") as f:
            for line in f:
                self.metadata.append(json.loads(line))
        
        # Load embedding model
        self.embed_model = SentenceTransformer(self.config["model_name"])
        
        print(f"Loaded retriever with {len(self.metadata)} documents")
    
    def retrieve(self, query: str, k: int = 5) -> List[Dict]:
        """
        Retrieve top-k relevant documents.
        
        Args:
            query: Search query
            k: Number of documents to retrieve
            
        Returns:
            List of document dictionaries with scores
        """
        # Encode query
        query_embedding = self.embed_model.encode(
            [query],
            normalize_embeddings=True
        ).astype(np.float32)
        
        # Search
        scores, indices = self.index.search(query_embedding, k)
        
        # Build results
        results = []
        for idx, score in zip(indices[0], scores[0]):
            doc = self.metadata[idx].copy()
            doc["score"] = float(score)
            results.append(doc)
        
        return results


if __name__ == "__main__":
    # Example usage
    retriever = DentalRetriever("faiss_index")
    
    results = retriever.retrieve("root canal indications", k=3)
    for r in results:
        print(f"Score: {r[\'score\']:.4f} | {r[\'source\']}")
'''

retriever_path = os.path.join(OUTPUT_DIR, "retriever.py")
with open(retriever_path, 'w') as f:
    f.write(retriever_code)

print(f"‚úì Saved retriever helper to {retriever_path}")

## 8. Summary

In [None]:
# File sizes
index_size = os.path.getsize(index_path) / (1024**2)
metadata_size = os.path.getsize(metadata_path) / (1024**2)

print("="*60)
print("FAISS INDEX CREATION COMPLETE")
print("="*60)
print(f"üìä Total vectors: {index.ntotal}")
print(f"üìê Embedding dimension: {embedding_dim}")
print(f"üß† Model: {MODEL_NAME}")
print(f"üíæ Index size: {index_size:.1f} MB")
print(f"üìÑ Metadata size: {metadata_size:.1f} MB")
print(f"üìÅ Output: {OUTPUT_DIR}")
print("="*60)
print("Files created:")
print(f"  - dental.index (FAISS index)")
print(f"  - metadata.jsonl (document metadata)")
print(f"  - config.json (index configuration)")
print(f"  - retriever.py (helper class)")
print("="*60)
print("Next: Run 05_model_training.ipynb for QLoRA fine-tuning")
print("="*60)

## 9. Test Full Retrieval Pipeline

In [None]:
# Test queries
test_queries = [
    "What are the steps for root canal treatment?",
    "Explain the classification of dental caries",
    "What are contraindications for dental implants?",
    "How to manage dental emergencies?"
]

print("Testing retrieval with sample queries:\n")

for query in test_queries:
    print(f"Query: {query}")
    
    # Encode and search
    q_emb = embed_model.encode([query], normalize_embeddings=True).astype(np.float32)
    scores, indices = index.search(q_emb, 3)
    
    print("Top 3 results:")
    for idx, score in zip(indices[0], scores[0]):
        chunk = chunks[idx]
        print(f"  [{score:.3f}] {chunk['source'][:40]}... (p.{chunk['page_number']})")
    print()