# Re-ranking with Cross-Encoders

## Welcome!
We've learned retrieval methods (BM25, Dense, Hybrid). But there's a problem:
**Initial retrieval is fast but not always accurate.**

## The Problem

```
Bi-Encoder (what we've been using):
┌─────────────────────────────────────────────────┐
│  Query → Embedding  ─────┐                      │
│                          ├── Compare (fast!)    │
│  Document → Embedding ───┘                      │
└─────────────────────────────────────────────────┘
Problem: Query and document are encoded SEPARATELY.
         No direct interaction = might miss nuances.
```

## The Solution: Cross-Encoders

```
Cross-Encoder:
┌─────────────────────────────────────────────────┐
│  [Query + Document] → Model → Relevance Score   │
└─────────────────────────────────────────────────┘
Query and document processed TOGETHER.
Much more accurate, but slower.
```

## The Strategy: Two-Stage Retrieval

```
Stage 1 (Fast): Retrieve top 20-50 candidates (Bi-Encoder)
Stage 2 (Accurate): Re-rank to get top 3-5 (Cross-Encoder)
```

**This is how production RAG systems work!**

## What You'll Learn
1. Bi-Encoder vs Cross-Encoder differences
2. How to use cross-encoders for re-ranking
3. Two-stage retrieval pipeline
4. When re-ranking helps most

## Step 1: Environment Setup

In [None]:
# Load environment variables
from dotenv import load_dotenv
import os

load_dotenv()
print("Environment loaded!")

In [None]:
# Install required packages (run once)
# !pip install sentence-transformers

## Step 2: Load and Prepare Documents

In [None]:
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma

# Load PDF
pdf_path = "../LangChain_projects/llm_fundamentals.pdf"
if not os.path.exists(pdf_path):
    pdf_path = "../RAG/llm_fundamentals.pdf"

loader = PyPDFLoader(pdf_path)
documents = loader.load()

# Split into chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50
)
chunks = text_splitter.split_documents(documents)

print(f"Loaded {len(chunks)} chunks")

# Create vector store for initial retrieval
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(
    documents=chunks,
    embedding=embeddings,
    collection_name="rerank_demo"
)

print("Vector store ready for initial retrieval!")

---
## Part 1: Understanding Bi-Encoder vs Cross-Encoder

Let's visualize the difference:

In [None]:
explanation = """
BI-ENCODER (Fast, Less Accurate)
================================
Used for: Initial retrieval from large collections

How it works:
1. Encode query ONCE → query_embedding
2. Encode all documents ONCE (pre-computed) → doc_embeddings
3. Compare using simple cosine similarity

       Query          Document
         ↓               ↓
    [Encoder]       [Encoder]
         ↓               ↓
    [Vector]        [Vector]
         ↓               ↓
         └───Compare────┘
              (fast!)

Pros: Very fast (can search millions of docs)
Cons: Query and doc don't "see" each other


CROSS-ENCODER (Slow, Very Accurate)
====================================
Used for: Re-ranking a small set of candidates

How it works:
1. Concatenate query + document
2. Pass through model TOGETHER
3. Model outputs relevance score directly

    [Query, Document]
           ↓
      [Encoder]
           ↓
    Relevance Score

Pros: Very accurate (sees full context)
Cons: Slow (must process each query-doc pair)
"""
print(explanation)

---
## Part 2: Load Cross-Encoder Model

In [None]:
from sentence_transformers import CrossEncoder

# Load a pre-trained cross-encoder model
# This model is specifically trained for re-ranking!
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

print("Cross-encoder model loaded!")
print("Model: cross-encoder/ms-marco-MiniLM-L-6-v2")
print("Trained on: MS MARCO (Microsoft Machine Reading Comprehension)")

---
## Part 3: Basic Cross-Encoder Usage

Let's see how the cross-encoder scores query-document pairs.

In [None]:
# Example: Score query-document pairs
query = "What is LoRA?"

# Some example documents (one relevant, one not)
doc_relevant = "LoRA (Low-Rank Adaptation) is a method for efficiently fine-tuning large language models by updating only small matrices."
doc_irrelevant = "The weather today is sunny with a high of 75 degrees."
doc_somewhat = "Fine-tuning large models requires significant computational resources."

# Cross-encoder takes pairs of (query, document)
pairs = [
    (query, doc_relevant),
    (query, doc_somewhat),
    (query, doc_irrelevant)
]

