# Lab 3.5.5: Reranking Pipeline

**Module:** 3.5 - RAG Systems & Vector Databases  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Advanced)

---

## üéØ Learning Objectives

By the end of this notebook, you will:

- [ ] Understand why reranking improves retrieval quality
- [ ] Implement two-stage retrieval with cross-encoders
- [ ] Load and use BGE-reranker on DGX Spark GPU
- [ ] Benchmark quality improvement vs latency cost
- [ ] Find the optimal first-stage K for your use case

---

## üìö Prerequisites

- Completed: Labs 3.5.1-3.5.4
- Understanding of: Bi-encoders vs Cross-encoders

---

## üåç Real-World Context

**The Problem:** Even with hybrid search, your top 5 results aren't always the best 5. The embedding model might rank a tangentially related chunk above the perfect match.

**The Solution:** Two-stage retrieval! First, quickly get 50 candidates. Then, use a more powerful model to rerank and select the final 5.

**Industry Usage:** Google Search, Bing, and enterprise search all use two-stage retrieval. Cohere's Rerank API is a commercial offering of exactly this.

---

## üßí ELI5: Bi-Encoders vs Cross-Encoders

> **Bi-Encoder (Fast, Less Accurate):**
> Imagine you and your friend each read a different book, then try to guess if your books are similar by describing them in a few sentences. Fast, but you might miss subtle connections.
>
> **Cross-Encoder (Slow, More Accurate):**
> Now imagine both of you read BOTH books together, discussing as you go. "Oh, this part connects to that part!" Much slower, but you catch every connection.
>
> **Two-Stage Retrieval:**
> - Stage 1 (Bi-Encoder): Quickly filter 1000 books down to 50 candidates
> - Stage 2 (Cross-Encoder): Carefully compare those 50 to pick the best 5

---

## Part 1: Setup

In [None]:
# Install dependencies
!pip install -q \
    langchain langchain-community langchain-huggingface \
    chromadb sentence-transformers \
    rank_bm25

print("‚úÖ Dependencies installed!")

In [None]:
import os
import time
from pathlib import Path
from typing import List, Dict, Tuple, Any
from dataclasses import dataclass
import numpy as np

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings

from sentence_transformers import CrossEncoder, SentenceTransformer

import torch
import gc

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Load and chunk documents
DOCS_PATH = Path("../data/sample_documents")

documents = []
for file_path in sorted(DOCS_PATH.glob("*.md")):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    documents.append(Document(
        page_content=content,
        metadata={"source": file_path.name}
    ))

splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
chunks = splitter.split_documents(documents)

print(f"üìö Loaded {len(documents)} documents ‚Üí {len(chunks)} chunks")

---

## Part 2: Understanding Cross-Encoders

### Bi-Encoder vs Cross-Encoder Architecture

```
BI-ENCODER:                          CROSS-ENCODER:
                                     
Query ‚îÄ‚îÄ‚ñ∫ Encoder ‚îÄ‚îÄ‚ñ∫ Query Vec      Query + Doc ‚îÄ‚îÄ‚ñ∫ Encoder ‚îÄ‚îÄ‚ñ∫ Score
                      ‚Üì cosine
Doc ‚îÄ‚îÄ‚ñ∫ Encoder ‚îÄ‚îÄ‚ñ∫ Doc Vec

- Encode separately                  - Encode together
- Compare vectors                    - Direct relevance score
- Fast (can pre-compute docs)        - Slow (must run for each pair)
- Used for: Initial retrieval        - Used for: Reranking
```

### Key Classes from sentence-transformers

| Class | Purpose |
|-------|---------|
| `SentenceTransformer(model)` | Bi-encoder: Encodes text into vectors separately |
| `CrossEncoder(model)` | Cross-encoder: Takes query+doc pairs, outputs relevance scores |
| `.encode(texts)` | SentenceTransformer method to get embeddings |
| `.predict(pairs)` | CrossEncoder method to get relevance scores for query-doc pairs |

In [None]:
# Load models
# ============
# HuggingFaceEmbeddings is a LangChain wrapper around SentenceTransformer
# It provides a consistent interface for embedding models

