In [None]:
import torch
import numpy as np
from datasets import load_dataset
from colpali_engine.models import ColPali, ColPaliProcessor
from PIL import Image
from typing import List, Dict, Tuple
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

def compute_ndcg_at_k(relevance_scores: np.ndarray, k: int = 5) -> float:
    """Compute NDCG@K metric"""
    if len(relevance_scores) == 0:
        return 0.0
    
    k = min(k, len(relevance_scores))
    top_k_scores = relevance_scores[:k]
    
    dcg = top_k_scores[0] + np.sum(top_k_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    ideal_scores = np.sort(relevance_scores)[::-1][:k]
    idcg = ideal_scores[0] + np.sum(ideal_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    return dcg / idcg if idcg > 0 else 0.0

def compute_recall_at_k(ranked_indices: List[int], relevant_idx: int, k: int = 1) -> float:
    return 1.0 if relevant_idx in ranked_indices[:k] else 0.0

def compute_mrr(ranked_indices: List[int], relevant_idx: int) -> float:
    try:
        rank = ranked_indices.index(relevant_idx) + 1
        return 1.0 / rank
    except ValueError:
        return 0.0

class ArxivQADataset:
    """Wrapper for ArxivQA dataset from ViDoRe"""
    
    def __init__(self, split: str = "test"):
        print(f"Loading ArxivQA {split} split from HuggingFace...")
        self.dataset = load_dataset("vidore/arxivqa_test_subsampled", split=split)
        print(f"Loaded {len(self.dataset)} samples")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def get_queries_and_docs(self) -> Tuple[List[str], List[Image.Image], List[int]]:
        """Extract queries, document images, and relevance mappings"""
        queries = []
        doc_images = []
        query_to_doc = []
        
        for item in self.dataset:
            queries.append(item['query'])
            doc_images.append(item['image'])
            query_to_doc.append(len(doc_images) - 1)
        
        return queries, doc_images, query_to_doc

class ColPaliRetriever:
    """ColPali model with late interaction mechanism"""
    
    def __init__(self, model_name: str = "vidore/colpali-v1.2"):
        print(f"Loading ColPali model: {model_name}")
        self.processor = ColPaliProcessor.from_pretrained(model_name)
        self.model = ColPali.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.device = self.model.device
        self.model.eval()
        print(f"Model loaded on device: {self.device}")
    
    def encode_queries(self, queries: List[str], batch_size: int = 8) -> List[torch.Tensor]:
        """Encode text queries to multi-vector representations"""
        all_query_embeddings = []
        
        print(f"Encoding {len(queries)} queries...")
        for i in tqdm(range(0, len(queries), batch_size)):
            batch = queries[i:i+batch_size]
            
            inputs = self.processor.process_queries(batch).to(self.device)
            
            with torch.no_grad():
                query_embeddings = self.model(**inputs)
            
            all_query_embeddings.extend(list(query_embeddings))
        
        return all_query_embeddings
    
    def encode_images(self, images: List[Image.Image], batch_size: int = 4) -> List[torch.Tensor]:
        """Encode document images to multi-vector representations"""
        all_doc_embeddings = []
        
        print(f"Encoding {len(images)} document images...")
        for i in tqdm(range(0, len(images), batch_size)):
            batch = images[i:i+batch_size]
            
            inputs = self.processor.process_images(batch).to(self.device)
            
            with torch.no_grad():
                doc_embeddings = self.model(**inputs)
            
            all_doc_embeddings.extend(list(doc_embeddings))
        
        return all_doc_embeddings
    
    def late_interaction_score(self, query_embedding: torch.Tensor, 
                              doc_embedding: torch.Tensor) -> float:
        """Compute late interaction score between query and document"""

        similarity_matrix = torch.matmul(query_embedding, doc_embedding.T)  
        
        max_scores = torch.max(similarity_matrix, dim=1)[0]  
        
        score = torch.sum(max_scores).item()
        
        return score
    
    def retrieve(self, query_embeddings: List[torch.Tensor], 
                doc_embeddings: List[torch.Tensor], 
                top_k: int = 10) -> List[List[int]]:
        """Retrieve top-k documents for each query using late interaction"""
        rankings = []
        
        print(f"Computing late interaction scores for {len(query_embeddings)} queries...")
        for query_emb in tqdm(query_embeddings):
            scores = []
            for doc_emb in doc_embeddings:
                score = self.late_interaction_score(query_emb, doc_emb)
                scores.append(score)
            
            scores = np.array(scores)
            top_indices = np.argsort(scores)[::-1][:top_k]
            rankings.append(top_indices.tolist())
        
        return rankings

def evaluate_retrieval(rankings: List[List[int]], 
                      relevant_docs: List[int]) -> Dict[str, float]:
    """Evaluate retrieval performance"""
    ndcg_5_scores = []
    recall_1_scores = []
    recall_5_scores = []
    mrr_scores = []
    
    for ranking, relevant_idx in zip(rankings, relevant_docs):
        relevance = np.array([1 if idx == relevant_idx else 0 for idx in ranking])
        
        ndcg_5_scores.append(compute_ndcg_at_k(relevance, k=5))
        recall_1_scores.append(compute_recall_at_k(ranking, relevant_idx, k=1))
        recall_5_scores.append(compute_recall_at_k(ranking, relevant_idx, k=5))
        mrr_scores.append(compute_mrr(ranking, relevant_idx))
    
    return {
        "NDCG@5": np.mean(ndcg_5_scores) * 100,
        "Recall@1": np.mean(recall_1_scores) * 100,
        "Recall@5": np.mean(recall_5_scores) * 100,
        "MRR": np.mean(mrr_scores)
    }

def main():
    print("="*80)
    print("ColPali Evaluation on ArxivQA Dataset")
    print("="*80)
    
    dataset = ArxivQADataset(split="test")
    queries, doc_images, query_to_doc = dataset.get_queries_and_docs()
    
    print(f"\nDataset Statistics:")
    print(f"  Number of queries: {len(queries)}")
    print(f"  Number of documents: {len(doc_images)}")
    print(f"  Sample query: {queries[0][:100]}...")
    
    print("\n" + "="*80)
    print("Evaluating ColPali")
    print("="*80)
    
    colpali = ColPaliRetriever("vidore/colpali-v1.2")
    
    query_embeddings = colpali.encode_queries(queries, batch_size=8)
    
    doc_embeddings = colpali.encode_images(doc_images, batch_size=4)
    
    rankings = colpali.retrieve(query_embeddings, doc_embeddings, top_k=10)
    
    results = evaluate_retrieval(rankings, query_to_doc)
    
    print("\n" + "="*80)
    print("COLPALI RESULTS ON ARXIVQA DATASET")
    print("="*80)
    print(f"\nNDCG@5:   {results['NDCG@5']:.2f}")
    print(f"Recall@1: {results['Recall@1']:.2f}")
    print(f"Recall@5: {results['Recall@5']:.2f}")
    print(f"MRR:      {results['MRR']:.4f}")
    
    print("\n" + "="*80)
    print("Expected results from paper (Table 2):")
    print("  NDCG@5:   79.1")
    print("  Recall@1: 72.4")
    print("="*80)

if __name__ == "__main__":
    main()

In [None]:
!pip install torch colpali-engine datasets pillow numpy tqdm

!pip install torch --index-url https://download.pytorch.org/whl/cu118

In [None]:
"""
Two-Stage RAG System for ArxivQA Dataset
Stage 1: Retrieve relevant document
Stage 2: Retrieve relevant page from document
"""

import torch
import numpy as np
from datasets import load_dataset
from colpali_engine.models import ColPali, ColPaliProcessor
from PIL import Image
from typing import List, Dict, Tuple
from tqdm import tqdm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

def compute_ndcg_at_k(relevance_scores: np.ndarray, k: int = 5) -> float:
    """Compute NDCG@K metric"""
    if len(relevance_scores) == 0:
        return 0.0
    
    k = min(k, len(relevance_scores))
    top_k_scores = relevance_scores[:k]
    
    dcg = top_k_scores[0] + np.sum(top_k_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    ideal_scores = np.sort(relevance_scores)[::-1][:k]
    idcg = ideal_scores[0] + np.sum(ideal_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    return dcg / idcg if idcg > 0 else 0.0

def compute_recall_at_k(ranked_indices: List[int], relevant_idx: int, k: int = 1) -> float:
    """Compute Recall@K metric"""
    return 1.0 if relevant_idx in ranked_indices[:k] else 0.0

def compute_mrr(ranked_indices: List[int], relevant_idx: int) -> float:
    """Compute Mean Reciprocal Rank"""
    try:
        rank = ranked_indices.index(relevant_idx) + 1
        return 1.0 / rank
    except ValueError:
        return 0.0

class ArxivQADatasetWithDocs:
    """ArxivQA dataset organized by documents"""
    
    def __init__(self, split: str = "test"):
        print(f"Loading ArxivQA {split} split from HuggingFace...")
        self.dataset = load_dataset("vidore/arxivqa_test_subsampled", split=split)
        print(f"Loaded {len(self.dataset)} samples")
        
        self._organize_by_documents()
    
    def _organize_by_documents(self):
        """Group pages by their parent document"""
        print("Organizing pages by documents...")
        
        self.doc_to_pages = defaultdict(list)
        
        self.query_to_doc_page = {}
        
        self.documents = []
        
        PAGES_PER_DOC = 5  
        
        for idx, item in enumerate(self.dataset):
            doc_id = idx // PAGES_PER_DOC  
            page_in_doc = idx % PAGES_PER_DOC
            
            self.doc_to_pages[doc_id].append((idx, item['image']))
            
            self.query_to_doc_page[idx] = (doc_id, page_in_doc)
        
        self.document_ids = sorted(self.doc_to_pages.keys())
        
        print(f"Organized into {len(self.document_ids)} documents")
        print(f"Average pages per document: {len(self.dataset) / len(self.document_ids):.1f}")
    
    def get_queries(self) -> List[str]:
        """Get all queries"""
        return [item['query'] for item in self.dataset]
    
    def get_document_representations(self) -> Tuple[List[int], List[List[Image.Image]]]:
        """Get document IDs and their pages for indexing"""
        doc_ids = []
        doc_pages = []
        
        for doc_id in self.document_ids:
            doc_ids.append(doc_id)
            pages = [img for _, img in self.doc_to_pages[doc_id]]
            doc_pages.append(pages)
        
        return doc_ids, doc_pages
    
    def get_all_pages(self) -> List[Image.Image]:
        """Get all individual pages"""
        return [item['image'] for item in self.dataset]
    
    def get_ground_truth(self, query_idx: int) -> Tuple[int, int]:
        """Get ground truth (doc_id, page_idx_in_doc) for a query"""
        return self.query_to_doc_page[query_idx]
    
    def get_page_global_idx(self, doc_id: int, page_in_doc: int) -> int:
        """Convert (doc_id, page_in_doc) to global page index"""
        return self.doc_to_pages[doc_id][page_in_doc][0]

class TwoStageColPaliRAG:
    """Two-stage retrieval system using ColPali"""
    
    def __init__(self, model_name: str = "vidore/colpali-v1.2"):
        print(f"Loading ColPali model: {model_name}")
        self.processor = ColPaliProcessor.from_pretrained(model_name)
        self.model = ColPali.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.device = self.model.device
        self.model.eval()
        print(f"Model loaded on device: {self.device}")
        
        self.doc_embeddings = None  
        self.page_embeddings = None  
        self.doc_ids = None
        self.doc_to_pages_map = None
    
    def encode_queries(self, queries: List[str], batch_size: int = 8) -> List[torch.Tensor]:
        """Encode text queries"""
        all_embeddings = []
        
        print(f"Encoding {len(queries)} queries...")
        for i in tqdm(range(0, len(queries), batch_size)):
            batch = queries[i:i+batch_size]
            inputs = self.processor.process_queries(batch).to(self.device)
            
            with torch.no_grad():
                embeddings = self.model(**inputs)
            
            all_embeddings.extend(list(embeddings))
        
        return all_embeddings
    
    def encode_images(self, images: List[Image.Image], batch_size: int = 4) -> List[torch.Tensor]:
        """Encode images"""
        all_embeddings = []
        
        for i in tqdm(range(0, len(images), batch_size)):
            batch = images[i:i+batch_size]
            inputs = self.processor.process_images(batch).to(self.device)
            
            with torch.no_grad():
                embeddings = self.model(**inputs)
            
            all_embeddings.extend(list(embeddings))
        
        return all_embeddings
    
    def late_interaction_score(self, query_emb: torch.Tensor, doc_emb: torch.Tensor) -> float:
        """Compute late interaction score"""
        similarity_matrix = torch.matmul(query_emb, doc_emb.T)
        max_scores = torch.max(similarity_matrix, dim=1)[0]
        return torch.sum(max_scores).item()
    
    def index_documents(self, doc_ids: List[int], doc_pages: List[List[Image.Image]], 
                       batch_size: int = 4):
        """
        Stage 1 Indexing: Create document-level representations
        by aggregating/averaging page embeddings per document
        """
        print("\n" + "="*80)
        print("STAGE 1: Indexing Documents")
        print("="*80)
        
        self.doc_ids = doc_ids
        self.doc_embeddings = []
        self.doc_to_pages_map = {}
        
        for doc_id, pages in tqdm(zip(doc_ids, doc_pages), total=len(doc_ids), 
                                   desc="Encoding documents"):
            page_embeddings = self.encode_images(pages, batch_size=batch_size)
            
            stacked_pages = torch.stack(page_embeddings)  
            doc_embedding = torch.mean(stacked_pages, dim=0)  
            
            self.doc_embeddings.append(doc_embedding)
            self.doc_to_pages_map[doc_id] = page_embeddings
        
        print(f"Indexed {len(self.doc_embeddings)} documents")
    
    def index_all_pages(self, all_pages: List[Image.Image], batch_size: int = 4):
        """
        Stage 2 Indexing: Index all individual pages
        (Alternative: only index pages from retrieved documents)
        """
        print("\n" + "="*80)
        print("STAGE 2: Indexing All Pages")
        print("="*80)
        
        print("Encoding all pages...")
        self.page_embeddings = self.encode_images(all_pages, batch_size=batch_size)
        print(f"Indexed {len(self.page_embeddings)} pages")
    
    def retrieve_document(self, query_embedding: torch.Tensor, top_k: int = 5) -> List[int]:
        """
        Stage 1: Retrieve top-k documents for the query
        """
        scores = []
        for doc_emb in self.doc_embeddings:
            score = self.late_interaction_score(query_embedding, doc_emb)
            scores.append(score)
        
        scores = np.array(scores)
        top_doc_indices = np.argsort(scores)[::-1][:top_k]
        
        return [self.doc_ids[idx] for idx in top_doc_indices]
    
    def retrieve_page_from_document(self, query_embedding: torch.Tensor, 
                                    doc_id: int, top_k: int = 5) -> List[int]:
        """
        Stage 2: Retrieve top-k pages from a specific document
        """
        page_embeddings = self.doc_to_pages_map[doc_id]
        
        scores = []
        for page_emb in page_embeddings:
            score = self.late_interaction_score(query_embedding, page_emb)
            scores.append(score)
        
        scores = np.array(scores)
        top_page_indices = np.argsort(scores)[::-1][:top_k]
        
        return top_page_indices.tolist()
    
    def retrieve_two_stage(self, query_embeddings: List[torch.Tensor],
                          top_k_docs: int = 3, top_k_pages: int = 5) -> List[List[Tuple[int, int]]]:
        """
        Two-stage retrieval:
        1. Retrieve top-k documents
        2. For each document, retrieve top-k pages
        Returns list of (doc_id, page_in_doc) tuples for each query
        """
        print("\n" + "="*80)
        print("TWO-STAGE RETRIEVAL")
        print("="*80)
        
        all_results = []
        
        for query_emb in tqdm(query_embeddings, desc="Retrieving"):
            top_docs = self.retrieve_document(query_emb, top_k=top_k_docs)
            
            doc_page_scores = []
            
            for doc_id in top_docs:
                top_pages = self.retrieve_page_from_document(query_emb, doc_id, top_k=top_k_pages)
                
                for page_idx in top_pages:
                    page_emb = self.doc_to_pages_map[doc_id][page_idx]
                    score = self.late_interaction_score(query_emb, page_emb)
                    doc_page_scores.append(((doc_id, page_idx), score))
            
            doc_page_scores.sort(key=lambda x: x[1], reverse=True)
            top_results = [dp for dp, _ in doc_page_scores[:top_k_pages]]
            
            all_results.append(top_results)
        
        return all_results

def evaluate_two_stage(results: List[List[Tuple[int, int]]], 
                      ground_truth: List[Tuple[int, int]],
                      dataset: ArxivQADatasetWithDocs) -> Dict[str, float]:
    """Evaluate two-stage retrieval results"""
    
    ndcg_5_scores = []
    recall_1_scores = []
    recall_5_scores = []
    mrr_scores = []
    
    doc_recall_scores = []
    page_given_doc_recall_scores = []
    
    for result_list, (gt_doc_id, gt_page_in_doc) in zip(results, ground_truth):

        retrieved_global_indices = []
        for doc_id, page_in_doc in result_list:
            global_idx = dataset.get_page_global_idx(doc_id, page_in_doc)
            retrieved_global_indices.append(global_idx)
        
        gt_global_idx = dataset.get_page_global_idx(gt_doc_id, gt_page_in_doc)
        
        relevance = np.array([1 if idx == gt_global_idx else 0 for idx in retrieved_global_indices])
        
        ndcg_5_scores.append(compute_ndcg_at_k(relevance, k=5))
        recall_1_scores.append(compute_recall_at_k(retrieved_global_indices, gt_global_idx, k=1))
        recall_5_scores.append(compute_recall_at_k(retrieved_global_indices, gt_global_idx, k=5))
        mrr_scores.append(compute_mrr(retrieved_global_indices, gt_global_idx))
        
        retrieved_docs = [doc_id for doc_id, _ in result_list]
        doc_correct = 1.0 if gt_doc_id in retrieved_docs else 0.0
        doc_recall_scores.append(doc_correct)

        if doc_correct:
            pages_from_correct_doc = [page for doc, page in result_list if doc == gt_doc_id]
            page_correct = 1.0 if gt_page_in_doc in pages_from_correct_doc else 0.0
            page_given_doc_recall_scores.append(page_correct)
    
    return {
        "NDCG@5": np.mean(ndcg_5_scores) * 100,
        "Recall@1": np.mean(recall_1_scores) * 100,
        "Recall@5": np.mean(recall_5_scores) * 100,
        "MRR": np.mean(mrr_scores),
        "Doc_Recall": np.mean(doc_recall_scores) * 100,
        "Page_Given_Doc_Recall": np.mean(page_given_doc_recall_scores) * 100 if page_given_doc_recall_scores else 0.0,
    }

def main():
    print("="*80)
    print("TWO-STAGE RAG SYSTEM - ArxivQA Evaluation")
    print("="*80)
    
    dataset = ArxivQADatasetWithDocs(split="test")
    
    queries = dataset.get_queries()
    doc_ids, doc_pages = dataset.get_document_representations()
    all_pages = dataset.get_all_pages()
    
    print(f"\nDataset Statistics:")
    print(f"  Total queries: {len(queries)}")
    print(f"  Total documents: {len(doc_ids)}")
    print(f"  Total pages: {len(all_pages)}")

    rag = TwoStageColPaliRAG("vidore/colpali-v1.2")

    rag.index_documents(doc_ids, doc_pages, batch_size=4)

    print("\nEncoding queries...")
    query_embeddings = rag.encode_queries(queries, batch_size=8)

    results = rag.retrieve_two_stage(query_embeddings, top_k_docs=3, top_k_pages=5)

    ground_truth = [dataset.get_ground_truth(i) for i in range(len(queries))]

    print("\n" + "="*80)
    print("EVALUATION RESULTS")
    print("="*80)
    
    metrics = evaluate_two_stage(results, ground_truth, dataset)
    
    print("\nStandard Retrieval Metrics:")
    print(f"  NDCG@5:   {metrics['NDCG@5']:.2f}")
    print(f"  Recall@1: {metrics['Recall@1']:.2f}")
    print(f"  Recall@5: {metrics['Recall@5']:.2f}")
    print(f"  MRR:      {metrics['MRR']:.4f}")
    
    print("\nTwo-Stage Specific Metrics:")
    print(f"  Document Recall:           {metrics['Doc_Recall']:.2f}%")
    print(f"  Page Recall (given doc):   {metrics['Page_Given_Doc_Recall']:.2f}%")
    
    print("\n" + "="*80)
    print("Comparison to Single-Stage ColPali (from paper):")
    print("  Expected NDCG@5:   79.1")
    print("  Expected Recall@1: 72.4")
    print("="*80)

if __name__ == "__main__":
    main()

In [None]:
import json
import requests
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm
import time
import os


class OpenRAGBenchLoader:
    """
    Load and process vectara/open_ragbench dataset for evaluation
    """
    
    def __init__(self, dataset_dir="./open_ragbench"):
        """
        Args:
            dataset_dir: Where to store downloaded dataset
        """
        self.dataset_dir = Path(dataset_dir)
        self.dataset_dir.mkdir(parents=True, exist_ok=True)
        
        self.base_url = "https://huggingface.co/datasets/vectara/open_ragbench/raw/main"
        
        self.queries = None
        self.answers = None
        self.corpus = {}
        self.qrels = None
        self.pdf_urls = None
    

    def download_dataset(self, max_queries=None):
        """
        Download the Open RAGBench dataset (metadata + PDFs)

        Args:
            max_queries: number of queries to process (downloads PDFs for these queries)
        """
        
        files_to_download = {
            'queries.json': f"{self.base_url}/pdf/arxiv/queries.json",
            'answers.json': f"{self.base_url}/pdf/arxiv/answers.json",
            'pdf_urls.json': f"{self.base_url}/pdf/arxiv/pdf_urls.json",
            'qrels.json': f"{self.base_url}/pdf/arxiv/qrels.json",
        }
        
        print("Downloading Open RAGBench dataset...")

        for filename, url in files_to_download.items():
            filepath = self.dataset_dir / filename
            
            if filepath.exists():
                print(f"  âœ“ {filename} already exists")
                continue
            
            print(f"  Downloading {filename}...")
            response = requests.get(url)
            if response.status_code == 200:
                with open(filepath, 'w') as f:
                    f.write(response.text)
            else:
                print(f"  âœ— Failed to download {filename}: {response.status_code}")
        
        print("\n  Downloading PDFs...")
        self._download_pdfs_for_queries(max_queries=max_queries)
        
        print("\nâœ“ Dataset download complete!")
    
    def _download_pdfs_for_queries(self, max_queries=None, output_dir="./arxiv_pdfs"):
        """
        Download PDFs required for the first N queries.
        
        Args:
            max_queries: Number of queries to process
            output_dir: Where to save PDFs
        """
        pdf_urls_path = self.dataset_dir / "pdf_urls.json"
        qrels_path = self.dataset_dir / "qrels.json"
        
        if not pdf_urls_path.exists() or not qrels_path.exists():
            print("  âœ— Missing pdf_urls.json or qrels.json")
            return
        
        with open(pdf_urls_path, 'r') as f:
            pdf_urls = json.load(f)
        
        with open(qrels_path, 'r') as f:
            qrels = json.load(f)
        
        query_ids = list(qrels.keys())
        if max_queries is not None:
            query_ids = query_ids[:max_queries]
        
        arxiv_ids = set()
        for query_id in query_ids:
            docs = qrels[query_id]
            if isinstance(docs, dict) and 'doc_id' in docs:
                arxiv_ids.add(docs['doc_id'])
            elif isinstance(docs, dict):
                for doc_id in docs.keys():
                    if '_page_' in doc_id:
                        arxiv_id = doc_id.split('_page_')[0]
                        arxiv_ids.add(arxiv_id)
        
        arxiv_ids = list(arxiv_ids)
        
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        print(f"  Downloading {len(arxiv_ids)} PDFs for {len(query_ids)} queries...")
        
        for arxiv_id in tqdm(arxiv_ids, desc="  Downloading PDFs"):
            pdf_path = output_path / f"{arxiv_id}.pdf"
            
            if pdf_path.exists():
                continue
            
            if arxiv_id not in pdf_urls:
                print(f"    âœ— No URL found for {arxiv_id}")
                continue
            
            url = pdf_urls[arxiv_id]
            
            try:
                response = requests.get(url, timeout=20)
                if response.status_code == 200:
                    with open(pdf_path, 'wb') as f:
                        f.write(response.content)
                else:
                    print(f"    âœ— Failed for {arxiv_id}: HTTP {response.status_code}")
            except Exception as e:
                print(f"    âœ— Error downloading {arxiv_id}: {e}")
            
            time.sleep(1)
        
        print(f"  âœ“ PDFs stored in {output_dir}/")
    

    def load_dataset(self):
        """Load downloaded dataset into memory."""
        
        print("Loading Open RAGBench dataset...")
        

        queries_path = self.dataset_dir / "queries.json"
        if queries_path.exists():
            with open(queries_path, 'r') as f:
                self.queries = json.load(f)
            print(f"  âœ“ Loaded {len(self.queries)} queries")
        

        answers_path = self.dataset_dir / "answers.json"
        if answers_path.exists():
            with open(answers_path, 'r') as f:
                self.answers = json.load(f)
            print(f"  âœ“ Loaded {len(self.answers)} answers")
        

        qrels_path = self.dataset_dir / "qrels.json"
        if qrels_path.exists():
            with open(qrels_path, 'r') as f:
                self.qrels = json.load(f)
            print(f"  âœ“ Loaded relevance judgments for {len(self.qrels)} queries")
        

        corpus_dir = self.dataset_dir / "corpus"
        if corpus_dir.exists():
            for corpus_file in corpus_dir.glob("*.json"):
                with open(corpus_file, 'r') as f:
                    data = json.load(f)
                    self.corpus.update(data)
            print(f"  âœ“ Loaded {len(self.corpus)} corpus documents")
        
        print("âœ“ Dataset loaded!\n")
    
    def get_ground_truth_pages(self, query_id):
        """Extract (arxiv_id, page_num) pairs from qrels."""
        
        if query_id not in self.qrels:
            return []
        
        ground_truth = []
        docs = self.qrels[query_id]
        
        if isinstance(docs, dict) and 'doc_id' in docs:

            arxiv_id = docs['doc_id']
            section_id = docs['section_id']
            ground_truth.append((arxiv_id, section_id))
        elif isinstance(docs, dict):

            for doc_id in docs.keys():
                if '_page_' in doc_id:
                    arxiv_id, page = doc_id.split('_page_')
                    ground_truth.append((arxiv_id, int(page)))
        
        return ground_truth
    
    def convert_to_evaluation_format(self, max_samples=None):
        """Convert dataset into format used by RAGEvaluator."""
        
        if not self.queries or not self.answers:
            raise ValueError("Dataset not loaded. Call load_dataset() first.")
        
        evaluation_dataset = []
        query_ids = list(self.queries.keys())
        
        if max_samples:
            query_ids = query_ids[:max_samples]
        
        for query_id in query_ids:
            query_data = self.queries[query_id]
            ground_truth = self.get_ground_truth_pages(query_id)
            
            evidence_pages = [
                f"{arxiv_id}_page_{page_num:03d}"
                for arxiv_id, page_num in ground_truth
            ]
            
            evaluation_dataset.append({
                "query_id": query_id,
                "question": query_data["query"],
                "answer": self.answers.get(query_id, ""),
                "evidence_pages": evidence_pages,
                "evidence_type": query_data.get("source", "unknown"),
                "query_type": query_data.get("type", "unknown"),
                "page_count": "multi" if len(evidence_pages) > 1 else "single"
            })
        
        print(f"âœ“ Converted {len(evaluation_dataset)} queries to evaluation format")
        return evaluation_dataset

    def download_pdfs(self, output_dir="./arxiv_pdfs", max_pdfs=None):
        """
        Download PDF files using pdf_urls.json (standalone method).
        
        Args:
            output_dir: Where to save PDFs
            max_pdfs: Maximum number of PDFs to download
        """
        
        pdf_urls_path = self.dataset_dir / "pdf_urls.json"
        
        if not pdf_urls_path.exists():
            print("âœ— pdf_urls.json not found. Run download_dataset() first.")
            return
        
        with open(pdf_urls_path, 'r') as f:
            pdf_urls = json.load(f)
        
        arxiv_ids = list(pdf_urls.keys())
        
        if max_pdfs is not None:
            arxiv_ids = arxiv_ids[:max_pdfs]
        
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        print(f"Downloading {len(arxiv_ids)} PDFs...")
        
        for arxiv_id in tqdm(arxiv_ids, desc="Downloading PDFs"):
            pdf_path = output_path / f"{arxiv_id}.pdf"
            
            if pdf_path.exists():
                continue
            
            url = pdf_urls[arxiv_id]
            
            try:
                response = requests.get(url, timeout=20)
                if response.status_code == 200:
                    with open(pdf_path, 'wb') as f:
                        f.write(response.content)
                else:
                    print(f"  âœ— Failed for {arxiv_id}: HTTP {response.status_code}")
            except Exception as e:
                print(f"  âœ— Error downloading {arxiv_id}: {e}")
            
            time.sleep(1)
        
        print(f"âœ“ PDFs stored in {output_dir}/")

In [None]:
"""
Two-Stage RAG System for Open RAGBench Dataset
Stage 1: Retrieve relevant PDF documents
Stage 2: Retrieve relevant pages from documents
"""

import torch
import numpy as np
from colpali_engine.models import ColPali, ColPaliProcessor
from PIL import Image
from typing import List, Dict, Tuple
from tqdm import tqdm
from collections import defaultdict
import warnings
import json
from pathlib import Path
import fitz
import os

warnings.filterwarnings('ignore')

fitz.TOOLS.mupdf_display_errors(False)

import sys
sys.path.append('.')


def compute_ndcg_at_k(relevance_scores: np.ndarray, k: int = 5) -> float:
    if len(relevance_scores) == 0:
        return 0.0
    
    k = min(k, len(relevance_scores))
    top_k_scores = relevance_scores[:k]
    
    dcg = top_k_scores[0] + np.sum(top_k_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    ideal_scores = np.sort(relevance_scores)[::-1][:k]
    idcg = ideal_scores[0] + np.sum(ideal_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    return dcg / idcg if idcg > 0 else 0.0

def compute_recall_at_k(ranked_indices: List[int], relevant_idx: int, k: int = 1) -> float:
    return 1.0 if relevant_idx in ranked_indices[:k] else 0.0

def compute_mrr(ranked_indices: List[int], relevant_idx: int) -> float:
    try:
        rank = ranked_indices.index(relevant_idx) + 1
        return 1.0 / rank
    except ValueError:
        return 0.0


class PDFProcessor:
    
    def __init__(self, pdf_dir: str = "./arxiv_pdfs", dpi: int = 144):
        self.pdf_dir = Path(pdf_dir)
        self.dpi = dpi
        self.pdf_cache = {}
    
    def pdf_to_images(self, arxiv_id: str) -> List[Image.Image]:
        if arxiv_id in self.pdf_cache:
            return self.pdf_cache[arxiv_id]
        
        pdf_path = self.pdf_dir / f"{arxiv_id}.pdf"
        
        if not pdf_path.exists():
            return []
        
        try:
            doc = fitz.open(pdf_path)
            images = []
            
            for page_num in range(len(doc)):
                try:
                    page = doc[page_num]
                    pix = page.get_pixmap(dpi=self.dpi)
                    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                    images.append(img)
                except Exception as page_error:
                    print(f"  Warning: Skipping page {page_num} of {arxiv_id}: {page_error}")
                    continue
            
            doc.close()
            
            if len(images) > 0:
                self.pdf_cache[arxiv_id] = images
            
            return images
        
        except Exception as e:
            print(f"  Error processing PDF {arxiv_id}: {str(e)[:100]}")
            return []

class OpenRAGBenchDataset:
    
    def __init__(self, loader: OpenRAGBenchLoader, pdf_processor: PDFProcessor, max_queries: int = None):
        self.loader = loader
        self.pdf_processor = pdf_processor
        self.max_queries = max_queries
        
        self._organize_data()
    
    def _organize_data(self):
        print("Organizing Open RAGBench dataset...")
        
        self.queries = []
        self.query_ids = []
        self.ground_truth = []
        
        all_query_ids = list(self.loader.queries.keys())
        
        if self.max_queries is not None:
            all_query_ids = all_query_ids[:self.max_queries]
            print(f"Limiting to first {self.max_queries} queries")
        
        for query_id in all_query_ids:
            query_data = self.loader.queries[query_id]
            self.query_ids.append(query_id)
            self.queries.append(query_data['query'])
            
            gt_pages = self.loader.get_ground_truth_pages(query_id)
            if gt_pages:
                self.ground_truth.append(gt_pages[0])
            else:
                self.ground_truth.append((None, None))
        
        print(f"Processing {len(self.queries)} queries")
        
        arxiv_ids = set()
        for gt in self.ground_truth:
            if gt[0] is not None:
                arxiv_ids.add(gt[0])
        
        available_pdfs = set()
        pdf_dir = Path(self.pdf_processor.pdf_dir)
        for pdf_file in pdf_dir.glob("*.pdf"):
            arxiv_id = pdf_file.stem
            available_pdfs.add(arxiv_id)
        
        print(f"Found {len(available_pdfs)} available PDFs")
        print(f"Need {len(arxiv_ids)} documents for these queries")
        
        arxiv_ids = arxiv_ids.intersection(available_pdfs)
        self.arxiv_ids = sorted(list(arxiv_ids))
        
        filtered_queries = []
        filtered_query_ids = []
        filtered_ground_truth = []
        
        for i, (qid, query, gt) in enumerate(zip(self.query_ids, self.queries, self.ground_truth)):
            if gt[0] in arxiv_ids:
                filtered_queries.append(query)
                filtered_query_ids.append(qid)
                filtered_ground_truth.append(gt)
        
        self.queries = filtered_queries
        self.query_ids = filtered_query_ids
        self.ground_truth = filtered_ground_truth
        
        print(f"âœ“ Filtered to {len(self.queries)} queries with available documents")
        
        print(f"Loading pages from {len(self.arxiv_ids)} documents...")
        self.doc_to_pages = {}
        failed_pdfs = []
        
        for arxiv_id in tqdm(self.arxiv_ids, desc="Converting PDFs to images"):
            images = self.pdf_processor.pdf_to_images(arxiv_id)
            if len(images) > 0:
                self.doc_to_pages[arxiv_id] = images
            else:
                failed_pdfs.append(arxiv_id)
        
        if failed_pdfs:
            print(f"âš  Warning: {len(failed_pdfs)} PDFs failed to load")
            filtered_queries = []
            filtered_query_ids = []
            filtered_ground_truth = []
            
            for i, (qid, query, gt) in enumerate(zip(self.query_ids, self.queries, self.ground_truth)):
                if gt[0] not in failed_pdfs:
                    filtered_queries.append(query)
                    filtered_query_ids.append(qid)
                    filtered_ground_truth.append(gt)
            
            self.queries = filtered_queries
            self.query_ids = filtered_query_ids
            self.ground_truth = filtered_ground_truth
        
        print(f"âœ“ Successfully loaded {len(self.queries)} queries")
        print(f"âœ“ Successfully loaded {len(self.doc_to_pages)} documents")
        total_pages = sum(len(pages) for pages in self.doc_to_pages.values())
        print(f"âœ“ Total pages: {total_pages}")
        avg_pages = total_pages / len(self.doc_to_pages) if self.doc_to_pages else 0
        print(f"âœ“ Average pages per document: {avg_pages:.1f}")

class TwoStageColPaliRAG:
    
    def __init__(self, model_name: str = "vidore/colpali-v1.2"):
        print(f"Loading ColPali model: {model_name}")
        self.processor = ColPaliProcessor.from_pretrained(model_name)
        self.model = ColPali.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.device = self.model.device
        self.model.eval()
        print(f"Model loaded on device: {self.device}")
        
        self.doc_embeddings = {}
        self.page_embeddings = {}
    
    def encode_queries(self, queries: List[str], batch_size: int = 8) -> List[torch.Tensor]:
        all_embeddings = []
        
        print(f"Encoding {len(queries)} queries...")
        for i in tqdm(range(0, len(queries), batch_size)):
            batch = queries[i:i+batch_size]
            inputs = self.processor.process_queries(batch).to(self.device)
            
            with torch.no_grad():
                embeddings = self.model(**inputs)
            
            all_embeddings.extend(list(embeddings))
        
        return all_embeddings
    
    def encode_images(self, images: List[Image.Image], batch_size: int = 4) -> List[torch.Tensor]:
        all_embeddings = []
        
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            inputs = self.processor.process_images(batch).to(self.device)
            
            with torch.no_grad():
                embeddings = self.model(**inputs)
            
            all_embeddings.extend(list(embeddings))
        
        return all_embeddings
    
    def late_interaction_score(self, query_emb: torch.Tensor, doc_emb: torch.Tensor) -> float:
        similarity_matrix = torch.matmul(query_emb, doc_emb.T)
        max_scores = torch.max(similarity_matrix, dim=1)[0]
        return torch.sum(max_scores).item()
    
    def index_documents(self, arxiv_ids: List[str], doc_pages: Dict[str, List[Image.Image]], 
                       batch_size: int = 4, aggregation: str = "max"):
        print("\n" + "="*80)
        print("STAGE 1: Indexing Documents")
        print(f"Aggregation strategy: {aggregation}")
        print("="*80)
        
        for arxiv_id in tqdm(arxiv_ids, desc="Encoding documents"):
            pages = doc_pages[arxiv_id]
            
            if len(pages) == 0:
                continue
            
            page_embeddings = self.encode_images(pages, batch_size=batch_size)
            
            self.page_embeddings[arxiv_id] = page_embeddings
            
            stacked = torch.stack(page_embeddings)
            
            if aggregation == "mean":
                doc_embedding = torch.mean(stacked, dim=0)
            elif aggregation == "max":
                doc_embedding = torch.max(stacked, dim=0)[0]
            elif aggregation == "weighted":
                weights = torch.softmax(torch.linspace(1.0, 0.5, len(pages)), dim=0)
                weights = weights.view(-1, 1, 1).to(stacked.device)
                doc_embedding = torch.sum(stacked * weights, dim=0)
            else:
                doc_embedding = torch.mean(stacked, dim=0)
            
            self.doc_embeddings[arxiv_id] = doc_embedding
        
        print(f"âœ“ Indexed {len(self.doc_embeddings)} documents")
    
    def retrieve_document(self, query_embedding: torch.Tensor, top_k: int = 5) -> List[str]:
        scores = {}
        
        for arxiv_id, doc_emb in self.doc_embeddings.items():
            score = self.late_interaction_score(query_embedding, doc_emb)
            scores[arxiv_id] = score
        
        sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return [arxiv_id for arxiv_id, _ in sorted_docs[:top_k]]
    
    def retrieve_page_from_document(self, query_embedding: torch.Tensor, 
                                    arxiv_id: str, top_k: int = 5) -> List[int]:
        if arxiv_id not in self.page_embeddings:
            return []
        
        page_embeddings = self.page_embeddings[arxiv_id]
        scores = []
        
        for page_emb in page_embeddings:
            score = self.late_interaction_score(query_embedding, page_emb)
            scores.append(score)
        
        scores = np.array(scores)
        top_pages = np.argsort(scores)[::-1][:top_k]
        
        return top_pages.tolist()
    
    def retrieve_two_stage(self, query_embeddings: List[torch.Tensor],
                          top_k_docs: int = 5, top_k_pages: int = 5,
                          rerank: bool = True) -> List[List[Tuple[str, int]]]:
        print("\n" + "="*80)
        print("TWO-STAGE RETRIEVAL")
        print(f"Stage 1: Retrieving top {top_k_docs} documents")
        print(f"Stage 2: Retrieving top {top_k_pages} pages")
        print(f"Re-ranking: {rerank}")
        print("="*80)
        
        all_results = []
        
        for query_emb in tqdm(query_embeddings, desc="Retrieving"):
            top_docs = self.retrieve_document(query_emb, top_k=top_k_docs)
            
            if rerank:
                all_candidates = []
                
                for arxiv_id in top_docs:
                    if arxiv_id not in self.page_embeddings:
                        continue
                    
                    page_embeddings = self.page_embeddings[arxiv_id]
                    
                    for page_num, page_emb in enumerate(page_embeddings):
                        score = self.late_interaction_score(query_emb, page_emb)
                        all_candidates.append(((arxiv_id, page_num), score))
                
                all_candidates.sort(key=lambda x: x[1], reverse=True)
                top_results = [doc_page for doc_page, _ in all_candidates[:top_k_pages]]
                
            else:
                doc_page_scores = []
                
                for arxiv_id in top_docs:
                    top_pages = self.retrieve_page_from_document(query_emb, arxiv_id, top_k=top_k_pages)
                    
                    for page_num in top_pages:
                        page_emb = self.page_embeddings[arxiv_id][page_num]
                        score = self.late_interaction_score(query_emb, page_emb)
                        doc_page_scores.append(((arxiv_id, page_num), score))
                
                doc_page_scores.sort(key=lambda x: x[1], reverse=True)
                top_results = [dp for dp, _ in doc_page_scores[:top_k_pages]]
            
            all_results.append(top_results)
        
        return all_results


def evaluate_two_stage(results: List[List[Tuple[str, int]]], 
                      ground_truth: List[Tuple[str, int]]) -> Dict[str, float]:
    
    ndcg_5_scores = []
    recall_1_scores = []
    recall_5_scores = []
    mrr_scores = []
    
    doc_recall_scores = []
    page_given_doc_recall = []
    end_to_end_success = []
    
    for result_list, (gt_arxiv_id, gt_page_num) in zip(results, ground_truth):
        if gt_arxiv_id is None:
            continue
        
        retrieved_ids = [f"{arxiv_id}_page_{page_num}" for arxiv_id, page_num in result_list]
        gt_id = f"{gt_arxiv_id}_page_{gt_page_num}"
        
        relevance = np.array([1 if rid == gt_id else 0 for rid in retrieved_ids])
        
        ndcg_5_scores.append(compute_ndcg_at_k(relevance, k=5))
        recall_1_scores.append(1.0 if gt_id in retrieved_ids[:1] else 0.0)
        recall_5_scores.append(1.0 if gt_id in retrieved_ids[:5] else 0.0)
        
        try:
            rank = retrieved_ids.index(gt_id) + 1
            mrr_scores.append(1.0 / rank)
        except ValueError:
            mrr_scores.append(0.0)
        
        retrieved_docs = [arxiv_id for arxiv_id, _ in result_list]
        doc_correct = 1.0 if gt_arxiv_id in retrieved_docs else 0.0
        doc_recall_scores.append(doc_correct)
        
        page_correct_given_doc = 0.0
        if doc_correct:
            pages_from_correct_doc = [page for arxiv, page in result_list if arxiv == gt_arxiv_id]
            page_correct_given_doc = 1.0 if gt_page_num in pages_from_correct_doc else 0.0
            page_given_doc_recall.append(page_correct_given_doc)
        
        end_to_end_success.append(1.0 if (gt_arxiv_id, gt_page_num) in result_list else 0.0)
    
    return {
        "NDCG@5": np.mean(ndcg_5_scores) * 100,
        "Recall@1": np.mean(recall_1_scores) * 100,
        "Recall@5": np.mean(recall_5_scores) * 100,
        "MRR": np.mean(mrr_scores),
        "Doc_Recall@5": np.mean(doc_recall_scores) * 100,
        "Page_Given_Doc_Recall": np.mean(page_given_doc_recall) * 100 if page_given_doc_recall else 0.0,
        "End_to_End_Success": np.mean(end_to_end_success) * 100,
        "Stage1_Stage2_Product": np.mean(doc_recall_scores) * np.mean(page_given_doc_recall) if page_given_doc_recall else 0.0
    }


print("="*80)
print("TWO-STAGE RAG SYSTEM - Open RAGBench Evaluation")
print("="*80)

MAX_QUERIES = 50
PDF_DIR = "./arxiv_pdfs"

TOP_K_DOCS = 20
TOP_K_PAGES = 20
AGGREGATION = "max"
RERANK = True

print("\n" + "="*80)
print("Loading Open RAGBench Dataset")
print("="*80)

loader = OpenRAGBenchLoader()

loader.download_dataset(max_queries=MAX_QUERIES)
loader.load_dataset()

pdf_processor = PDFProcessor(pdf_dir=PDF_DIR)

colpali_dataset = OpenRAGBenchDataset(loader, pdf_processor, max_queries=MAX_QUERIES)

print(f"\nDataset Statistics:")
print(f"  Queries: {len(colpali_dataset.queries)}")
print(f"  Documents: {len(colpali_dataset.arxiv_ids)}")
total_pages = sum(len(pages) for pages in colpali_dataset.doc_to_pages.values())
print(f"  Total pages: {total_pages}")
print(f"  Avg pages/doc: {total_pages/len(colpali_dataset.arxiv_ids):.1f}")

print("\n" + "="*80)
print("Initializing ColPali RAG System")
print("="*80)

colpali_rag = TwoStageColPaliRAG("vidore/colpali-v1.2")

colpali_rag.index_documents(colpali_dataset.arxiv_ids, colpali_dataset.doc_to_pages, 
                   batch_size=4, aggregation=AGGREGATION)

query_embeddings = colpali_rag.encode_queries(colpali_dataset.queries, batch_size=8)

results = colpali_rag.retrieve_two_stage(query_embeddings, 
                                 top_k_docs=TOP_K_DOCS, 
                                 top_k_pages=TOP_K_PAGES,
                                 rerank=RERANK)

print("\n" + "="*80)
print("EVALUATION RESULTS")
print("="*80)

metrics = evaluate_two_stage(results, colpali_dataset.ground_truth)

print("\n Standard Retrieval Metrics:")
print(f"  NDCG@5:   {metrics['NDCG@5']:.2f}%")
print(f"  Recall@1: {metrics['Recall@1']:.2f}%")
print(f"  Recall@5: {metrics['Recall@5']:.2f}%")
print(f"  MRR:      {metrics['MRR']:.4f}")

print("\n Two-Stage Breakdown:")
print(f"  Stage 1 - Document Recall@5:       {metrics['Doc_Recall@5']:.2f}%")
print(f"  Stage 2 - Page Recall (given doc): {metrics['Page_Given_Doc_Recall']:.2f}%")
print(f"  End-to-End Success (exact match):  {metrics['End_to_End_Success']:.2f}%")
print(f"  Expected compound (Stage1 Ã— Stage2): {metrics['Stage1_Stage2_Product']:.2f}%")

print("\n Current Configuration:")
print(f"  â€¢ top_k_docs: {TOP_K_DOCS}")
print(f"  â€¢ top_k_pages: {TOP_K_PAGES}")
print(f"  â€¢ aggregation: {AGGREGATION}")
print(f"  â€¢ rerank: {RERANK}")

print("\n Tips to Improve Accuracy:")
print("  1. Increase TOP_K_DOCS (10â†’20) - casts wider net in Stage 1")
print("  2. Use aggregation='max' - captures most salient content")
print("  3. Enable rerank=True - global re-ranking of all candidates")
print("  4. Increase TOP_K_PAGES for evaluation metrics")

print("\n" + "="*80)

In [None]:
"""
Text-Based RAG Baseline for Open RAGBench
Uses OCR + text embeddings (traditional approach)
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from typing import List, Dict, Tuple
from tqdm import tqdm
from collections import defaultdict
import warnings
import json
from pathlib import Path
import fitz
import os

warnings.filterwarnings('ignore')
fitz.TOOLS.mupdf_display_errors(False)

import sys
sys.path.append('.')


def compute_ndcg_at_k(relevance_scores: np.ndarray, k: int = 5) -> float:
    if len(relevance_scores) == 0:
        return 0.0
    
    k = min(k, len(relevance_scores))
    top_k_scores = relevance_scores[:k]
    
    dcg = top_k_scores[0] + np.sum(top_k_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    ideal_scores = np.sort(relevance_scores)[::-1][:k]
    idcg = ideal_scores[0] + np.sum(ideal_scores[1:] / np.log2(np.arange(2, k + 1)))
    
    return dcg / idcg if idcg > 0 else 0.0

def compute_recall_at_k(ranked_indices: List[int], relevant_idx: int, k: int = 1) -> float:
    return 1.0 if relevant_idx in ranked_indices[:k] else 0.0

def compute_mrr(ranked_indices: List[int], relevant_idx: int) -> float:
    try:
        rank = ranked_indices.index(relevant_idx) + 1
        return 1.0 / rank
    except ValueError:
        return 0.0


class PDFTextExtractor:
    
    def __init__(self, pdf_dir: str = "./arxiv_pdfs"):
        self.pdf_dir = Path(pdf_dir)
        self.text_cache = {}
    
    def extract_text_from_pdf(self, arxiv_id: str) -> List[str]:
        if arxiv_id in self.text_cache:
            return self.text_cache[arxiv_id]
        
        pdf_path = self.pdf_dir / f"{arxiv_id}.pdf"
        
        if not pdf_path.exists():
            return []
        
        try:
            doc = fitz.open(pdf_path)
            page_texts = []
            
            for page_num in range(len(doc)):
                try:
                    page = doc[page_num]
                    text = page.get_text()
                    
                    text = text.strip()
                    if len(text) < 50:
                        text = "[Page contains minimal text or is image-based]"
                    
                    page_texts.append(text)
                
                except Exception as page_error:
                    page_texts.append("[Error extracting text from this page]")
                    continue
            
            doc.close()
            
            if len(page_texts) > 0:
                self.text_cache[arxiv_id] = page_texts
            
            return page_texts
        
        except Exception as e:
            print(f"  Error extracting text from {arxiv_id}: {str(e)[:100]}")
            return []


class TextBasedDataset:
    
    def __init__(self, loader: OpenRAGBenchLoader, text_extractor: PDFTextExtractor, max_queries: int = None):
        self.loader = loader
        self.text_extractor = text_extractor
        self.max_queries = max_queries
        
        self._organize_data()
    
    def _organize_data(self):
        print("Organizing dataset with text extraction...")
        
        self.queries = []
        self.query_ids = []
        self.ground_truth = []
        
        all_query_ids = list(self.loader.queries.keys())
        
        if self.max_queries is not None:
            all_query_ids = all_query_ids[:self.max_queries]
            print(f"Limiting to first {self.max_queries} queries")
        
        for query_id in all_query_ids:
            query_data = self.loader.queries[query_id]
            self.query_ids.append(query_id)
            self.queries.append(query_data['query'])
            
            gt_pages = self.loader.get_ground_truth_pages(query_id)
            if gt_pages:
                self.ground_truth.append(gt_pages[0])
            else:
                self.ground_truth.append((None, None))
        
        print(f"Processing {len(self.queries)} queries")
        
        arxiv_ids = set()
        for gt in self.ground_truth:
            if gt[0] is not None:
                arxiv_ids.add(gt[0])
        
        available_pdfs = set()
        pdf_dir = Path(self.text_extractor.pdf_dir)
        for pdf_file in pdf_dir.glob("*.pdf"):
            arxiv_id = pdf_file.stem
            available_pdfs.add(arxiv_id)
        
        print(f"Found {len(available_pdfs)} available PDFs")
        arxiv_ids = arxiv_ids.intersection(available_pdfs)
        self.arxiv_ids = sorted(list(arxiv_ids))
        
        filtered_queries = []
        filtered_query_ids = []
        filtered_ground_truth = []
        
        for qid, query, gt in zip(self.query_ids, self.queries, self.ground_truth):
            if gt[0] in arxiv_ids:
                filtered_queries.append(query)
                filtered_query_ids.append(qid)
                filtered_ground_truth.append(gt)
        
        self.queries = filtered_queries
        self.query_ids = filtered_query_ids
        self.ground_truth = filtered_ground_truth
        
        print(f"âœ“ Filtered to {len(self.queries)} queries with available documents")
        
        print(f"Extracting text from {len(self.arxiv_ids)} documents...")
        self.doc_to_page_texts = {}
        failed_docs = []
        
        for arxiv_id in tqdm(self.arxiv_ids, desc="Extracting text from PDFs"):
            page_texts = self.text_extractor.extract_text_from_pdf(arxiv_id)
            if len(page_texts) > 0:
                self.doc_to_page_texts[arxiv_id] = page_texts
            else:
                failed_docs.append(arxiv_id)
        
        if failed_docs:
            print(f"âš  Warning: {len(failed_docs)} PDFs failed text extraction")
            filtered_queries = []
            filtered_query_ids = []
            filtered_ground_truth = []
            
            for qid, query, gt in zip(self.query_ids, self.queries, self.ground_truth):
                if gt[0] not in failed_docs:
                    filtered_queries.append(query)
                    filtered_query_ids.append(qid)
                    filtered_ground_truth.append(gt)
            
            self.queries = filtered_queries
            self.query_ids = filtered_query_ids
            self.ground_truth = filtered_ground_truth
        
        total_pages = sum(len(texts) for texts in self.doc_to_page_texts.values())
        print(f"âœ“ Successfully loaded {len(self.queries)} queries")
        print(f"âœ“ Successfully extracted text from {len(self.doc_to_page_texts)} documents")
        print(f"âœ“ Total pages: {total_pages}")
        avg_pages = total_pages / len(self.doc_to_page_texts) if self.doc_to_page_texts else 0
        print(f"âœ“ Average pages per document: {avg_pages:.1f}")


class TextEmbeddingModel:
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        print(f"Loading text embedding model: {model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        print(f"Model loaded on device: {self.device}")
    
    def encode_texts(self, texts: List[str], batch_size: int = 32, max_length: int = 512) -> np.ndarray:
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            
            inputs = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt"
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                embeddings = outputs.last_hidden_state.mean(dim=1)
                all_embeddings.append(embeddings.cpu().numpy())
        
        return np.vstack(all_embeddings)
    
    def cosine_similarity(self, query_embeds: np.ndarray, doc_embeds: np.ndarray) -> np.ndarray:
        query_norm = query_embeds / (np.linalg.norm(query_embeds, axis=1, keepdims=True) + 1e-9)
        doc_norm = doc_embeds / (np.linalg.norm(doc_embeds, axis=1, keepdims=True) + 1e-9)
        return query_norm @ doc_norm.T


class TextBasedRAG:
    
    def __init__(self, embedding_model: TextEmbeddingModel):
        self.embedding_model = embedding_model
        self.doc_embeddings = {}
        self.page_embeddings = {}
        self.arxiv_ids = []
    
    def index_documents(self, arxiv_ids: List[str], doc_page_texts: Dict[str, List[str]],
                       batch_size: int = 32):
        print("\n" + "="*80)
        print("INDEXING DOCUMENTS (Text-Based)")
        print("="*80)
        
        self.arxiv_ids = arxiv_ids
        
        for arxiv_id in tqdm(arxiv_ids, desc="Encoding documents"):
            page_texts = doc_page_texts[arxiv_id]
            
            if len(page_texts) == 0:
                continue
            
            page_embeds = self.embedding_model.encode_texts(page_texts, batch_size=batch_size)
            self.page_embeddings[arxiv_id] = page_embeds
            
            doc_embed = np.mean(page_embeds, axis=0, keepdims=True)
            self.doc_embeddings[arxiv_id] = doc_embed
        
        print(f"âœ“ Indexed {len(self.doc_embeddings)} documents")
    
    def retrieve_document(self, query_embed: np.ndarray, top_k: int = 5) -> List[str]:
        scores = {}
        
        for arxiv_id, doc_embed in self.doc_embeddings.items():
            similarity = self.embedding_model.cosine_similarity(query_embed, doc_embed)
            scores[arxiv_id] = similarity[0, 0]
        
        sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return [arxiv_id for arxiv_id, _ in sorted_docs[:top_k]]
    
    def retrieve_page_from_document(self, query_embed: np.ndarray, arxiv_id: str,
                                    top_k: int = 5) -> List[Tuple[int, float]]:
        if arxiv_id not in self.page_embeddings:
            return []
        
        page_embeds = self.page_embeddings[arxiv_id]
        similarities = self.embedding_model.cosine_similarity(query_embed, page_embeds)
        
        scores = similarities[0]
        top_indices = np.argsort(scores)[::-1][:top_k]
        
        return [(idx, scores[idx]) for idx in top_indices]
    
    def retrieve_two_stage(self, queries: List[str], top_k_docs: int = 10,
                          top_k_pages: int = 10, batch_size: int = 32) -> List[List[Tuple[str, int]]]:
        print("\n" + "="*80)
        print("TWO-STAGE RETRIEVAL (Text-Based)")
        print(f"Stage 1: Retrieving top {top_k_docs} documents")
        print(f"Stage 2: Retrieving top {top_k_pages} pages")
        print("="*80)
        
        print("Encoding queries...")
        query_embeds = self.embedding_model.encode_texts(queries, batch_size=batch_size)
        
        all_results = []
        
        for i, query_embed in enumerate(tqdm(query_embeds, desc="Retrieving")):
            query_embed = query_embed.reshape(1, -1)
            
            top_docs = self.retrieve_document(query_embed, top_k=top_k_docs)
            
            all_candidates = []
            
            for arxiv_id in top_docs:
                page_results = self.retrieve_page_from_document(query_embed, arxiv_id, top_k=top_k_pages)
                
                for page_num, score in page_results:
                    all_candidates.append(((arxiv_id, page_num), score))
            
            all_candidates.sort(key=lambda x: x[1], reverse=True)
            top_results = [doc_page for doc_page, _ in all_candidates[:top_k_pages]]
            
            all_results.append(top_results)
        
        return all_results


def evaluate_two_stage(results: List[List[Tuple[str, int]]], 
                      ground_truth: List[Tuple[str, int]]) -> Dict[str, float]:
    
    ndcg_5_scores = []
    recall_1_scores = []
    recall_5_scores = []
    mrr_scores = []
    
    doc_recall_scores = []
    page_given_doc_recall = []
    end_to_end_success = []
    
    for result_list, (gt_arxiv_id, gt_page_num) in zip(results, ground_truth):
        if gt_arxiv_id is None:
            continue
        
        retrieved_ids = [f"{arxiv_id}_page_{page_num}" for arxiv_id, page_num in result_list]
        gt_id = f"{gt_arxiv_id}_page_{gt_page_num}"
        
        relevance = np.array([1 if rid == gt_id else 0 for rid in retrieved_ids])
        
        ndcg_5_scores.append(compute_ndcg_at_k(relevance, k=5))
        recall_1_scores.append(1.0 if gt_id in retrieved_ids[:1] else 0.0)
        recall_5_scores.append(1.0 if gt_id in retrieved_ids[:5] else 0.0)
        
        try:
            rank = retrieved_ids.index(gt_id) + 1
            mrr_scores.append(1.0 / rank)
        except ValueError:
            mrr_scores.append(0.0)
        
        retrieved_docs = [arxiv_id for arxiv_id, _ in result_list]
        doc_correct = 1.0 if gt_arxiv_id in retrieved_docs else 0.0
        doc_recall_scores.append(doc_correct)
        
        if doc_correct:
            pages_from_correct_doc = [page for arxiv, page in result_list if arxiv == gt_arxiv_id]
            page_correct_given_doc = 1.0 if gt_page_num in pages_from_correct_doc else 0.0
            page_given_doc_recall.append(page_correct_given_doc)
        
        end_to_end_success.append(1.0 if (gt_arxiv_id, gt_page_num) in result_list else 0.0)
    
    return {
        "NDCG@5": np.mean(ndcg_5_scores) * 100,
        "Recall@1": np.mean(recall_1_scores) * 100,
        "Recall@5": np.mean(recall_5_scores) * 100,
        "MRR": np.mean(mrr_scores),
        "Doc_Recall@5": np.mean(doc_recall_scores) * 100,
        "Page_Given_Doc_Recall": np.mean(page_given_doc_recall) * 100 if page_given_doc_recall else 0.0,
        "End_to_End_Success": np.mean(end_to_end_success) * 100,
        "Stage1_Stage2_Product": np.mean(doc_recall_scores) * np.mean(page_given_doc_recall) if page_given_doc_recall else 0.0
    }


print("="*80)
print("TEXT-BASED RAG BASELINE - Open RAGBench Evaluation")
print("="*80)

MAX_QUERIES = 50
PDF_DIR = "./arxiv_pdfs"
TOP_K_DOCS = 20
TOP_K_PAGES = 20
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

print("\n" + "="*80)
print("Loading Open RAGBench Dataset")
print("="*80)

loader = OpenRAGBenchLoader()
loader.download_dataset(max_queries=MAX_QUERIES)
loader.load_dataset()

text_extractor = PDFTextExtractor(pdf_dir=PDF_DIR)

text_dataset = TextBasedDataset(loader, text_extractor, max_queries=MAX_QUERIES)

print(f"\nDataset Statistics:")
print(f"  Queries: {len(text_dataset.queries)}")
print(f"  Documents: {len(text_dataset.arxiv_ids)}")
total_pages = sum(len(texts) for texts in text_dataset.doc_to_page_texts.values())
print(f"  Total pages: {total_pages}")
print(f"  Avg pages/doc: {total_pages/len(text_dataset.arxiv_ids):.1f}")

print("\n" + "="*80)
print("Initializing Text-Based RAG System")
print("="*80)

embedding_model = TextEmbeddingModel(model_name=EMBEDDING_MODEL)
text_rag = TextBasedRAG(embedding_model)

text_rag.index_documents(text_dataset.arxiv_ids, text_dataset.doc_to_page_texts, batch_size=32)

results = text_rag.retrieve_two_stage(text_dataset.queries, 
                                 top_k_docs=TOP_K_DOCS,
                                 top_k_pages=TOP_K_PAGES,
                                 batch_size=32)

print("\n" + "="*80)
print("EVALUATION RESULTS (Text-Based Baseline)")
print("="*80)

metrics = evaluate_two_stage(results, text_dataset.ground_truth)

print("\nStandard Retrieval Metrics:")
print(f"  NDCG@5:   {metrics['NDCG@5']:.2f}%")
print(f"  Recall@1: {metrics['Recall@1']:.2f}%")
print(f"  Recall@5: {metrics['Recall@5']:.2f}%")
print(f"  MRR:      {metrics['MRR']:.4f}")

print("\nTwo-Stage Breakdown:")
print(f"  Stage 1 - Document Recall@5:       {metrics['Doc_Recall@5']:.2f}%")
print(f"  Stage 2 - Page Recall (given doc): {metrics['Page_Given_Doc_Recall']:.2f}%")
print(f"  End-to-End Success (exact match):  {metrics['End_to_End_Success']:.2f}%")

print("\n Analysis:")
print(f"  â€¢ Text-based RAG (OCR + embeddings)")
print(f"  â€¢ Stage 1: {metrics['Doc_Recall@5']:.0f}% document recall")
print(f"  â€¢ Stage 2: {metrics['Page_Given_Doc_Recall']:.0f}% page recall (given doc)")
print(f"  â€¢ Compound: {metrics['Stage1_Stage2_Product']:.0f}%")

print("\nLimitations of Text-Based Approach:")
print("  âš  Loses visual information (figures, tables, layouts)")
print("  âš  OCR errors on complex PDFs")
print("  âš  Long text truncation (512 tokens max)")
print("  âš  No understanding of visual structure")

print("\n" + "="*80)

In [None]:
"""
Latency Benchmark for RAG Systems
Measures P95 and P99 latencies for ColPali and Text-Based RAG
"""

import torch
import numpy as np
import time
from typing import List, Dict, Tuple
from tqdm import tqdm
import json
from pathlib import Path
from dataclasses import dataclass, asdict
import sys
sys.path.append('.')

@dataclass
class LatencyMetrics:
    operation: str
    mean_ms: float
    median_ms: float
    p50_ms: float
    p95_ms: float
    p99_ms: float
    min_ms: float
    max_ms: float
    std_ms: float
    num_samples: int
    
    def to_dict(self):
        return asdict(self)

class LatencyProfiler:
    
    def __init__(self):
        self.measurements = {}
    
    def measure(self, operation_name: str):
        return LatencyContext(self, operation_name)
    
    def add_measurement(self, operation: str, latency_ms: float):
        if operation not in self.measurements:
            self.measurements[operation] = []
        self.measurements[operation].append(latency_ms)
    
    def compute_metrics(self, operation: str) -> LatencyMetrics:
        if operation not in self.measurements:
            return None
        
        latencies = np.array(self.measurements[operation])
        
        return LatencyMetrics(
            operation=operation,
            mean_ms=float(np.mean(latencies)),
            median_ms=float(np.median(latencies)),
            p50_ms=float(np.percentile(latencies, 50)),
            p95_ms=float(np.percentile(latencies, 95)),
            p99_ms=float(np.percentile(latencies, 99)),
            min_ms=float(np.min(latencies)),
            max_ms=float(np.max(latencies)),
            std_ms=float(np.std(latencies)),
            num_samples=len(latencies)
        )
    
    def get_all_metrics(self) -> Dict[str, LatencyMetrics]:
        return {op: self.compute_metrics(op) for op in self.measurements.keys()}
    
    def reset(self):
        self.measurements = {}

class LatencyContext:
    
    def __init__(self, profiler: LatencyProfiler, operation: str):
        self.profiler = profiler
        self.operation = operation
        self.start_time = None
    
    def __enter__(self):
        self.start_time = time.perf_counter()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        end_time = time.perf_counter()
        latency_ms = (end_time - self.start_time) * 1000
        self.profiler.add_measurement(self.operation, latency_ms)


class ColPaliLatencyBenchmark:
    
    def __init__(self, rag_system, dataset, profiler: LatencyProfiler):
        self.rag = rag_system
        self.dataset = dataset
        self.profiler = profiler
    
    def benchmark_query_encoding(self, num_samples: int = 100):
        print("\nðŸ“Š Benchmarking ColPali Query Encoding...")
        
        sample_queries = self.dataset.queries[:num_samples]
        
        _ = self.rag.encode_queries(sample_queries[:5], batch_size=1)
        
        for query in tqdm(sample_queries, desc="  Single query encoding"):
            with self.profiler.measure("colpali_query_encode_single"):
                _ = self.rag.encode_queries([query], batch_size=1)
        
        for i in range(0, len(sample_queries), 8):
            batch = sample_queries[i:i+8]
            with self.profiler.measure("colpali_query_encode_batch8"):
                _ = self.rag.encode_queries(batch, batch_size=8)
    
    def benchmark_document_encoding(self, num_docs: int = 20):
        print("\nðŸ“Š Benchmarking ColPali Document Encoding...")
        
        sample_docs = list(self.dataset.doc_to_pages.items())[:num_docs]
        
        for arxiv_id, pages in tqdm(sample_docs, desc="  Document encoding"):
            if len(pages) > 0:
                with self.profiler.measure("colpali_page_encode_single"):
                    _ = self.rag.encode_images([pages[0]], batch_size=1)
            
            for i in range(0, min(len(pages), 12), 4):
                batch = pages[i:i+4]
                with self.profiler.measure("colpali_page_encode_batch4"):
                    _ = self.rag.encode_images(batch, batch_size=4)
    
    def benchmark_stage1_retrieval(self, num_queries: int = 100):
        print("\nðŸ“Š Benchmarking ColPali Stage 1 (Document Retrieval)...")
        
        sample_queries = self.dataset.queries[:num_queries]
        query_embeddings = self.rag.encode_queries(sample_queries, batch_size=8)
        
        for query_emb in tqdm(query_embeddings, desc="  Stage 1 retrieval"):
            with self.profiler.measure("colpali_stage1_retrieve_top5"):
                _ = self.rag.retrieve_document(query_emb, top_k=5)
            
            with self.profiler.measure("colpali_stage1_retrieve_top10"):
                _ = self.rag.retrieve_document(query_emb, top_k=10)
    
    def benchmark_stage2_retrieval(self, num_queries: int = 100):
        print("\nðŸ“Š Benchmarking ColPali Stage 2 (Page Retrieval)...")
        
        sample_queries = self.dataset.queries[:num_queries]
        query_embeddings = self.rag.encode_queries(sample_queries, batch_size=8)
        
        sample_doc = self.dataset.arxiv_ids[0]
        
        for query_emb in tqdm(query_embeddings, desc="  Stage 2 retrieval"):
            with self.profiler.measure("colpali_stage2_retrieve_top5"):
                _ = self.rag.retrieve_page_from_document(query_emb, sample_doc, top_k=5)
            
            with self.profiler.measure("colpali_stage2_retrieve_top10"):
                _ = self.rag.retrieve_page_from_document(query_emb, sample_doc, top_k=10)
    
    def benchmark_end_to_end(self, num_queries: int = 50):
        print("\nðŸ“Š Benchmarking ColPali End-to-End Retrieval...")
        
        sample_queries = self.dataset.queries[:num_queries]
        
        for query in tqdm(sample_queries, desc="  End-to-end retrieval"):
            with self.profiler.measure("colpali_e2e_top5_docs_top5_pages"):
                query_emb = self.rag.encode_queries([query], batch_size=1)
                _ = self.rag.retrieve_two_stage(query_emb, top_k_docs=5, top_k_pages=5)
            
            with self.profiler.measure("colpali_e2e_top10_docs_top10_pages"):
                query_emb = self.rag.encode_queries([query], batch_size=1)
                _ = self.rag.retrieve_two_stage(query_emb, top_k_docs=10, top_k_pages=10)


class TextBasedLatencyBenchmark:
    
    def __init__(self, rag_system, dataset, profiler: LatencyProfiler):
        self.rag = rag_system
        self.dataset = dataset
        self.profiler = profiler
    
    def benchmark_query_encoding(self, num_samples: int = 100):
        print("\nðŸ“Š Benchmarking Text-Based Query Encoding...")
        
        sample_queries = self.dataset.queries[:num_samples]
        
        _ = self.rag.embedding_model.encode_texts(sample_queries[:5], batch_size=1)
        
        for query in tqdm(sample_queries, desc="  Single query encoding"):
            with self.profiler.measure("text_query_encode_single"):
                _ = self.rag.embedding_model.encode_texts([query], batch_size=1)
        
        for i in range(0, len(sample_queries), 32):
            batch = sample_queries[i:i+32]
            with self.profiler.measure("text_query_encode_batch32"):
                _ = self.rag.embedding_model.encode_texts(batch, batch_size=32)
    
    def benchmark_document_encoding(self, num_docs: int = 20):
        print("\nðŸ“Š Benchmarking Text-Based Document Encoding...")
        
        sample_docs = list(self.dataset.doc_to_page_texts.items())[:num_docs]
        
        for arxiv_id, page_texts in tqdm(sample_docs, desc="  Document encoding"):
            if len(page_texts) > 0:
                with self.profiler.measure("text_page_encode_single"):
                    _ = self.rag.embedding_model.encode_texts([page_texts[0]], batch_size=1)
            
            if len(page_texts) >= 8:
                with self.profiler.measure("text_page_encode_batch8"):
                    _ = self.rag.embedding_model.encode_texts(page_texts[:8], batch_size=8)
    
    def benchmark_stage1_retrieval(self, num_queries: int = 100):
        print("\nðŸ“Š Benchmarking Text-Based Stage 1 (Document Retrieval)...")
        
        sample_queries = self.dataset.queries[:num_queries]
        query_embeds = self.rag.embedding_model.encode_texts(sample_queries, batch_size=32)
        
        for query_emb in tqdm(query_embeds, desc="  Stage 1 retrieval"):
            query_emb = query_emb.reshape(1, -1)
            
            with self.profiler.measure("text_stage1_retrieve_top5"):
                _ = self.rag.retrieve_document(query_emb, top_k=5)
            
            with self.profiler.measure("text_stage1_retrieve_top10"):
                _ = self.rag.retrieve_document(query_emb, top_k=10)
    
    def benchmark_stage2_retrieval(self, num_queries: int = 100):
        print("\nðŸ“Š Benchmarking Text-Based Stage 2 (Page Retrieval)...")
        
        sample_queries = self.dataset.queries[:num_queries]
        query_embeds = self.rag.embedding_model.encode_texts(sample_queries, batch_size=32)
        
        sample_doc = self.dataset.arxiv_ids[0]
        
        for query_emb in tqdm(query_embeds, desc="  Stage 2 retrieval"):
            query_emb = query_emb.reshape(1, -1)
            
            with self.profiler.measure("text_stage2_retrieve_top5"):
                _ = self.rag.retrieve_page_from_document(query_emb, sample_doc, top_k=5)
            
            with self.profiler.measure("text_stage2_retrieve_top10"):
                _ = self.rag.retrieve_page_from_document(query_emb, sample_doc, top_k=10)
    
    def benchmark_end_to_end(self, num_queries: int = 50):
        print("\nðŸ“Š Benchmarking Text-Based End-to-End Retrieval...")
        
        sample_queries = self.dataset.queries[:num_queries]
        
        for query in tqdm(sample_queries, desc="  End-to-end retrieval"):
            with self.profiler.measure("text_e2e_top5_docs_top5_pages"):
                _ = self.rag.retrieve_two_stage([query], top_k_docs=5, top_k_pages=5, batch_size=1)
            
            with self.profiler.measure("text_e2e_top10_docs_top10_pages"):
                _ = self.rag.retrieve_two_stage([query], top_k_docs=10, top_k_pages=10, batch_size=1)


def print_latency_report(metrics: Dict[str, LatencyMetrics], title: str):
    print("\n" + "="*100)
    print(f"{title}")
    print("="*100)
    print(f"{'Operation':<45} {'Mean':<10} {'P50':<10} {'P95':<10} {'P99':<10} {'Samples':<10}")
    print("-"*100)
    
    for op_name, metric in sorted(metrics.items()):
        print(f"{metric.operation:<45} "
              f"{metric.mean_ms:>8.2f}ms "
              f"{metric.p50_ms:>8.2f}ms "
              f"{metric.p95_ms:>8.2f}ms "
              f"{metric.p99_ms:>8.2f}ms "
              f"{metric.num_samples:>10}")
    
    print("="*100)

def compare_systems(colpali_metrics: Dict[str, LatencyMetrics], 
                   text_metrics: Dict[str, LatencyMetrics]):
    print("\n" + "="*100)
    print("SYSTEM COMPARISON - P95 and P99 Latencies")
    print("="*100)
    
    comparison_ops = [
        ("Query Encoding (Single)", "colpali_query_encode_single", "text_query_encode_single"),
        ("Stage 1: Document Retrieval (Top-5)", "colpali_stage1_retrieve_top5", "text_stage1_retrieve_top5"),
        ("Stage 2: Page Retrieval (Top-5)", "colpali_stage2_retrieve_top5", "text_stage2_retrieve_top5"),
        ("End-to-End (Top-5 docs, Top-5 pages)", "colpali_e2e_top5_docs_top5_pages", "text_e2e_top5_docs_top5_pages"),
        ("End-to-End (Top-10 docs, Top-10 pages)", "colpali_e2e_top10_docs_top10_pages", "text_e2e_top10_docs_top10_pages"),
    ]
    
    print(f"\n{'Operation':<45} {'System':<15} {'P95 (ms)':<12} {'P99 (ms)':<12} {'Winner'}")
    print("-"*100)
    
    for op_name, colpali_key, text_key in comparison_ops:
        if colpali_key in colpali_metrics and text_key in text_metrics:
            colpali_m = colpali_metrics[colpali_key]
            text_m = text_metrics[text_key]
            
            print(f"{op_name:<45} {'ColPali':<15} {colpali_m.p95_ms:>10.2f}  {colpali_m.p99_ms:>10.2f}")
            
            winner = "Text-Based âœ“" if text_m.p95_ms < colpali_m.p95_ms else "ColPali âœ“"
            speedup = colpali_m.p95_ms / text_m.p95_ms if text_m.p95_ms > 0 else 1.0
            print(f"{'':<45} {'Text-Based':<15} {text_m.p95_ms:>10.2f}  {text_m.p99_ms:>10.2f}  {winner} ({speedup:.1f}x)")
            print("-"*100)

def save_results(colpali_metrics: Dict[str, LatencyMetrics],
                text_metrics: Dict[str, LatencyMetrics],
                output_file: str = "latency_benchmark_results.json"):
    results = {
        "colpali": {k: v.to_dict() for k, v in colpali_metrics.items()},
        "text_based": {k: v.to_dict() for k, v in text_metrics.items()},
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nâœ“ Results saved to {output_file}")


def run_full_benchmark(colpali_rag, colpali_dataset, text_rag, text_dataset,
                      num_queries: int = 100, num_docs: int = 20):
    print("\n" + "="*100)
    print("LATENCY BENCHMARK - ColPali vs Text-Based RAG")
    print("="*100)
    
    colpali_profiler = LatencyProfiler()
    text_profiler = LatencyProfiler()
    
    print("\n" + "="*100)
    print("BENCHMARKING COLPALI RAG SYSTEM")
    print("="*100)
    
    colpali_bench = ColPaliLatencyBenchmark(colpali_rag, colpali_dataset, colpali_profiler)
    colpali_bench.benchmark_query_encoding(num_samples=num_queries)
    colpali_bench.benchmark_document_encoding(num_docs=num_docs)
    colpali_bench.benchmark_stage1_retrieval(num_queries=num_queries)
    colpali_bench.benchmark_stage2_retrieval(num_queries=num_queries)
    colpali_bench.benchmark_end_to_end(num_queries=min(50, num_queries))
    
    print("\n" + "="*100)
    print("BENCHMARKING TEXT-BASED RAG SYSTEM")
    print("="*100)
    
    text_bench = TextBasedLatencyBenchmark(text_rag, text_dataset, text_profiler)
    text_bench.benchmark_query_encoding(num_samples=num_queries)
    text_bench.benchmark_document_encoding(num_docs=num_docs)
    text_bench.benchmark_stage1_retrieval(num_queries=num_queries)
    text_bench.benchmark_stage2_retrieval(num_queries=num_queries)
    text_bench.benchmark_end_to_end(num_queries=min(50, num_queries))
    
    colpali_metrics = colpali_profiler.get_all_metrics()
    text_metrics = text_profiler.get_all_metrics()
    
    print_latency_report(colpali_metrics, "COLPALI RAG - LATENCY METRICS")
    print_latency_report(text_metrics, "TEXT-BASED RAG - LATENCY METRICS")
    compare_systems(colpali_metrics, text_metrics)
    
    save_results(colpali_metrics, text_metrics)
    
    return colpali_metrics, text_metrics

In [None]:
colpali_metrics, text_metrics = run_full_benchmark(
    colpali_rag=colpali_rag,
    colpali_dataset=colpali_dataset,
    text_rag=text_rag,
    text_dataset=text_dataset,
    num_queries=100,
    num_docs=20
)