# Lab 3.5.5 Solutions: Cross-Encoder Reranking

Complete solutions for two-stage retrieval with cross-encoder reranking.

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

from pathlib import Path
from typing import List, Dict, Any, Tuple
import numpy as np
import torch
import time

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

In [None]:
# Load and prepare data
def load_and_chunk():
    documents = []
    for file_path in Path("../data/sample_documents").glob("*.md"):
        content = file_path.read_text(encoding='utf-8')
        documents.append(Document(
            page_content=content,
            metadata={"source": file_path.name}
        ))
    
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    return splitter.split_documents(documents)

chunks = load_and_chunk()
print(f"Loaded {len(chunks)} chunks")

# Load embedding model
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_model = HuggingFaceEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": device},
    encode_kwargs={"normalize_embeddings": True}
)

## Exercise 1 Solution: Understand Bi-Encoder vs Cross-Encoder

**Task**: Implement and compare both architectures.

In [None]:
class BiEncoderScorer:
    """
    Bi-encoder: Encodes query and document separately.
    
    Architecture:
        Query -> Encoder -> query_emb
        Document -> Encoder -> doc_emb
        Score = cosine_similarity(query_emb, doc_emb)
    
    Pros: Fast (can pre-compute doc embeddings)
    Cons: Limited interaction between query and document
    """
    
    def __init__(self, embedding_model: HuggingFaceEmbeddings):
        self.embedding_model = embedding_model
    
    def score(self, query: str, document: str) -> float:
        """Score a query-document pair."""
        query_emb = np.array(self.embedding_model.embed_query(query))
        doc_emb = np.array(self.embedding_model.embed_documents([document])[0])
        
        # Cosine similarity
        return float(np.dot(query_emb, doc_emb) / (
            np.linalg.norm(query_emb) * np.linalg.norm(doc_emb) + 1e-8
        ))
    
    def score_batch(self, query: str, documents: List[str]) -> List[float]:
        """Score multiple documents (efficient with batch encoding)."""
        query_emb = np.array(self.embedding_model.embed_query(query))
        doc_embs = np.array(self.embedding_model.embed_documents(documents))
        
        # Batch cosine similarity
        similarities = np.dot(doc_embs, query_emb) / (
            np.linalg.norm(doc_embs, axis=1) * np.linalg.norm(query_emb) + 1e-8
        )
        return similarities.tolist()


class CrossEncoderScorer:
    """
    Cross-encoder: Encodes query and document together.
    
    Architecture:
        [CLS] query [SEP] document [SEP] -> Encoder -> Score
    
    Pros: Full attention between query and document, higher accuracy
    Cons: Slow (can't pre-compute, O(n) for n documents)
    """
    
    def __init__(
        self,
        model_name: str = "BAAI/bge-reranker-large",
        device: str = None
    ):
        from sentence_transformers import CrossEncoder
        
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.model = CrossEncoder(model_name, device=device)
        self.device = device
        print(f"Cross-encoder loaded on {device}")
    
    def score(self, query: str, document: str) -> float:
        """Score a single query-document pair."""
        return float(self.model.predict([[query, document]])[0])
    
    def score_batch(self, query: str, documents: List[str]) -> List[float]:
        """Score multiple documents."""
        pairs = [[query, doc] for doc in documents]
        scores = self.model.predict(pairs, batch_size=32, show_progress_bar=False)
        return scores.tolist()

# Create scorers
bi_encoder = BiEncoderScorer(embedding_model)
cross_encoder = CrossEncoderScorer()

print("\nBoth scorers created!")

In [None]:
# Compare scoring behavior
test_query = "What is the memory capacity of DGX Spark?"

test_documents = [
    "DGX Spark features 128GB of unified memory shared between CPU and GPU.",
    "The memory system uses LPDDR5X technology for high bandwidth.",
    "GPU memory is important for training large models.",
    "The weather today is sunny and warm."
]

print(f"Query: {test_query}\n")
print("Scoring comparison:")
print(f"{'Document':<60} {'Bi-Encoder':<12} {'Cross-Encoder':<12}")
print("-"*84)

bi_scores = bi_encoder.score_batch(test_query, test_documents)
cross_scores = cross_encoder.score_batch(test_query, test_documents)

for doc, bi_score, cross_score in zip(test_documents, bi_scores, cross_scores):
    doc_short = doc[:55] + "..." if len(doc) > 55 else doc
    print(f"{doc_short:<60} {bi_score:<12.4f} {cross_score:<12.4f}")

## Exercise 2 Solution: Implement Two-Stage Retrieval

**Task**: Build complete two-stage pipeline with reranking.

