# Bandit Retrieval

## Load Data

In [3]:
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import os
import json

class Dataloader(ABC):    
    @abstractmethod
    def load_dataset(self):
        pass

    @abstractmethod
    def load_questions(self):
        pass

    @abstractmethod
    def load_passages(self):
        pass

    @abstractmethod
    def create_relevance_map(self):
        pass

    def load_data(self):
        self.dataset = self.load_dataset()
        question_ids, question_texts = self.load_questions()
        passage_ids, passage_texts = self.load_passages()
        relevance_map = self.create_relevance_map()
        
        return question_ids, question_texts, passage_ids, passage_texts, relevance_map

    def visualize_relevance(self, fig_path):
        relevance = [qrel.relevance for qrel in self.dataset.qrels_iter()]
        plt.hist(relevance, bins=len(set(relevance)))
        plt.xlabel('Relevance')
        plt.ylabel('Frequency')
        plt.title('Relevance Distribution')
        plt.savefig(fig_path)
        return
    
    def save_relevance_map(self, data_dir, relevance_map):
        data_path = os.path.join(data_dir, "relevance_map.json")
        with open(data_path, 'w') as json_file:
            json.dump(relevance_map, json_file, indent=4)
        return

In [4]:
from collections import defaultdict
import ir_datasets

class Scidocs(Dataloader):
    def load_dataset(self):
        dataset = ir_datasets.load("beir/scidocs")
        return dataset
    
    def load_questions(self):
        # avoid duplicate questions
        question_map = dict()
        sub = dict()

        for query in self.dataset.queries_iter():
            if query.text in question_map:
                sub[query.query_id] = question_map[query.text]
            else:
                question_map[query.text] = query.query_id
            
        self.question_sub = sub

        question_ids = list(question_map.values())
        question_texts = list(question_map.keys())

        return question_ids, question_texts


    def load_passages(self):
        # avoid duplicate passages
        passage_map = dict()
        sub = dict()

        for passage in self.dataset.docs_iter():
            if passage.text in passage_map:
                sub[passage.doc_id] = passage_map[passage.text]
            else:
                passage_map[passage.text] = passage.doc_id

        self.passage_sub = sub

        passage_ids = list(passage_map.values())
        passage_texts = list(passage_map.keys())

        return passage_ids, passage_texts
    

    def create_relevance_map(self):
        relevance_map = defaultdict(dict)
        for qrel in self.dataset.qrels_iter():
            query_id = qrel.query_id
            doc_id = qrel.doc_id
            relevance = qrel.relevance
            
            if relevance <= 0:
                continue

            if query_id in self.question_sub:
                query_id = self.question_sub[query_id]
            
            if doc_id in self.passage_sub:
                doc_id = self.passage_sub[doc_id]
            
            relevance_map[query_id][doc_id] = relevance
        
        return relevance_map

## Run model

In [5]:
scidocs = Scidocs()
question_ids, question_texts, passage_ids, passage_texts, relevance_map = scidocs.load_data()

In [52]:
question_id_to_idx = {question_ids[i]: i for i in range(len(question_ids))}
passage_id_to_idx = {passage_ids[i]: i for i in range(len(passage_ids))}
question_idx_to_id = {i: question_ids[i] for i in range(len(question_ids))}
passage_idx_to_id = {i: passage_ids[i] for i in range(len(passage_ids))}

In [7]:
## Pre-compute Embeddings with Sentence Transformer

import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm

def precompute_embeddings(texts, model_name='all-MiniLM-L6-v2', batch_size=32):
    """
    Precompute embeddings for a list of texts using Sentence Transformer.
    
    Args:
        texts: List of text strings to embed
        model_name: Name of the sentence transformer model to use
        batch_size: Batch size for embedding computation
        
    Returns:
        numpy array of embeddings
    """
    print(f"Loading Sentence Transformer model: {model_name}")
    model = SentenceTransformer(model_name)
    
    # Compute embeddings in batches
    print(f"Computing embeddings for {len(texts)} texts...")
    embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i+batch_size]
        batch_embeddings = model.encode(batch, show_progress_bar=False)
        embeddings.append(batch_embeddings)
    
    return np.vstack(embeddings)