print("üîÑ Loading embedding model (bi-encoder)...")
bi_encoder = HuggingFaceEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)
print("‚úÖ Bi-encoder loaded!")

# CrossEncoder from sentence-transformers
# ========================================
# CrossEncoder takes a pair of texts [query, document] and outputs a relevance score
# Unlike bi-encoders, it cannot pre-compute document embeddings
# 
# Usage:
#   scores = cross_encoder.predict([[query1, doc1], [query1, doc2], ...])
#   Returns: array of relevance scores (higher = more relevant)

print("\nüîÑ Loading reranker model (cross-encoder)...")
cross_encoder = CrossEncoder(
    "BAAI/bge-reranker-large",  # Specialized model for reranking
    device="cuda" if torch.cuda.is_available() else "cpu"
)
print("‚úÖ Cross-encoder loaded!")

In [None]:
# Demonstrate the difference
query = "How much memory does DGX Spark have?"
doc1 = "The DGX Spark has 128GB of unified LPDDR5X memory."
doc2 = "Memory management is important for large models."

print(f"üìù Query: '{query}'")
print(f"üìÑ Doc 1: '{doc1}'")
print(f"üìÑ Doc 2: '{doc2}'")
print("-" * 60)

# Bi-encoder: embed separately, compute similarity
query_emb = bi_encoder.embed_query(query)
doc1_emb = bi_encoder.embed_query(doc1)
doc2_emb = bi_encoder.embed_query(doc2)

bi_score1 = np.dot(query_emb, doc1_emb)
bi_score2 = np.dot(query_emb, doc2_emb)

print(f"\nüîµ Bi-Encoder Scores (cosine similarity):")
print(f"   Doc 1: {bi_score1:.4f}")
print(f"   Doc 2: {bi_score2:.4f}")
print(f"   Bi-encoder prefers: Doc {'1' if bi_score1 > bi_score2 else '2'}")

# Cross-encoder: process together, get relevance score
cross_scores = cross_encoder.predict([
    [query, doc1],
    [query, doc2]
])

print(f"\nüü¢ Cross-Encoder Scores (relevance):")
print(f"   Doc 1: {cross_scores[0]:.4f}")
print(f"   Doc 2: {cross_scores[1]:.4f}")
print(f"   Cross-encoder prefers: Doc {'1' if cross_scores[0] > cross_scores[1] else '2'}")

### üîç What Just Happened?

Both encoders correctly identified Doc 1 as more relevant, but notice the cross-encoder has a much larger gap between scores. The cross-encoder can:

1. See the query and document together
2. Understand exact word matches ("DGX Spark", "128GB", "memory")
3. Evaluate semantic relevance more precisely

The trade-off? Cross-encoders are ~10-100x slower because they can't pre-compute document embeddings.

---

## Part 3: Building the Two-Stage Retriever