In [None]:
class TwoStageRetriever:
    """
    Two-stage retrieval: Bi-encoder for recall, Cross-encoder for precision.
    
    Stage 1 (Bi-encoder):
    - Fast approximate search
    - High recall, moderate precision
    - Returns top-N candidates (typically 50-100)
    
    Stage 2 (Cross-encoder):
    - Accurate reranking
    - High precision
    - Reranks to top-K (typically 5-10)
    """
    
    def __init__(
        self,
        documents: List[Document],
        embedding_model: HuggingFaceEmbeddings,
        reranker_model: str = "BAAI/bge-reranker-large"
    ):
        self.documents = documents
        self.embedding_model = embedding_model
        
        # Pre-compute document embeddings for Stage 1
        print("Pre-computing document embeddings...")
        texts = [doc.page_content for doc in documents]
        self.embeddings = np.array(embedding_model.embed_documents(texts))
        print(f"Embeddings shape: {self.embeddings.shape}")
        
        # Initialize cross-encoder for Stage 2
        print("Loading cross-encoder...")
        from sentence_transformers import CrossEncoder
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.reranker = CrossEncoder(reranker_model, device=device)
        print("Two-stage retriever ready!")
    
    def search(
        self,
        query: str,
        k: int = 5,
        first_stage_k: int = 50
    ) -> List[Dict[str, Any]]:
        """
        Two-stage search.
        
        Args:
            query: Search query
            k: Final number of results
            first_stage_k: Candidates from Stage 1
        
        Returns:
            List of result dicts with scores and timing
        """
        # Stage 1: Bi-encoder retrieval
        stage1_start = time.time()
        query_emb = np.array(self.embedding_model.embed_query(query))
        similarities = np.dot(self.embeddings, query_emb)
        top_indices = np.argsort(similarities)[-first_stage_k:][::-1]
        stage1_time = (time.time() - stage1_start) * 1000
        
        # Get candidates
        candidates = [(self.documents[i], similarities[i], i) for i in top_indices]
        
        # Stage 2: Cross-encoder reranking
        stage2_start = time.time()
        pairs = [[query, doc.page_content] for doc, _, _ in candidates]
        rerank_scores = self.reranker.predict(pairs, batch_size=32, show_progress_bar=False)
        stage2_time = (time.time() - stage2_start) * 1000
        
        # Combine and sort
        reranked = sorted(
            zip(candidates, rerank_scores),
            key=lambda x: -x[1]  # Sort by rerank score descending
        )[:k]
        
        # Build results
        results = []
        for i, ((doc, bi_score, orig_idx), rerank_score) in enumerate(reranked):
            # Find original bi-encoder rank
            bi_rank = list(top_indices).index(orig_idx) + 1
            
            results.append({
                "document": doc,
                "content": doc.page_content,
                "metadata": doc.metadata,
                "final_rank": i + 1,
                "rerank_score": float(rerank_score),
                "bi_encoder_score": float(bi_score),
                "bi_encoder_rank": bi_rank,
                "rank_change": bi_rank - (i + 1),
                "stage1_time_ms": stage1_time,
                "stage2_time_ms": stage2_time
            })
        
        return results
    
    def search_without_reranking(self, query: str, k: int = 5) -> List[Dict]:
        """Single-stage search (bi-encoder only) for comparison."""
        start = time.time()
        query_emb = np.array(self.embedding_model.embed_query(query))
        similarities = np.dot(self.embeddings, query_emb)
        top_indices = np.argsort(similarities)[-k:][::-1]
        elapsed = (time.time() - start) * 1000
        
        return [
            {
                "document": self.documents[i],
                "content": self.documents[i].page_content,
                "metadata": self.documents[i].metadata,
                "rank": rank + 1,
                "score": float(similarities[i]),
                "time_ms": elapsed
            }
            for rank, i in enumerate(top_indices)
        ]

# Create two-stage retriever
retriever = TwoStageRetriever(chunks, embedding_model)

## Exercise 3 Solution: Analyze Reranking Impact

**Task**: Measure how reranking changes rankings.