# Precompute question and passage embeddings
model_name = 'all-MiniLM-L6-v2'  # You can change this to other models like 'paraphrase-mpnet-base-v2' for better quality
print("Pre-computing embeddings...")

question_embeddings = precompute_embeddings(question_texts, model_name=model_name)
passage_embeddings = precompute_embeddings(passage_texts, model_name=model_name)

print(f"Question embeddings shape: {question_embeddings.shape}")
print(f"Passage embeddings shape: {passage_embeddings.shape}")

# Save embeddings (optional)
import os
if not os.path.exists('embeddings'):
    os.makedirs('embeddings')
    
np.save('embeddings/question_embeddings.npy', question_embeddings)
np.save('embeddings/passage_embeddings.npy', passage_embeddings)

Pre-computing embeddings...
Loading Sentence Transformer model: all-MiniLM-L6-v2
Computing embeddings for 1000 texts...


  0%|          | 0/32 [00:00<?, ?it/s]

Loading Sentence Transformer model: all-MiniLM-L6-v2
Computing embeddings for 25265 texts...


  0%|          | 0/790 [00:00<?, ?it/s]

Question embeddings shape: (1000, 384)
Passage embeddings shape: (25265, 384)


## Implement Bandit Retrieval with GP-UCB

In [87]:
import numpy as np
import scipy.stats
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern
from sentence_transformers import SentenceTransformer
import time
import random
from tqdm.notebook import tqdm

class LLMInterface:
    """Interface for querying LLM models for relevance scoring."""
    
    def __init__(self, model_name="gpt-3.5-turbo", api_key=None, temperature=0.0, delay=0.0,
                question_embeddings=None, passage_embeddings=None):
        """
        Initialize the LLM interface.
        
        Args:
            model_name: Name of the LLM model to use
            api_key: API key for the LLM service (if applicable)
            temperature: Temperature parameter for LLM sampling
            delay: Simulated delay in seconds (for testing without actual LLM calls)
            question_embeddings: Pre-computed embeddings for questions
            passage_embeddings: Pre-computed embeddings for passages
        """
        self.model_name = model_name
        self.api_key = api_key
        self.temperature = temperature
        self.delay = delay
        self.call_count = 0
        
        # Store precomputed embeddings
        self.question_embeddings = question_embeddings
        self.passage_embeddings = passage_embeddings
        self.has_precomputed_embeddings = (question_embeddings is not None and 
                                          passage_embeddings is not None)
        
        # For simulation purposes
        self.is_simulation = True
        self.sim_scores = {}
    
    def get_relevance_score(self, query, passage, relevance_map=None, query_id=None, passage_id=None,
                           query_idx=None, passage_idx=None):
        """
        Get relevance score for a query-passage pair from the LLM.
        
        In a real implementation, this would query the actual LLM.
        For demonstration purposes, we'll simulate responses based on either:
        1. The actual relevance scores if available in relevance_map
        2. A simulated score based on cosine similarity plus noise
        
        Args:
            query: The query text
            passage: The passage text
            relevance_map: Optional ground truth relevance map for simulation
            query_id: ID of the query in the relevance map
            passage_id: ID of the passage in the relevance map
            query_idx: Index of the query in precomputed embeddings
            passage_idx: Index of the passage in precomputed embeddings
            
        Returns:
            float: Relevance score between 0 and 1
        """
        self.call_count += 1
        
        if self.delay > 0:
            time.sleep(self.delay)  # Simulate LLM API latency
        
        # For simulation purposes
        if self.is_simulation:
            # Use ground truth if available
            if relevance_map and query_id and passage_id:
                if query_id in relevance_map and passage_id in relevance_map[query_id]:
                    # Scale the relevance to [0, 1]
                    true_rel = relevance_map[query_id][passage_id]
                    # Perfect retrieval
                    return true_rel
            
            # Generate key for caching
            key = f"{query}_{passage}"
            if key in self.sim_scores:
                return self.sim_scores[key]
            
            # Simulate relevance with cosine similarity + noise using precomputed embeddings
            if self.has_precomputed_embeddings and query_idx is not None and passage_idx is not None:
                # Use precomputed embeddings directly
                query_embedding = self.question_embeddings[query_idx]
                passage_embedding = self.passage_embeddings[passage_idx]
                
                # Calculate cosine similarity directly using numpy dot product
                # This avoids the sklearn reshape issue
                similarity = np.dot(query_embedding, passage_embedding) / (
                    np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding)
                )
            else:
                # Fall back to computing embeddings on the fly if needed
                sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
                query_embedding = sentence_model.encode(query)
                passage_embedding = sentence_model.encode(passage)
                
                # Calculate cosine similarity directly
                similarity = np.dot(query_embedding, passage_embedding) / (
                    np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding)
                )
            
            # Perfect retrieval
            score = similarity
            self.sim_scores[key] = score
            return score
        
        else:
            # In a real implementation, this would call the actual LLM API
            # Placeholder for actual implementation
            return random.random()