In [None]:
class TwoStageRetriever:
    """
    Two-stage retrieval with bi-encoder + cross-encoder reranking.
    """
    
    def __init__(
        self,
        documents: List[Document],
        bi_encoder: HuggingFaceEmbeddings,
        cross_encoder: CrossEncoder
    ):
        """
        Initialize with documents and models.
        """
        self.documents = documents
        self.bi_encoder = bi_encoder
        self.cross_encoder = cross_encoder
        
        # Pre-compute document embeddings
        print("   Computing document embeddings...")
        texts = [doc.page_content for doc in documents]
        self.embeddings = np.array(bi_encoder.embed_documents(texts))
        print(f"   ‚úÖ Embedded {len(documents)} documents")
        
    def search(
        self,
        query: str,
        k: int = 5,
        first_stage_k: int = 50
    ) -> List[Tuple[Document, float, Dict]]:
        """
        Two-stage retrieval.
        
        Args:
            query: Search query
            k: Final number of results
            first_stage_k: Candidates from first stage
            
        Returns:
            List of (document, score, metadata) tuples
        """
        # Stage 1: Fast bi-encoder retrieval
        stage1_start = time.time()
        
        query_emb = np.array(self.bi_encoder.embed_query(query))
        similarities = np.dot(self.embeddings, query_emb)
        top_indices = np.argsort(similarities)[-first_stage_k:][::-1]
        
        candidates = [(self.documents[i], similarities[i]) for i in top_indices]
        stage1_time = time.time() - stage1_start
        
        # Stage 2: Cross-encoder reranking
        stage2_start = time.time()
        
        pairs = [[query, doc.page_content] for doc, _ in candidates]
        rerank_scores = self.cross_encoder.predict(pairs)
        
        # Sort by rerank score
        reranked = sorted(
            zip(candidates, rerank_scores),
            key=lambda x: -x[1]
        )
        
        stage2_time = time.time() - stage2_start
        
        # Build results with metadata
        results = []
        for i, ((doc, bi_score), rerank_score) in enumerate(reranked[:k]):
            results.append((
                doc,
                rerank_score,
                {
                    "bi_encoder_score": float(bi_score),
                    "bi_encoder_rank": top_indices.tolist().index(
                        self.documents.index(doc)
                    ) + 1 if doc in self.documents else -1,
                    "rerank_score": float(rerank_score),
                    "final_rank": i + 1,
                    "stage1_time_ms": stage1_time * 1000,
                    "stage2_time_ms": stage2_time * 1000
                }
            ))
        
        return results
    
    def search_without_reranking(self, query: str, k: int = 5) -> List[Tuple[Document, float]]:
        """
        Single-stage retrieval (bi-encoder only) for comparison.
        """
        query_emb = np.array(self.bi_encoder.embed_query(query))
        similarities = np.dot(self.embeddings, query_emb)
        top_indices = np.argsort(similarities)[-k:][::-1]
        
        return [(self.documents[i], similarities[i]) for i in top_indices]


# Build the two-stage retriever
print("üîÑ Building two-stage retriever...")
two_stage = TwoStageRetriever(chunks, bi_encoder, cross_encoder)
print("‚úÖ Two-stage retriever ready!")

In [None]:
# Test the two-stage retriever
query = "What is the attention mechanism in transformers?"

print(f"üîç Query: '{query}'")
print("=" * 70)

# Compare with and without reranking
without_rerank = two_stage.search_without_reranking(query, k=5)
with_rerank = two_stage.search(query, k=5, first_stage_k=50)

print("\nüîµ WITHOUT Reranking (Bi-encoder only):")
for i, (doc, score) in enumerate(without_rerank):
    print(f"   {i+1}. [{score:.4f}] {doc.metadata['source']}")
    print(f"      {doc.page_content[:80]}...")

print("\nüü¢ WITH Reranking (Two-stage):")
for doc, score, meta in with_rerank:
    bi_rank = meta['bi_encoder_rank']
    final_rank = meta['final_rank']
    movement = bi_rank - final_rank
    arrow = "‚Üë" if movement > 0 else ("‚Üì" if movement < 0 else "=")
    
    print(f"   {final_rank}. [{score:.4f}] {doc.metadata['source']} (was #{bi_rank} {arrow})")
    print(f"      {doc.page_content[:80]}...")

---

## Part 4: Benchmarking Quality Improvement

In [None]:
# Evaluation dataset
eval_dataset = [
    {"question": "What is the memory capacity of DGX Spark?", "expected_source": "dgx_spark_technical_guide.md"},
    {"question": "How do Tensor Cores work?", "expected_source": "dgx_spark_technical_guide.md"},
    {"question": "Explain self-attention in transformers", "expected_source": "transformer_architecture_explained.md"},
    {"question": "What is positional encoding?", "expected_source": "transformer_architecture_explained.md"},
    {"question": "How does LoRA reduce memory?", "expected_source": "lora_finetuning_guide.md"},
    {"question": "What is QLoRA?", "expected_source": "lora_finetuning_guide.md"},
    {"question": "How does GPTQ quantization work?", "expected_source": "quantization_methods.md"},
    {"question": "What is GGUF format?", "expected_source": "quantization_methods.md"},
    {"question": "What are the benefits of RAG?", "expected_source": "rag_architecture_patterns.md"},
    {"question": "How to choose a vector database?", "expected_source": "vector_database_comparison.md"},
]