In [None]:
def analyze_reranking_impact(
    retriever: TwoStageRetriever,
    test_queries: List[str],
    k: int = 5,
    first_stage_k: int = 50
) -> Dict[str, Any]:
    """
    Analyze how reranking changes document rankings.
    
    Metrics:
    - Average rank change
    - Percentage of docs that moved up
    - Top-1 stability
    """
    
    all_rank_changes = []
    top1_changed = 0
    moved_up = 0
    moved_down = 0
    total_docs = 0
    
    latencies = {"stage1": [], "stage2": []}
    
    for query in test_queries:
        results = retriever.search(query, k=k, first_stage_k=first_stage_k)
        
        for r in results:
            rank_change = r["rank_change"]
            all_rank_changes.append(abs(rank_change))
            
            if rank_change > 0:
                moved_up += 1
            elif rank_change < 0:
                moved_down += 1
            
            total_docs += 1
        
        # Check if top-1 changed
        if results[0]["bi_encoder_rank"] != 1:
            top1_changed += 1
        
        latencies["stage1"].append(results[0]["stage1_time_ms"])
        latencies["stage2"].append(results[0]["stage2_time_ms"])
    
    return {
        "avg_rank_change": np.mean(all_rank_changes),
        "max_rank_change": np.max(all_rank_changes),
        "pct_moved_up": 100 * moved_up / total_docs,
        "pct_moved_down": 100 * moved_down / total_docs,
        "pct_unchanged": 100 * (total_docs - moved_up - moved_down) / total_docs,
        "top1_change_rate": 100 * top1_changed / len(test_queries),
        "avg_stage1_ms": np.mean(latencies["stage1"]),
        "avg_stage2_ms": np.mean(latencies["stage2"])
    }

# Test queries
test_queries = [
    "What is the memory capacity of DGX Spark?",
    "How does LoRA reduce trainable parameters?",
    "Explain transformer attention mechanism",
    "What quantization methods are available?",
    "Compare vector databases for RAG",
    "How to fine-tune a large language model efficiently?",
    "What is the difference between GPTQ and AWQ?",
    "Explain retrieval augmented generation"
]

# Analyze impact
impact = analyze_reranking_impact(retriever, test_queries)

print("="*60)
print("RERANKING IMPACT ANALYSIS")
print("="*60)
print(f"Average rank change:     {impact['avg_rank_change']:.2f} positions")
print(f"Maximum rank change:     {impact['max_rank_change']:.0f} positions")
print(f"Documents moved up:      {impact['pct_moved_up']:.1f}%")
print(f"Documents moved down:    {impact['pct_moved_down']:.1f}%")
print(f"Documents unchanged:     {impact['pct_unchanged']:.1f}%")
print(f"Top-1 change rate:       {impact['top1_change_rate']:.1f}%")
print(f"\nLatency breakdown:")
print(f"  Stage 1 (bi-encoder):  {impact['avg_stage1_ms']:.2f}ms")
print(f"  Stage 2 (reranking):   {impact['avg_stage2_ms']:.2f}ms")
print(f"  Total:                 {impact['avg_stage1_ms'] + impact['avg_stage2_ms']:.2f}ms")

## Exercise 4 Solution: Compare With and Without Reranking

**Task**: Evaluate quality improvement from reranking.

In [None]:
def evaluate_retrieval_quality(
    retriever: TwoStageRetriever,
    test_queries: List[Dict[str, Any]],
    k: int = 5
) -> Dict[str, Dict]:
    """
    Compare retrieval quality with and without reranking.
    """
    
    results = {
        "without_reranking": {"recalls": [], "mrrs": []},
        "with_reranking": {"recalls": [], "mrrs": []}
    }
    
    for query_data in test_queries:
        query = query_data["query"]
        expected = query_data.get("expected_source", "")
        
        # Without reranking
        simple_results = retriever.search_without_reranking(query, k=k)
        
        # With reranking
        reranked_results = retriever.search(query, k=k)
        
        # Calculate metrics for each
        for method, search_results in [
            ("without_reranking", simple_results),
            ("with_reranking", reranked_results)
        ]:
            found = False
            mrr = 0
            
            for i, r in enumerate(search_results):
                source = r["metadata"].get("source", "")
                if expected in source:
                    found = True
                    mrr = 1 / (i + 1)
                    break
            
            results[method]["recalls"].append(1.0 if found else 0.0)
            results[method]["mrrs"].append(mrr)
    
    # Aggregate
    summary = {}
    for method, metrics in results.items():
        summary[method] = {
            "recall@k": np.mean(metrics["recalls"]),
            "mrr": np.mean(metrics["mrrs"])
        }
    
    return summary

# Labeled test queries
labeled_queries = [
    {"query": "DGX Spark memory capacity", "expected_source": "dgx_spark"},
    {"query": "LoRA parameter reduction", "expected_source": "lora"},
    {"query": "transformer attention", "expected_source": "transformer"},
    {"query": "quantization methods comparison", "expected_source": "quantization"},
    {"query": "vector database for retrieval", "expected_source": "vector_database"},
    {"query": "128GB unified memory GPU", "expected_source": "dgx_spark"},
    {"query": "efficient fine-tuning with low rank", "expected_source": "lora"},
    {"query": "NVFP4 4-bit quantization", "expected_source": "quantization"}
]