class BanditRetrieval:
    """Bandit-based retrieval system using GP-UCB for selective LLM querying."""
    
    def __init__(self, questions, passages, query_encoder='tfidf', passage_encoder='tfidf', 
                 kernel=None, beta=2.0, normalize_features=True, gp_noise=0.1,
                 init_samples=5, budget=100, llm=None, 
                 question_embeddings=None, passage_embeddings=None):
        """
        Initialize the bandit retrieval system.
        
        Args:
            questions: List of question texts
            passages: List of passage texts
            query_encoder: Method for encoding queries ('tfidf', 'sentence-transformer', etc.)
            passage_encoder: Method for encoding passages
            kernel: Kernel for Gaussian Process, if None uses default RBF
            beta: Exploration parameter for UCB
            normalize_features: Whether to normalize the feature vectors
            gp_noise: Noise level for Gaussian Process
            init_samples: Number of initial samples per query before using GP-UCB
            budget: Maximum number of LLM calls per query
            llm: LLM interface for retrieving relevance scores
            question_embeddings: Pre-computed embeddings for questions (numpy array)
            passage_embeddings: Pre-computed embeddings for passages (numpy array)
        """
        self.questions = questions
        self.passages = passages
        self.query_encoder = query_encoder
        self.passage_encoder = passage_encoder
        self.beta = beta
        self.normalize_features = normalize_features
        self.init_samples = init_samples
        self.budget = budget
        self.llm = llm if llm else LLMInterface()
        
        # Store pre-computed embeddings
        self.question_embeddings = question_embeddings
        self.passage_embeddings = passage_embeddings
        self.using_precomputed = question_embeddings is not None and passage_embeddings is not None
        
        # Set default kernel if none provided
        if kernel is None:
            # Matern kernel is often good for text-based tasks
            length_scale = 1.0
            self.kernel = ConstantKernel(1.0) * Matern(length_scale=length_scale, nu=1.5)
        else:
            self.kernel = kernel
            
        self.gp_noise = gp_noise
        
        # Initialize encoder based on specified method (only if not using pre-computed embeddings)
        if not self.using_precomputed:
            self._initialize_encoders()
        else:
            print("Using pre-computed embeddings")
    
    def _encode_query_passage_pair(self, query_idx, passage_idx):
        """
        Encode a query-passage pair into a feature vector using pre-computed embeddings.
        
        Args:
            query_idx: Index of the query in self.questions
            passage_idx: Index of the passage in self.passages
            
        Returns:
            feature_vector: Combined feature vector representing the query-passage pair
        """
        if self.using_precomputed:
            # Use pre-computed embeddings
            query_vec = self.question_embeddings[query_idx]
            passage_vec = self.passage_embeddings[passage_idx]
            
            # Combine features - here we use concatenation
            feature_vector = np.concatenate([query_vec, passage_vec])
            
            # Optionally normalize
            if self.normalize_features:
                norm = np.linalg.norm(feature_vector)
                if norm > 0:
                    feature_vector = feature_vector / norm
                    
            return feature_vector
        else:
            # Fall back to original method for encoding with texts
            query = self.questions[query_idx]
            passage = self.passages[passage_idx]
            
            # Encode query
            if self.query_encoder == 'tfidf':
                query_vec = self.tfidf.transform([query]).toarray().flatten()
            elif self.query_encoder == 'sentence-transformer':
                query_vec = self.sentence_model.encode(query)
            else:
                raise ValueError(f"Unknown query encoder: {self.query_encoder}")
                
            # Encode passage
            if self.passage_encoder == 'tfidf':
                passage_vec = self.tfidf.transform([passage]).toarray().flatten()
            elif self.passage_encoder == 'sentence-transformer':
                passage_vec = self.sentence_model.encode(passage)
            else:
                raise ValueError(f"Unknown passage encoder: {self.passage_encoder}")
                
            # Combine features
            feature_vector = np.concatenate([query_vec, passage_vec])
            
            # Optionally normalize
            if self.normalize_features:
                norm = np.linalg.norm(feature_vector)
                if norm > 0:
                    feature_vector = feature_vector / norm
                    
            return feature_vector
        
    def _ucb_score(self, means, stds, t):
        """Calculate UCB scores for a given set of means and standard deviations."""
        return means + self.beta * stds / np.sqrt(t)
    
    def retrieve(self, query_idx, query_id=None, relevance_map=None, verbose=True, return_trajectories=False):
        """
        Retrieve relevant passages for a given query using GP-UCB.
        
        Args:
            query_idx: Index of the query in self.questions
            query_id: ID of the query (for relevance map lookup)
            relevance_map: Map of ground truth relevance scores (for evaluation)
            verbose: Whether to print progress information
            return_trajectories: Whether to return mean/std trajectories for visualization
            
        Returns:
            ranked_passages: List of passage indices ranked by relevance
            scores: Relevance scores for each passage
            If return_trajectories is True, also returns lists of means and stds
        """
        query = self.questions[query_idx]
        n_passages = len(self.passages)
        
        # Initialize Gaussian Process Regressor
        gp = GaussianProcessRegressor(
            kernel=self.kernel,
            alpha=self.gp_noise,
            normalize_y=True,
            n_restarts_optimizer=2
        )
        
        # Data structures to track observations
        X_observed = []  # Feature vectors of observed query-passage pairs
        y_observed = []  # Observed relevance scores
        observed_indices = set()  # Indices of observed passages
        
        # For tracking and visualization
        means_history = []
        stds_history = []
        
        # Initialize by sampling a few passages randomly
        initial_indices = np.random.choice(n_passages, min(self.init_samples, n_passages), replace=False)
        
        for passage_idx in initial_indices:
            # Encode the query-passage pair
            feature_vector = self._encode_query_passage_pair(query_idx, passage_idx)
            
            passage_id = passage_idx_to_id[passage_idx]
            # print(f"Passage ID: {passage_id}")
            # print(f"Query ID: {query_id}")
            # print(f"Relevance map: {relevance_map[query_id]}")
            
            # Get relevance score from LLM
            if relevance_map and query_id:
                if query_id in relevance_map and passage_id in relevance_map[query_id]:
                    score = relevance_map[query_id][passage_id]
                else:
                    score = 0
            else:
                score = self.llm.get_relevance_score(query, self.passages[passage_idx], self.question_embeddings[query_idx], self.passage_embeddings[passage_idx])
                
            # Store observation
            X_observed.append(feature_vector)
            y_observed.append(score)
            observed_indices.add(passage_idx)
        
        # Main GP-UCB loop
        pbar = tqdm(total=self.budget, disable=not verbose)
        pbar.update(len(initial_indices))
        
        while len(observed_indices) < min(self.budget, n_passages):
            # Fit GP model to observed data
            if len(X_observed) > 0:
                X_array = np.vstack(X_observed)
                y_array = np.array(y_observed)
                gp.fit(X_array, y_array)
            
            # Encode all unobserved passages and calculate UCB scores
            candidate_indices = [i for i in range(n_passages) if i not in observed_indices]
            
            if not candidate_indices:
                break  # All passages have been observed
                
            X_candidates = []
            for passage_idx in candidate_indices:
                feature_vector = self._encode_query_passage_pair(query_idx, passage_idx)
                X_candidates.append(feature_vector)
                
            X_candidates = np.vstack(X_candidates)
            
            # Get posterior mean and standard deviation
            means, stds = gp.predict(X_candidates, return_std=True)

            print(f"means: {means}")
            
            # Store for visualization
            mean_full = np.zeros(n_passages)
            std_full = np.zeros(n_passages)
            for i, idx in enumerate(candidate_indices):
                mean_full[idx] = means[i]
                std_full[idx] = stds[i]
            for idx in observed_indices:
                idx_obs = list(observed_indices).index(idx)
                mean_full[idx] = y_observed[idx_obs]
                std_full[idx] = 0.0  # No uncertainty for observed points
                
            means_history.append(mean_full.copy())
            stds_history.append(std_full.copy())
            
            # Calculate UCB scores
            t = len(observed_indices) + 1  # Current iteration
            ucb_scores = self._ucb_score(means, stds, t)
            
            # Select passage with highest UCB score
            ucb_idx = np.argmax(ucb_scores)
            selected_idx = candidate_indices[ucb_idx]
            selected_passage = self.passages[selected_idx]
            
            # Get relevance score from LLM
            if relevance_map and query_id:
                score = self.llm.get_relevance_score(
                    query, selected_passage, relevance_map=relevance_map, 
                    query_id=query_id, passage_id=selected_idx
                )
            else:
                score = self.llm.get_relevance_score(query, selected_passage)
                
            # Store observation
            X_observed.append(X_candidates[ucb_idx])
            y_observed.append(score)
            observed_indices.add(selected_idx)
            
            pbar.update(1)
            
            # Optional: Early stopping if convergence criteria met
            # For example, if the highest standard deviation is below a threshold
            if np.max(stds) < 0.05 and len(observed_indices) > self.init_samples:
                if verbose:
                    print(f"Stopping early at {len(observed_indices)} observations due to convergence")
                break
                
        pbar.close()
        
        # Final prediction for all passages
        all_features = []
        for passage_idx in range(n_passages):
            feature_vector = self._encode_query_passage_pair(query_idx, passage_idx)
            all_features.append(feature_vector)
            
        X_all = np.vstack(all_features)
        
        # Fit the final GP model
        X_array = np.vstack(X_observed)
        y_array = np.array(y_observed)
        gp.fit(X_array, y_array)
        
        # Get posterior mean for all passages
        final_means, final_stds = gp.predict(X_all, return_std=True)
        
        # For observed passages, use the actual scores
        for i, idx in enumerate(observed_indices):
            idx_obs = list(observed_indices).index(idx)
            final_means[idx] = y_observed[idx_obs]
            
        # Rank passages by relevance score
        ranked_indices = np.argsort(-final_means)
        
        if return_trajectories:
            return ranked_indices, final_means, means_history, stds_history
        else:
            return ranked_indices, final_means