print(f"üìã Evaluation dataset: {len(eval_dataset)} queries")

In [None]:
def evaluate_retriever(
    retriever: TwoStageRetriever,
    eval_dataset: List[Dict],
    use_reranking: bool = True,
    first_stage_k: int = 50,
    k: int = 5
) -> Dict[str, Any]:
    """
    Evaluate retriever on the evaluation dataset.
    """
    correct_at_1 = 0
    correct_at_3 = 0
    correct_at_5 = 0
    total_time_ms = 0
    
    for item in eval_dataset:
        question = item["question"]
        expected = item["expected_source"]
        
        start = time.time()
        if use_reranking:
            results = retriever.search(question, k=k, first_stage_k=first_stage_k)
            sources = [r[0].metadata.get('source') for r in results]
        else:
            results = retriever.search_without_reranking(question, k=k)
            sources = [r[0].metadata.get('source') for r in results]
        total_time_ms += (time.time() - start) * 1000
        
        if sources and sources[0] == expected:
            correct_at_1 += 1
        if expected in sources[:3]:
            correct_at_3 += 1
        if expected in sources[:5]:
            correct_at_5 += 1
    
    n = len(eval_dataset)
    return {
        "recall@1": correct_at_1 / n,
        "recall@3": correct_at_3 / n,
        "recall@5": correct_at_5 / n,
        "avg_latency_ms": total_time_ms / n
    }


# Evaluate both methods
print("üìä Evaluating retrieval methods...")

no_rerank_metrics = evaluate_retriever(two_stage, eval_dataset, use_reranking=False)
rerank_metrics = evaluate_retriever(two_stage, eval_dataset, use_reranking=True)

print(f"\n{'Method':<25} {'R@1':<10} {'R@5':<10} {'Latency':<15}")
print("-" * 60)
print(f"{'Bi-encoder Only':<25} {no_rerank_metrics['recall@1']:<10.0%} "
      f"{no_rerank_metrics['recall@5']:<10.0%} {no_rerank_metrics['avg_latency_ms']:<15.1f}ms")
print(f"{'Two-Stage (Reranking)':<25} {rerank_metrics['recall@1']:<10.0%} "
      f"{rerank_metrics['recall@5']:<10.0%} {rerank_metrics['avg_latency_ms']:<15.1f}ms")

# Calculate improvement
r1_improvement = (rerank_metrics['recall@1'] - no_rerank_metrics['recall@1']) / no_rerank_metrics['recall@1'] * 100 if no_rerank_metrics['recall@1'] > 0 else 0
latency_increase = rerank_metrics['avg_latency_ms'] / no_rerank_metrics['avg_latency_ms']

print(f"\nüéØ Improvement: R@1 +{r1_improvement:.1f}%")
print(f"‚è±Ô∏è Latency: {latency_increase:.1f}x slower (still fast on GPU!)")

---

## Part 5: Finding Optimal First-Stage K

In [None]:
# Test different first-stage K values
print("üî¨ Finding optimal first-stage K...")
print("-" * 60)

k_values = [10, 20, 30, 50, 75, 100]
results = []

for first_stage_k in k_values:
    metrics = evaluate_retriever(
        two_stage, eval_dataset, 
        use_reranking=True, 
        first_stage_k=first_stage_k
    )
    results.append((first_stage_k, metrics))
    print(f"   K={first_stage_k:3d}: R@1={metrics['recall@1']:.0%}, "
          f"R@5={metrics['recall@5']:.0%}, Latency={metrics['avg_latency_ms']:.1f}ms")

# Find best K (optimize for R@1 with reasonable latency)
best_k, best_metrics = max(results, key=lambda x: x[1]['recall@1'])
print(f"\nüèÜ Best K: {best_k} (R@1: {best_metrics['recall@1']:.0%})")

In [None]:
# Visualize the trade-off
print("\nüìä Quality vs Latency Trade-off:")
print("=" * 60)

