# Retrieval Module Verification

This notebook tests the `src.retrieval` module which handles:
- **Dense Vector Search** using Qdrant + BGE embeddings
- **Cross-Encoder Reranking** for high precision (top 20 â†’ top 5)
- **Parent Content Extraction** via `format_docs_for_gen()`

Uses a **small subset (50 docs)** for fast testing.

In [None]:
import sys
import os
import json
import zipfile

sys.path.append(os.path.abspath(".."))
PROJECT_ROOT = os.path.abspath("..")
QDRANT_PATH = os.path.join(PROJECT_ROOT, "qdrant_test_db")
MAX_DOCS = 50

print(f"Project root: {PROJECT_ROOT}")
print(f"Test subset size: {MAX_DOCS} documents")

## Step 0: Prepare Test Collection

Create a small Qdrant collection for testing retrieval.

In [None]:
# Extract corpus if needed
corpus_dir = os.path.join(PROJECT_ROOT, "dataset/corpora/passage_level")
jsonl_file = os.path.join(corpus_dir, "govt.jsonl")
zip_file = os.path.join(corpus_dir, "govt.jsonl.zip")

if not os.path.exists(jsonl_file) and os.path.exists(zip_file):
    print("Extracting corpus...")
    with zipfile.ZipFile(zip_file, 'r') as zf:
        zf.extractall(corpus_dir)
    print("Corpus extracted")
else:
    print(f"Corpus ready: govt.jsonl")

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore

# Check if collection exists
need_create = True
if os.path.exists(QDRANT_PATH):
    try:
        client = QdrantClient(path=QDRANT_PATH)
        info = client.get_collection("mtrag_test")
        print(f"Test collection exists: {info.points_count} points")
        client.close()
        need_create = False
    except:
        pass

if need_create:
    print("Creating test collection...")
    
    # Load subset
    docs = []
    with open(jsonl_file, 'r') as f:
        for i, line in enumerate(f):
            if i >= MAX_DOCS:
                break
            item = json.loads(line)
            text = item.get("text", "").strip()
            if text:
                docs.append(Document(page_content=text, metadata={"doc_id": item.get("id", str(i))}))
    print(f"   â€¢ Loaded {len(docs)} documents")
    
    # Chunk
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    chunks = splitter.split_documents(docs)
    print(f"   â€¢ Split into {len(chunks)} chunks")
    
    # Build embeddings (using small model for speed)
    print("   â€¢ Building embeddings (bge-small-en)...")
    embedding_model = HuggingFaceEmbeddings(
        model_name="BAAI/bge-small-en-v1.5",
        model_kwargs={"device": "cpu"}
    )
    
    # Create collection
    client = QdrantClient(path=QDRANT_PATH)
    if client.collection_exists("mtrag_test"):
        client.delete_collection("mtrag_test")
    client.create_collection(
        collection_name="mtrag_test",
        vectors_config=VectorParams(size=384, distance=Distance.COSINE)
    )
    
    vectorstore = QdrantVectorStore(client=client, collection_name="mtrag_test", embedding=embedding_model)
    vectorstore.add_documents(chunks)
    print(f"\nTest collection created with {len(chunks)} chunks")

## Step 1: Initialize Retriever

Create a retriever from the vector store.

In [None]:
# Reinitialize for clean state
embedding_model = HuggingFaceEmbeddings(
    model_name="BAAI/bge-small-en-v1.5",
    model_kwargs={"device": "cpu"}
)

client = QdrantClient(path=QDRANT_PATH)
vectorstore = QdrantVectorStore(
    client=client, 
    collection_name="mtrag_test", 
    embedding=embedding_model
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

print(f"Retriever Configuration:")
print(f"   â€¢ Type: {type(retriever).__name__}")
print(f"   â€¢ Top-K: 5 documents")
print(f"   â€¢ Embedding model: bge-small-en-v1.5")

print("\nRetriever initialized!")

## Step 2: Test Retrieval

Query the retriever and examine results.

In [None]:
query = "government regulations"
print(f"Query: '{query}'")

docs = retriever.invoke(query)

print(f"\nRetrieved {len(docs)} documents:")
for i, doc in enumerate(docs[:3]):
    print(f"\n   Document {i+1}:")
    print(f"   â€¢ Content: {doc.page_content[:100]}...")
    print(f"   â€¢ Doc ID: {doc.metadata.get('doc_id', 'N/A')}")

print("\nRetrieval working correctly!")

## Step 3: Test `format_docs_for_gen()`

This function from `src.retrieval`:
1. Extracts parent content from retrieved documents
2. Deduplicates to avoid repetition
3. Concatenates into context string for LLM

In [None]:
from src.retrieval import format_docs_for_gen

context = format_docs_for_gen(docs)

print(f"Context Statistics:")
print(f"   â€¢ Total length: {len(context)} characters")
print(f"   â€¢ Unique documents: {len(docs)}")

print(f"\nðŸ“„ Context Preview (first 300 chars):")
print(f"   {context[:300]}...")

print("\nformat_docs_for_gen() working correctly!")

## Cleanup
Remove test files after verification.

In [None]:
import shutil

# Close client first
client.close()

# Remove test database
if os.path.exists(QDRANT_PATH):
    shutil.rmtree(QDRANT_PATH)
    print(f"Removed test database: {QDRANT_PATH}")

print("\nAll retrieval tests passed!")