## Evaluate and Visualize Results

In [88]:
def evaluate_retrieval(ranked_indices, true_relevance, cutoffs=[5, 10, 20, 50, 100]):
    """
    Evaluate retrieval performance using precision, recall, and nDCG at different cutoffs.
    
    Args:
        ranked_indices: List of passage indices ranked by relevance
        true_relevance: Dictionary mapping passage indices to relevance scores
        cutoffs: List of cutoff values for evaluation
        
    Returns:
        metrics: Dictionary of evaluation metrics
    """
    metrics = {}
    
    # Convert true relevance to a list aligned with ranked_indices
    true_rel_list = np.zeros(len(ranked_indices))
    for idx, score in true_relevance.items():
        idx = passage_id_to_idx[idx]
        if idx < len(true_rel_list):
            true_rel_list[idx] = score
    
    # Sort true relevance according to ranked_indices
    sorted_rel = true_rel_list[ranked_indices]
    
    for k in cutoffs:
        if k > len(ranked_indices):
            continue
            
        # Precision@k
        precision = np.sum(sorted_rel[:k] > 0) / k
        metrics[f'P@{k}'] = precision
        
        # Recall@k
        total_relevant = np.sum(true_rel_list > 0)
        if total_relevant > 0:
            recall = np.sum(sorted_rel[:k] > 0) / total_relevant
            metrics[f'R@{k}'] = recall
        else:
            metrics[f'R@{k}'] = 0.0
            
        # nDCG@k
        dcg = np.sum(sorted_rel[:k] / np.log2(np.arange(2, k+2)))
        
        # Ideal ordering for iDCG
        ideal_rel = np.sort(true_rel_list)[::-1]
        idcg = np.sum(ideal_rel[:k] / np.log2(np.arange(2, k+2)))
        
        if idcg > 0:
            ndcg = dcg / idcg
            metrics[f'nDCG@{k}'] = ndcg
        else:
            metrics[f'nDCG@{k}'] = 0.0
            
    return metrics