for first_stage_k, metrics in results:
    r1 = metrics['recall@1']
    latency = metrics['avg_latency_ms']
    
    # Visual bar
    quality_bar = "‚ñà" * int(r1 * 20)
    latency_bar = "‚ñë" * int(latency / 10)
    
    print(f"K={first_stage_k:3d} | Quality: {quality_bar:<20} {r1:.0%}")
    print(f"       | Latency: {latency_bar:<20} {latency:.0f}ms")
    print()

---

## Part 6: Production Reranking Pipeline

For production use, we use `SentenceTransformer` directly (instead of the LangChain wrapper) for better control over batching and performance.

### SentenceTransformer vs HuggingFaceEmbeddings

| Feature | SentenceTransformer | HuggingFaceEmbeddings |
|---------|--------------------|-----------------------|
| Library | sentence-transformers | LangChain wrapper |
| Batching | Built-in batch_size parameter | Via encode_kwargs |
| Return type | numpy array | Python list |
| Use case | Direct control, production | LangChain integration |

In [None]:
class ProductionReranker:
    """
    Production-ready reranking pipeline with optimizations.
    
    Uses sentence-transformers directly for better performance:
    - SentenceTransformer: For fast bi-encoder embeddings with batching
    - CrossEncoder: For accurate reranking with batch processing
    """
    
    def __init__(
        self,
        documents: List[Document],
        bi_encoder_model: str = "BAAI/bge-large-en-v1.5",
        reranker_model: str = "BAAI/bge-reranker-large",
        device: str = "cuda"
    ):
        self.documents = documents
        self.device = device if torch.cuda.is_available() else "cpu"
        
        # Load models using sentence-transformers directly
        # SentenceTransformer: Bi-encoder for fast initial retrieval
        # - .encode(texts, batch_size, normalize_embeddings) ‚Üí numpy array
        print(f"   Loading bi-encoder...")
        self.bi_encoder = SentenceTransformer(bi_encoder_model, device=self.device)
        
        # CrossEncoder: For accurate pairwise relevance scoring
        # - .predict(pairs, batch_size) ‚Üí numpy array of scores
        print(f"   Loading reranker...")
        self.reranker = CrossEncoder(reranker_model, device=self.device)
        
        # Pre-compute embeddings with batching for efficiency
        # normalize_embeddings=True enables cosine similarity via dot product
        print(f"   Computing embeddings...")
        texts = [doc.page_content for doc in documents]
        self.embeddings = self.bi_encoder.encode(
            texts,
            batch_size=32,              # Process 32 docs at a time
            show_progress_bar=False,
            normalize_embeddings=True    # Enable dot product = cosine similarity
        )
        
        print(f"   ‚úÖ Ready!")
    
    def search(
        self,
        query: str,
        k: int = 5,
        first_stage_k: int = 50,
        score_threshold: float = None
    ) -> List[Dict[str, Any]]:
        """
        Production search with optional score threshold.
        
        Args:
            query: Search query
            k: Final number of results to return
            first_stage_k: Candidates to retrieve before reranking
            score_threshold: Minimum reranker score to include
        """
        # Stage 1: Bi-encoder retrieval (fast)
        # encode single query, get embedding vector
        query_emb = self.bi_encoder.encode(
            query,
            normalize_embeddings=True
        )
        
        # Dot product with all document embeddings (= cosine sim for normalized vectors)
        similarities = np.dot(self.embeddings, query_emb)
        top_indices = np.argsort(similarities)[-first_stage_k:][::-1]
        
        # Stage 2: Reranking with CrossEncoder (accurate)
        # predict() takes list of [query, doc] pairs
        pairs = [[query, self.documents[i].page_content] for i in top_indices]
        rerank_scores = self.reranker.predict(
            pairs,
            batch_size=32,
            show_progress_bar=False
        )
        
        # Build results sorted by reranker score
        results = []
        sorted_results = sorted(
            zip(top_indices, rerank_scores),
            key=lambda x: -x[1]
        )
        
        for idx, score in sorted_results[:k]:
            # Optional score threshold filtering
            if score_threshold and score < score_threshold:
                continue
                
            results.append({
                "content": self.documents[idx].page_content,
                "metadata": self.documents[idx].metadata,
                "score": float(score),
                "bi_encoder_score": float(similarities[idx])
            })
        
        return results