quality = evaluate_retrieval_quality(retriever, labeled_queries)

print("\n" + "="*60)
print("QUALITY COMPARISON: WITH vs WITHOUT RERANKING")
print("="*60)
print(f"{'Metric':<20} {'Without Reranking':<20} {'With Reranking':<20} {'Improvement':<15}")
print("-"*75)

for metric in ["recall@k", "mrr"]:
    without = quality["without_reranking"][metric]
    with_rr = quality["with_reranking"][metric]
    improvement = (with_rr - without) / (without + 1e-8) * 100
    
    print(f"{metric:<20} {without:<20.3f} {with_rr:<20.3f} {improvement:+.1f}%")

## Exercise 5 Solution: Optimize First-Stage K

**Task**: Find optimal candidate pool size.

In [None]:
def optimize_first_stage_k(
    retriever: TwoStageRetriever,
    test_queries: List[Dict],
    k_values: List[int] = [10, 20, 50, 100],
    final_k: int = 5
) -> List[Dict]:
    """
    Find optimal first_stage_k balancing quality and latency.
    """
    
    results = []
    
    for first_k in k_values:
        recalls = []
        mrrs = []
        latencies = []
        
        for query_data in test_queries:
            query = query_data["query"]
            expected = query_data.get("expected_source", "")
            
            search_results = retriever.search(query, k=final_k, first_stage_k=first_k)
            
            # Check recall
            found = False
            mrr = 0
            for i, r in enumerate(search_results):
                if expected in r["metadata"].get("source", ""):
                    found = True
                    mrr = 1 / (i + 1)
                    break
            
            recalls.append(1.0 if found else 0.0)
            mrrs.append(mrr)
            latencies.append(search_results[0]["stage1_time_ms"] + search_results[0]["stage2_time_ms"])
        
        results.append({
            "first_stage_k": first_k,
            "recall@k": np.mean(recalls),
            "mrr": np.mean(mrrs),
            "latency_ms": np.mean(latencies)
        })
    
    return results

# Run optimization
k_results = optimize_first_stage_k(
    retriever,
    labeled_queries,
    k_values=[10, 20, 30, 50, 75, 100]
)

print("\n" + "="*70)
print("FIRST-STAGE K OPTIMIZATION")
print("="*70)
print(f"{'First-Stage K':<15} {'Recall@5':<12} {'MRR':<12} {'Latency (ms)':<15}")
print("-"*54)

best_quality = max(k_results, key=lambda x: x['mrr'])
best_latency = min(k_results, key=lambda x: x['latency_ms'])

for r in k_results:
    markers = []
    if r == best_quality:
        markers.append("BEST QUALITY")
    if r == best_latency:
        markers.append("FASTEST")
    marker_str = " <-- " + ", ".join(markers) if markers else ""
    
    print(f"{r['first_stage_k']:<15} {r['recall@k']:<12.3f} {r['mrr']:<12.3f} {r['latency_ms']:<15.2f}{marker_str}")

print(f"\nRecommendation: Use first_stage_k={best_quality['first_stage_k']} for best quality")
print(f"Alternative: Use first_stage_k={best_latency['first_stage_k']} for lowest latency")

## Summary

In [None]:
summary = """
RERANKING BEST PRACTICES
========================

1. ARCHITECTURE CHOICE
   - Bi-encoder: Fast, good for initial retrieval
   - Cross-encoder: Accurate, use for reranking top candidates
   - Combine both for optimal quality/latency tradeoff

2. FIRST-STAGE K SELECTION
   - Too low (< 20): May miss relevant documents
   - Too high (> 100): Increases reranking latency
   - Sweet spot: 30-50 for most use cases

3. MODEL SELECTION
   - BGE-reranker-large: Best quality, ~300ms/query
   - BGE-reranker-base: Good balance, ~100ms/query
   - MiniLM-based: Fastest, slight quality drop

4. DGX SPARK OPTIMIZATION
   - Use GPU for both bi-encoder and cross-encoder
   - Batch reranking requests for throughput
   - Consider caching embeddings in GPU memory

5. WHEN TO USE RERANKING
   - High-stakes queries (customer support, legal)
   - Complex semantic queries
   - When latency budget allows (~300ms extra)
   
6. WHEN TO SKIP RERANKING
   - Simple keyword queries
   - Real-time applications with strict latency
   - When bi-encoder quality is sufficient
"""

print(summary)