def plot_metrics(metrics_list, query_ids):
    """Plot average metrics across all queries."""
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Combine metrics from all queries
    all_metrics = {}
    for i, metrics in enumerate(metrics_list):
        for k, v in metrics.items():
            if k not in all_metrics:
                all_metrics[k] = []
            all_metrics[k].append(v)
    
    # Calculate average metrics
    avg_metrics = {k: np.mean(v) for k, v in all_metrics.items()}
    
    # Extract metric groups
    precision_metrics = {k: v for k, v in avg_metrics.items() if k.startswith('P@')}
    recall_metrics = {k: v for k, v in avg_metrics.items() if k.startswith('R@')}
    ndcg_metrics = {k: v for k, v in avg_metrics.items() if k.startswith('nDCG@')}
    
    # Create a figure with 3 subplots for each metric type
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot each metric group
    for i, (metrics, title) in enumerate(zip(
        [precision_metrics, recall_metrics, ndcg_metrics],
        ['Precision@k', 'Recall@k', 'nDCG@k']
    )):
        # Sort by cutoff value
        cutoffs = [int(k.split('@')[1]) for k in metrics.keys()]
        sorted_idx = np.argsort(cutoffs)
        x = [list(metrics.keys())[i] for i in sorted_idx]
        y = [list(metrics.values())[i] for i in sorted_idx]
        
        axes[i].bar(x, y, color='skyblue')
        axes[i].set_title(title)
        axes[i].set_ylim(0, 1)
        axes[i].grid(axis='y', linestyle='--', alpha=0.7)
        for j, value in enumerate(y):
            axes[i].text(j, value + 0.02, f'{value:.3f}', ha='center')
            
    plt.tight_layout()
    return fig, avg_metrics