# Get relevance scores
scores = cross_encoder.predict(pairs)

print(f"Query: '{query}'")
print("\nCross-Encoder Scores (higher = more relevant):")
print("="*60)
print(f"\nRelevant doc:   {scores[0]:.4f}")
print(f"Document: {doc_relevant[:60]}...")
print(f"\nSomewhat related: {scores[1]:.4f}")
print(f"Document: {doc_somewhat[:60]}...")
print(f"\nIrrelevant doc: {scores[2]:.4f}")
print(f"Document: {doc_irrelevant[:60]}...")

---
## Part 4: Two-Stage Retrieval Pipeline

Now let's build a complete pipeline:
1. **Stage 1**: Fast retrieval with bi-encoder (get top 20)
2. **Stage 2**: Re-rank with cross-encoder (get top 5)

In [None]:
def retrieve_and_rerank(query: str, initial_k: int = 20, final_k: int = 5):
    """
    Two-stage retrieval with re-ranking.
    
    Args:
        query: Search query
        initial_k: Number of candidates from stage 1 (bi-encoder)
        final_k: Number of final results after re-ranking
    
    Returns:
        List of (document, cross_encoder_score) tuples
    """
    # STAGE 1: Initial retrieval (fast, bi-encoder)
    print(f"Stage 1: Retrieving top {initial_k} candidates...")
    initial_results = vectorstore.similarity_search(query, k=initial_k)
    
    # STAGE 2: Re-ranking (accurate, cross-encoder)
    print(f"Stage 2: Re-ranking with cross-encoder...")
    
    # Prepare pairs for cross-encoder
    pairs = [(query, doc.page_content) for doc in initial_results]
    
    # Get cross-encoder scores
    scores = cross_encoder.predict(pairs)
    
    # Combine documents with scores and sort
    scored_docs = list(zip(initial_results, scores))
    scored_docs.sort(key=lambda x: x[1], reverse=True)  # Sort by score descending
    
    # Return top final_k
    return scored_docs[:final_k]

print("Two-stage retrieval function ready!")

In [None]:
# Test the two-stage pipeline
query = "How does LoRA help with fine-tuning efficiency?"

print(f"Query: '{query}'")
print("="*80)

reranked_results = retrieve_and_rerank(query, initial_k=15, final_k=5)

print(f"\nTop {len(reranked_results)} Re-ranked Results:")
print("="*80)

for i, (doc, score) in enumerate(reranked_results, 1):
    print(f"\nRank {i} (Score: {score:.4f}):")
    print(f"{doc.page_content[:200]}...")
    print("-"*40)

---
## Part 5: Comparing Before and After Re-ranking

Let's see how re-ranking improves results.

In [None]:
def compare_with_without_reranking(query: str, k: int = 5):
    """
    Compare search results with and without re-ranking.
    """
    print(f"Query: '{query}'")
    print("\n" + "="*80)
    
    # WITHOUT re-ranking (just bi-encoder)
    print("\nWITHOUT Re-ranking (Bi-Encoder only):")
    print("-"*40)
    basic_results = vectorstore.similarity_search(query, k=k)
    for i, doc in enumerate(basic_results, 1):
        print(f"{i}. {doc.page_content[:100]}...")
    
    # WITH re-ranking
    print("\nWITH Re-ranking (Bi-Encoder + Cross-Encoder):")
    print("-"*40)
    reranked = retrieve_and_rerank(query, initial_k=20, final_k=k)
    for i, (doc, score) in enumerate(reranked, 1):
        print(f"{i}. (score: {score:.3f}) {doc.page_content[:80]}...")
    
    print("\n" + "="*80)

# Compare with a specific query
compare_with_without_reranking("What are the benefits of using LoRA for fine-tuning?")

In [None]:
# Try another query
compare_with_without_reranking("How does attention mechanism work in transformers?")

---
## Part 6: Complete RAG Pipeline with Re-ranking

In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

# Initialize LLM
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.7,
    api_key=os.environ["OPENAI_API_KEY"]
)

# Create prompt template
prompt_template = ChatPromptTemplate.from_template("""
You are a helpful AI assistant. Answer the question based on the provided context.
If the context doesn't contain the answer, say "I don't have enough information to answer that."

Context:
{context}

Question: {question}

Answer:
""")

print("LLM and prompt ready!")