# Build production reranker
print("üîÑ Building production reranker...")
prod_reranker = ProductionReranker(chunks)
print("‚úÖ Production reranker ready!")

In [None]:
# Test production reranker
query = "How do I fine-tune a large language model efficiently?"

print(f"üîç Query: '{query}'")
print("=" * 70)

# Time the query
start = time.time()
results = prod_reranker.search(query, k=5, first_stage_k=50)
latency = (time.time() - start) * 1000

print(f"‚è±Ô∏è Total latency: {latency:.1f}ms\n")

for i, result in enumerate(results):
    print(f"üîπ Result {i+1}:")
    print(f"   Source: {result['metadata']['source']}")
    print(f"   Score: {result['score']:.4f} (bi-encoder: {result['bi_encoder_score']:.4f})")
    print(f"   Content: {result['content'][:100]}...")
    print()

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: First-Stage K Too Small
```python
# ‚ùå Wrong: Only get 5 candidates, then rerank
results = retriever.search(query, k=5, first_stage_k=5)

# ‚úÖ Right: Get more candidates for reranking
results = retriever.search(query, k=5, first_stage_k=50)
```
**Why:** If the best document isn't in the top 5 from the bi-encoder, reranking can't find it!

### Mistake 2: Not Batching Reranker Calls
```python
# ‚ùå Wrong: One by one (slow!)
for pair in pairs:
    score = reranker.predict([pair])[0]

# ‚úÖ Right: Batch processing
scores = reranker.predict(pairs, batch_size=32)
```

### Mistake 3: Using Reranker for Initial Retrieval
```python
# ‚ùå Wrong: Rerank all documents (impossible for large corpora)
pairs = [[query, doc] for doc in all_documents]  # 1M documents?
scores = reranker.predict(pairs)  # Hours of compute!

# ‚úÖ Right: Bi-encoder first, then rerank candidates
candidates = bi_encoder.search(query, k=100)  # Fast!
reranked = reranker.predict([[query, c] for c in candidates])  # Reasonable
```

---

## ‚úã Try It Yourself

### Exercise 1: Different Reranker Models
Try `BAAI/bge-reranker-base` instead of `-large`. Compare quality and speed.

### Exercise 2: Score Threshold
Add a minimum score threshold to filter low-confidence results.

### Exercise 3: Combine with Hybrid Search
Use hybrid search (Lab 3.5.4) as the first stage instead of pure bi-encoder.

<details>
<summary>üí° Hint for Exercise 3</summary>

```python
class HybridTwoStageRetriever:
    def __init__(self, hybrid_retriever, cross_encoder):
        self.hybrid = hybrid_retriever
        self.reranker = cross_encoder
    
    def search(self, query, k=5, first_stage_k=50):
        # Stage 1: Hybrid search
        candidates = self.hybrid.search(query, k=first_stage_k)
        
        # Stage 2: Rerank
        pairs = [[query, c[0].page_content] for c in candidates]
        scores = self.reranker.predict(pairs)
        
        # Return top-k by rerank score
        ...
```
</details>

---

## üéâ Checkpoint

You've learned:
- ‚úÖ The difference between bi-encoders and cross-encoders
- ‚úÖ How to implement two-stage retrieval with reranking
- ‚úÖ How to benchmark quality improvement vs latency
- ‚úÖ How to find the optimal first-stage K

**Key Insight:** Reranking can significantly improve retrieval quality with manageable latency cost, especially on GPU-equipped systems like DGX Spark!

---

## üßπ Cleanup

In [None]:
# Clean up
del bi_encoder, cross_encoder, two_stage, prod_reranker
gc.collect()
torch.cuda.empty_cache()

print("‚úÖ Cleanup complete!")

---

## Next Steps

In the next lab, we'll learn how to **evaluate RAG systems** with RAGAS!

‚û°Ô∏è Continue to [Lab 3.5.6: RAGAS Evaluation](./lab-3.5.6-evaluation.ipynb)