def visualize_gp_ucb_trajectory(means_history, stds_history, query, top_passages=5):
    """Visualize the GP-UCB algorithm's trajectory of means and uncertainties."""
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    n_iterations = len(means_history)
    n_passages = len(means_history[0])
    
    # Plot mean estimates over iterations
    plt.figure(figsize=(14, 7))
    plt.subplot(1, 2, 1)
    
    # Get top k passages from final iteration
    final_means = means_history[-1]
    top_k_indices = np.argsort(-final_means)[:top_passages]
    
    # Plot means for top k passages
    for i in top_k_indices:
        means = [m[i] for m in means_history]
        plt.plot(range(n_iterations), means, '-o', label=f'Passage {i}')
        
    plt.title(f'Mean Estimates for Top {top_passages} Passages')
    plt.xlabel('Iteration')
    plt.ylabel('Relevance Score')
    plt.legend()
    plt.grid(True)
    
    # Plot standard deviations over iterations
    plt.subplot(1, 2, 2)
    for i in top_k_indices:
        stds = [s[i] for s in stds_history]
        plt.plot(range(n_iterations), stds, '-o', label=f'Passage {i}')
        
    plt.title(f'Standard Deviations for Top {top_passages} Passages')
    plt.xlabel('Iteration')
    plt.ylabel('Uncertainty (σ)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.suptitle(f'GP-UCB Trajectory for Query: "{query[:50]}..."', y=1.05)
    return plt

## Run Experiment

In [1]:
# Create bandit retrieval system with pre-computed embeddings
llm = LLMInterface(delay=0.0)  # For simulation
bandit = BanditRetrieval(
    questions=question_texts,
    passages=passage_texts,
    query_encoder='sentence-transformer',  # This will be ignored when using pre-computed embeddings
    passage_encoder='sentence-transformer',  # This will be ignored when using pre-computed embeddings
    beta=2.0,
    init_samples=1000,
    budget=50,
    llm=llm,
    question_embeddings=question_embeddings,  # Pass pre-computed embeddings
    passage_embeddings=passage_embeddings     # Pass pre-computed embeddings
)

# Select a few queries for demonstration
n_queries = 1
selected_query_indices = np.random.choice(len(question_texts), n_queries, replace=False)
selected_query_ids = [question_ids[idx] for idx in selected_query_indices]

# Store results for evaluation
results = []
metrics_list = []

for i, (query_idx, query_id) in enumerate(zip(selected_query_indices, selected_query_ids)):
    print(f"\nProcessing query {i+1}/{n_queries}: {question_texts[query_idx][:100]}...")
    
    # Run bandit retrieval
    ranked_indices, scores, means_history, stds_history = bandit.retrieve(query_idx=query_idx, query_id=query_id, relevance_map=relevance_map,verbose=True,return_trajectories=True)
    
    # Store results
    results.append((query_idx, query_id, ranked_indices, scores, means_history, stds_history))
    
    # Evaluate if ground truth is available
    if query_id in relevance_map:
        true_relevance = relevance_map[query_id]
        metrics = evaluate_retrieval(ranked_indices, true_relevance)
        metrics_list.append(metrics)
        print("Metrics:", {k: f"{v:.3f}" for k, v in metrics.items()})
        
    # Display top 3 retrieved passages
    print("\nTop retrieved passages:")
    for j in range(min(3, len(ranked_indices))):
        passage_idx = ranked_indices[j]
        score = scores[passage_idx]
        print(f"{j+1}. Score: {score:.3f}")
        print(f"   Passage: {passage_texts[passage_idx][:200]}...")
    
    # Visualize GP-UCB trajectory
    visualize_gp_ucb_trajectory(
        means_history, 
        stds_history, 
        question_texts[query_idx],
        top_passages=5
    )
    plt.show()

# Plot overall metrics if we have evaluations
if metrics_list:
    fig, avg_metrics = plot_metrics(metrics_list, selected_query_ids)
    plt.show()
    print("Average metrics:", {k: f"{v:.3f}" for k, v in avg_metrics.items()})


NameError: name 'LLMInterface' is not defined

## Comparative Analysis

In [None]:
def random_baseline(questions, passages, query_idx, budget, llm, relevance_map=None, query_id=None):
    """
    Baseline that randomly samples passages.
    
    Args:
        questions: List of question texts
        passages: List of passage texts
        query_idx: Index of the query
        budget: Number of passages to sample
        llm: LLM interface for retrieving relevance scores
        relevance_map: Ground truth relevance map
        query_id: ID of the query in the relevance map
        
    Returns:
        ranked_indices: List of passage indices ranked by relevance
        scores: Dictionary mapping passage indices to relevance scores
    """
    query = questions[query_idx]
    n_passages = len(passages)
    
    # Randomly sample passages
    sampled_indices = np.random.choice(n_passages, min(budget, n_passages), replace=False)
    
    # Get relevance scores for sampled passages
    scores = {}
    for idx in sampled_indices:
        passage = passages[idx]
        if relevance_map and query_id:
            score = llm.get_relevance_score(
                query, passage, relevance_map=relevance_map, 
                query_id=query_id, passage_id=idx
            )
        else:
            score = llm.get_relevance_score(query, passage)
        scores[idx] = score
    
    # Rank the sampled passages by their scores
    ranked_indices = sorted(scores.keys(), key=lambda idx: -scores[idx])
    
    # For unobserved passages, assign a default score
    all_scores = np.zeros(n_passages)
    for idx, score in scores.items():
        all_scores[idx] = score
    
    return ranked_indices, all_scores

def tfidf_baseline(questions, passages, query_idx, budget, llm, top_k=100, relevance_map=None, query_id=None):
    """
    Baseline that uses TF-IDF to pre-rank passages and then scores the top-k.
    
    Args:
        questions: List of question texts
        passages: List of passage texts
        query_idx: Index of the query
        budget: Number of passages to score with LLM
        llm: LLM interface for retrieving relevance scores
        top_k: Number of top TF-IDF passages to consider
        relevance_map: Ground truth relevance map
        query_id: ID of the query in the relevance map
        
    Returns:
        ranked_indices: List of passage indices ranked by relevance
        scores: Array of relevance scores
    """
    query = questions[query_idx]
    n_passages = len(passages)
    
    # Use TF-IDF to calculate initial similarity scores
    vectorizer = TfidfVectorizer()
    all_texts = [query] + passages
    tfidf_matrix = vectorizer.fit_transform(all_texts)
    
    # Calculate similarity between query and all passages
    query_vec = tfidf_matrix[0:1]
    passage_vecs = tfidf_matrix[1:]
    similarities = cosine_similarity(query_vec, passage_vecs)[0]
    
    # Get top-k passages by TF-IDF similarity
    top_indices = np.argsort(-similarities)[:top_k]
    
    # Score top passages with LLM up to budget
    llm_indices = top_indices[:min(budget, len(top_indices))]
    
    scores = {}
    for idx in llm_indices:
        passage = passages[idx]
        if relevance_map and query_id:
            score = llm.get_relevance_score(
                query, passage, relevance_map=relevance_map, 
                query_id=query_id, passage_id=idx
            )
        else:
            score = llm.get_relevance_score(query, passage)
        scores[idx] = score
    
    # Re-rank the top passages by LLM scores
    ranked_indices = sorted(scores.keys(), key=lambda idx: -scores[idx])
    
    # For all passages, assign either the LLM score or the normalized TF-IDF score
    all_scores = np.zeros(n_passages)
    
    # For passages not scored by LLM, use normalized TF-IDF scores
    norm_similarities = (similarities - np.min(similarities)) / (np.max(similarities) - np.min(similarities))
    for i in range(n_passages):
        if i in scores:
            all_scores[i] = scores[i]  # Use LLM score if available
        else:
            all_scores[i] = norm_similarities[i] * 0.5  # Scale TF-IDF scores to [0, 0.5]
    
    return list(ranked_indices) + [i for i in range(n_passages) if i not in ranked_indices], all_scores

# Compare methods on a single query
query_idx = selected_query_indices[0]
query_id = selected_query_ids[0]
budget = 50

print(f"Query: {question_texts[query_idx]}")

# GP-UCB (already run above, just re-use the first result)
gp_ucb_indices, gp_ucb_scores = results[0][2], results[0][3]

# Random baseline
random_indices, random_scores = random_baseline(
    question_texts, passage_texts, query_idx, budget, llm, 
    relevance_map=relevance_map, query_id=query_id
)

# TF-IDF baseline
tfidf_indices, tfidf_scores = tfidf_baseline(
    question_texts, passage_texts, query_idx, budget, llm, top_k=100,
    relevance_map=relevance_map, query_id=query_id
)

# Evaluate methods
if query_id in relevance_map:
    true_relevance = relevance_map[query_id]
    gp_ucb_metrics = evaluate_retrieval(gp_ucb_indices, true_relevance)
    random_metrics = evaluate_retrieval(random_indices, true_relevance)
    tfidf_metrics = evaluate_retrieval(tfidf_indices, true_relevance)
    
    print("\nGP-UCB Metrics:", {k: f"{v:.3f}" for k, v in gp_ucb_metrics.items()})
    print("Random Baseline Metrics:", {k: f"{v:.3f}" for k, v in random_metrics.items()})
    print("TF-IDF Baseline Metrics:", {k: f"{v:.3f}" for k, v in tfidf_metrics.items()})
    
    # Compare methods with bar chart
    import pandas as pd
    
    # Prepare data for plotting
    metrics_names = sorted([k for k in gp_ucb_metrics.keys() if k.startswith('nDCG')])
    methods = ['GP-UCB', 'Random', 'TF-IDF']
    
    data = {
        'Method': [],
        'Metric': [],
        'Value': []
    }
    
    for metric in metrics_names:
        data['Method'].extend(methods)
        data['Metric'].extend([metric] * 3)
        data['Value'].extend([
            gp_ucb_metrics[metric],
            random_metrics[metric],
            tfidf_metrics[metric]
        ])
    
    df = pd.DataFrame(data)
    
    # Plot comparison
    plt.figure(figsize=(12, 6))
    sns.barplot(x='Metric', y='Value', hue='Method', data=df)
    plt.title(f'Comparison of Retrieval Methods (Query: {question_texts[query_idx][:50]}...)')
    plt.ylim(0, 1)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()