In [None]:
def rag_with_reranking(question: str, initial_k: int = 20, final_k: int = 5):
    """
    Complete RAG pipeline with re-ranking.
    
    Steps:
    1. Retrieve initial candidates (bi-encoder)
    2. Re-rank candidates (cross-encoder)
    3. Generate answer (LLM)
    """
    print(f"Question: {question}")
    print("\n" + "="*80)
    
    # Stage 1 & 2: Retrieve and re-rank
    reranked_results = retrieve_and_rerank(question, initial_k, final_k)
    
    # Prepare context from top re-ranked documents
    context = "\n\n".join([doc.page_content for doc, _ in reranked_results])
    
    # Stage 3: Generate answer
    print("Stage 3: Generating answer with LLM...")
    
    # Format prompt
    prompt = prompt_template.format(context=context, question=question)
    
    # Get answer
    response = llm.invoke(prompt)
    
    # Display results
    print("\n" + "="*80)
    print("ANSWER:")
    print(response.content)
    
    print("\n" + "="*80)
    print(f"SOURCES (Top {final_k} re-ranked):")
    for i, (doc, score) in enumerate(reranked_results, 1):
        print(f"  {i}. (score: {score:.3f}) {doc.page_content[:80]}...")
    
    return response.content

print("RAG with re-ranking function ready!")

In [None]:
# Test the complete pipeline
answer = rag_with_reranking("What is LoRA and why is it useful for fine-tuning large models?")

In [None]:
# Try another question
answer = rag_with_reranking("Explain the difference between LoRA and QLoRA.")

---
## Part 7: Performance Considerations

In [None]:
import time

def measure_performance(query: str):
    """
    Measure time for each stage.
    """
    print(f"Query: '{query}'")
    print("\nPerformance Breakdown:")
    print("="*50)
    
    # Stage 1: Bi-encoder retrieval
    start = time.time()
    initial_results = vectorstore.similarity_search(query, k=20)
    stage1_time = time.time() - start
    print(f"Stage 1 (Bi-Encoder, 20 docs):    {stage1_time*1000:.1f} ms")
    
    # Stage 2: Cross-encoder re-ranking
    start = time.time()
    pairs = [(query, doc.page_content) for doc in initial_results]
    scores = cross_encoder.predict(pairs)
    stage2_time = time.time() - start
    print(f"Stage 2 (Cross-Encoder, 20 docs): {stage2_time*1000:.1f} ms")
    
    print(f"\nTotal retrieval time:             {(stage1_time + stage2_time)*1000:.1f} ms")
    print(f"\nNote: Cross-encoder is slower but more accurate!")
    print(f"      That's why we only re-rank a small candidate set.")

measure_performance("What is attention mechanism?")

---
## When to Use Re-ranking?

| Scenario | Re-ranking Helps? | Reason |
|----------|-------------------|--------|
| High accuracy needed | Yes | Cross-encoder is more accurate |
| Latency-sensitive | Maybe | Adds ~100-500ms per query |
| Complex queries | Yes | Better at understanding nuance |
| Simple keyword lookup | No | BM25/bi-encoder is sufficient |
| Large result set needed | No | Re-ranking doesn't help if you need 100+ results |

**Best practice:**
- Initial retrieval: 20-50 candidates
- Re-rank to: 3-10 final results
- More candidates = better recall but slower

---
## Summary

### What You've Learned:
1. **Bi-Encoder vs Cross-Encoder**: Trade-off between speed and accuracy
2. **Two-Stage Retrieval**: Fast initial retrieval + accurate re-ranking
3. **Cross-Encoder Usage**: How to score query-document pairs
4. **Complete Pipeline**: Retrieval → Re-ranking → Generation

### Key Takeaways:
- Re-ranking significantly improves result quality
- Use bi-encoder for initial retrieval (fast, scalable)
- Use cross-encoder for re-ranking (accurate, but slow)
- Always re-rank a small candidate set (20-50 docs)

### The Two-Stage Pattern:
```
All Documents (millions)
        ↓
  [Bi-Encoder]  ← Fast, approximate
        ↓
Top 20-50 Candidates
        ↓
  [Cross-Encoder] ← Slow, accurate
        ↓
Top 3-5 Results
        ↓
      [LLM]
        ↓
     Answer
```

### Next Up:
**Query Expansion** - Generate multiple queries from one question for better recall!