<a href="https://colab.research.google.com/github/FailedAnalysis/CS544/blob/main/544tp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [2]:
!pip install peft



In [3]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.18-py311-none-any.whl.metadata (7.5 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m 

1. Data Processing

In [9]:
import json
import pandas as pd
import numpy as np
import torch
import re
import logging
import sys
import os
from tqdm.auto import tqdm

# Hugging Face and evaluation libraries
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    pipeline,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments
)
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, TaskType
from sentence_transformers import SentenceTransformer, losses, models, util, InputExample, CrossEncoder
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import evaluate
import datasets

# Set up logging - Moved to the absolute top level
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Ensure handler is only added once if script is run multiple times
if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    logger.addHandler(handler)


# --- Data Processing Functions ---

def load_json_data_squad(file_path):
    """Helper to load SQuAD JSON data."""
    if not os.path.exists(file_path):
        logger.error(f"File not found: {file_path}")
        return None
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if 'data' not in data:
             logger.error(f"JSON file {file_path} does not contain the expected 'data' key.")
             return None
        return data['data'] # SQuAD JSON has a top-level 'data' key
    except json.JSONDecodeError as e:
        logger.error(f"Error decoding JSON from {file_path}: {e}")
        return None
    except Exception as e:
        logger.error(f"An unexpected error occurred loading {file_path}: {e}")
        return None

def process_squad_json(train_json_path, dev_json_path):
    """
    Loads SQuAD 2.0 data from local JSON files and processes it
    into DataFrames suitable for both Retriever and Generator.

    Args:
        train_json_path (str): Path to the train-v2.0.json file (or your renamed file).
        dev_json_path (str): Path to the dev-v2.0.json file (or your renamed file).

    Returns:
        tuple: (documents_df, retriever_train_df, retriever_test_df, generator_train_df, generator_dev_df)
               documents_df: DataFrame for corpus ['id', 'passage']
               retriever_train_df: DataFrame for retriever train ['id', 'question', 'relevant_passage_ids']
               retriever_test_df: DataFrame for retriever test ['id', 'question', 'relevant_passage_ids']
               generator_train_df: DataFrame for generator train ['id', 'question', 'answer', 'relevant_passage_ids']
               generator_dev_df: DataFrame for generator dev ['id', 'question', 'answer', 'relevant_passage_ids']
    """
    logger.info(f"Loading and processing SQuAD 2.0 data from local files: {train_json_path}, {dev_json_path}")

    train_data = load_json_data_squad(train_json_path)
    dev_data = load_json_data_squad(dev_json_path)

    if train_data is None and dev_data is None:
         logger.error("Failed to load both train and dev JSON files.")
         return pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    # Use empty lists if one fails but the other succeeds
    if train_data is None: train_data = []
    if dev_data is None: dev_data = []


    # --- Build Corpus and Context Map ---
    logger.info("Building corpus and context map...")
    context_to_id = {}
    documents_list = []
    doc_id_counter = 0

    def process_paragraphs_for_corpus(data_split, desc):
        nonlocal doc_id_counter
        if not isinstance(data_split, list): return
        for article in tqdm(data_split, desc=f"Processing articles ({desc} corpus)"):
            if not isinstance(article, dict) or 'paragraphs' not in article or not isinstance(article['paragraphs'], list): continue
            for paragraph in article['paragraphs']:
                if not isinstance(paragraph, dict) or 'context' not in paragraph: continue
                context = paragraph['context']
                if not isinstance(context, str):
                     logger.warning(f"Context is not a string in {desc}, type: {type(context)}. Skipping.")
                     continue
                stripped_context = context.strip()
                if stripped_context not in context_to_id:
                    doc_id = f"doc_{doc_id_counter}"
                    context_to_id[stripped_context] = doc_id
                    documents_list.append({'id': doc_id, 'passage': context})
                    doc_id_counter += 1

    process_paragraphs_for_corpus(train_data, "train")
    process_paragraphs_for_corpus(dev_data, "dev")

    documents_df = pd.DataFrame(documents_list)
    logger.info(f"Created corpus with {len(documents_df)} unique passages.")

    if documents_df.empty:
         logger.error("Corpus is empty after processing.")
         return documents_df, pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()


    # --- Process Questions for Retriever and Generator DataFrames ---
    logger.info("Processing questions for Retriever and Generator DataFrames...")

    def process_qas_for_dataframes(data_split, desc):
        retriever_list = [] # For retriever train/test
        generator_list = [] # For generator train/dev

        if not isinstance(data_split, list): return ([], [])

        for article in tqdm(data_split, desc=f"Processing questions ({desc})"):
            if not isinstance(article, dict) or 'paragraphs' not in article or not isinstance(article['paragraphs'], list): continue
            for paragraph in article['paragraphs']:
                if not isinstance(paragraph, dict) or 'context' not in paragraph or 'qas' not in paragraph or not isinstance(paragraph['qas'], list): continue

                context = paragraph['context']
                context_str = str(context) if not isinstance(context, str) else context
                stripped_context = context_str.strip()
                passage_id = context_to_id.get(stripped_context)

                if passage_id is None:
                     logger.warning(f"Context '{stripped_context[:50]}...' not found in corpus map for a question in {desc}. Skipping questions in this paragraph.")
                     continue

                for qa in paragraph['qas']:
                    if not isinstance(qa, dict) or 'id' not in qa or 'question' not in qa or 'is_impossible' not in qa or 'answers' not in qa: continue

                    qid = str(qa['id'])
                    question_text = str(qa['question'])
                    is_impossible = qa['is_impossible']
                    answers = qa['answers']

                    # For Retriever: relevant_passage_ids is the ID of the context paragraph
                    retriever_list.append({
                        'id': qid,
                        'question': question_text,
                        'relevant_passage_ids': [passage_id] # List containing the string passage ID
                    })

                    # For Generator: also need the answer text
                    answer_text = ""
                    if not is_impossible and answers and 'text' in answers and isinstance(answers['text'], list) and len(answers['text']) > 0:
                        answer_text = str(answers['text'][0])

                    generator_list.append({
                        'id': qid,
                        'question': question_text,
                        'answer': answer_text,
                        'relevant_passage_ids': [passage_id] # List containing the string passage ID (for generator training context)
                    })
        return retriever_list, generator_list

    retriever_train_list, generator_train_list = process_qas_for_dataframes(train_data, "train")
    retriever_test_list, generator_dev_list = process_qas_for_dataframes(dev_data, "dev")

    retriever_train_df = pd.DataFrame(retriever_train_list)
    retriever_test_df = pd.DataFrame(retriever_test_list)
    generator_train_df = pd.DataFrame(generator_train_list)
    generator_dev_df = pd.DataFrame(generator_dev_list)

    logger.info(f"Created retriever_train_df with {len(retriever_train_df)} questions.")
    logger.info(f"Created retriever_test_df with {len(retriever_test_df)} questions.")
    logger.info(f"Created generator_train_df with {len(generator_train_df)} questions.")
    logger.info(f"Created generator_dev_df with {len(generator_dev_df)} questions.")


    return documents_df, retriever_train_df, retriever_test_df, generator_train_df, generator_dev_df


2. Retriever

In [5]:
class Retriever:

    def __init__(self, retriever_model_name="all-mpnet-base-v2", reranker_model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
        """
        Initialize a SentenceTransformer model for retrieval and a CrossEncoder for re-ranking.

        Args:
            retriever_model_name (str): Name of the transformer model to load for retrieval (bi-encoder).
            reranker_model_name (str): Name of the transformer model to load for re-ranking (cross-encoder).
        """
        logger.info(f"Initializing retriever model: {retriever_model_name}")
        self.retriever_model = SentenceTransformer(retriever_model_name)

        logger.info(f"Initializing reranker model: {reranker_model_name}")
        self.reranker_model = CrossEncoder(reranker_model_name)

        # Internal caches and placeholders
        self.corpus = {}         # dict of {doc_id -> passage}
        self.corpus_ids = []     # list of doc_ids
        self.corpus_texts = []   # list of passages
        self.corpus_embeddings = None

        # For training/evaluation
        self.train_examples = []
        self.queries = {}        # {query_id: query_text} used for IR evaluation
        self.relevant_docs = {}  # {query_id: {doc_id: 1}}

    def load_corpus(self, documents_df):
        """
        Load the corpus documents into memory.

        Args:
            documents_df (pd.DataFrame): Must have columns ['id', 'passage']
        """
        logger.info("Loading corpus documents...")
        documents_df['id'] = documents_df['id'].astype(str)
        documents_df['passage'] = documents_df['passage'].fillna("").astype(str)

        self.corpus = dict(zip(documents_df['id'], documents_df['passage']))
        self.corpus_ids = list(self.corpus.keys())
        self.corpus_texts = list(self.corpus.values())

        logger.info(f"Corpus size: {len(self.corpus)} documents.")

    def prepare_data(self, train_df, negative_samples=3, eval_ratio=0.2):
        """
        Prepares training data (with negative samples) and an evaluation set.

        Args:
            train_df (pd.DataFrame): Must have columns ['id', 'question', 'relevant_passage_ids'] (list of strings)
            negative_samples (int): Number of negative samples per query
            eval_ratio (float): Fraction of train_df to hold out for evaluation
        """
        if not self.corpus:
            raise ValueError("No corpus loaded. Call load_corpus(documents_df) first.")

        logger.info(f"Preparing training data from DataFrame ({len(train_df)} rows)...")

        train_df['id'] = train_df['id'].astype(str)
        train_df['question'] = train_df['question'].fillna("").astype(str)
        # Ensure relevant_passage_ids is a list of strings
        train_df['relevant_passage_ids'] = train_df['relevant_passage_ids'].apply(
             lambda x: [str(i).strip() for i in (eval(str(x)) if isinstance(x, str) else x) if str(i).strip()] if pd.notna(x) else []
        )


        # Determine eval set
        total_len = len(train_df)
        eval_size = int(eval_ratio * total_len)
        eval_size = min(eval_size, total_len)
        if total_len > 0 and eval_ratio > 0 and eval_size == 0: eval_size = 1
        eval_indices = set(np.random.choice(range(total_len), size=eval_size, replace=False)) if eval_size > 0 else set()

        train_examples = []
        queries = {}
        relevant_docs = {}

        for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Preparing data"):
            qid = str(row['id'])
            question_text = row['question']
            rel_ids = row['relevant_passage_ids'] # This is now a list of strings

            # Positive examples
            valid_rel_ids = [rid for rid in rel_ids if rid in self.corpus]
            for rid in valid_rel_ids:
                 train_examples.append(InputExample(
                     texts=[question_text, self.corpus[rid]],
                     label=1.0
                 ))
            if not valid_rel_ids:
                 logger.warning(f"Query {qid} has no relevant documents found in corpus. Cannot create positive examples for training.")


            # Negative examples (random sample passages not in rel_ids)
            # Only sample negatives if there's at least one valid positive example possible
            if valid_rel_ids:
                all_irrelevant_ids = [pid for pid in self.corpus.keys() if pid not in rel_ids]
                if len(all_irrelevant_ids) > 0 and negative_samples > 0:
                    neg_sample_ids = np.random.choice(
                        all_irrelevant_ids,
                        min(negative_samples, len(all_irrelevant_ids)),
                        replace=False
                    )
                    for nid in neg_sample_ids:
                        train_examples.append(InputExample(
                            texts=[question_text, self.corpus[nid]],
                            label=0.0
                        ))

            # If this index is in eval set, store queries/relevant_docs
            if idx in eval_indices:
                 queries[qid] = question_text
                 # Ensure only relevant docs present in the corpus are added for evaluation ground truth
                 relevant_docs[qid] = {doc_id: 1 for doc_id in rel_ids if doc_id in self.corpus}
                 # If a query has no ground truth relevant docs in the corpus, exclude it from evaluation
                 if not relevant_docs[qid]:
                      if qid in queries: del queries[qid]
                      if qid in relevant_docs: del relevant_docs[qid]


        self.train_examples = train_examples
        self.queries = queries
        self.relevant_docs = relevant_docs

        logger.info(f"Total training examples: {len(self.train_examples)}")
        logger.info(f"Eval queries: {len(self.queries)}; corpus size: {len(self.corpus)}")


    def train(self,
              epochs=1,
              evaluation_steps=250,
              warmup_steps=200,
              output_path="output/retriever-model"):
        """
        Train the SentenceTransformer model.
        """
        if not self.train_examples:
            raise ValueError("No training examples found. Run prepare_data(...) first.")

        logger.info(f"Training retriever model on {len(self.train_examples)} examples...")
        train_batch_size = 16 # Adjust based on GPU memory
        train_dataloader = DataLoader(self.train_examples, batch_size=train_batch_size, shuffle=True)
        train_loss = losses.MultipleNegativesRankingLoss(self.retriever_model)

        ir_evaluator = None
        if evaluation_steps > 0 and len(self.queries) > 0 and len(self.corpus) > 0 and len(self.relevant_docs) > 0:
            logger.info("Setting up Information Retrieval Evaluator...")
            ir_evaluator = InformationRetrievalEvaluator(
                queries={qid: str(q) for qid, q in self.queries.items()},
                corpus={doc_id: str(doc) for doc_id, doc in self.corpus.items()},
                relevant_docs=self.relevant_docs,
                show_progress_bar=True,
                corpus_chunk_size=100000
            )
        elif evaluation_steps > 0:
             logger.warning(f"Skipping evaluation during training (evaluation_steps > 0) as eval data is incomplete.")
             evaluation_steps = 0


        logger.info("Starting training...")
        self.retriever_model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            evaluator=ir_evaluator,
            epochs=epochs,
            evaluation_steps=evaluation_steps,
            warmup_steps=warmup_steps,
            output_path=output_path,
            show_progress_bar=True
        )
        logger.info("Retriever training complete!")


    def fit(self,
            train_df: pd.DataFrame,
            negative_samples=3,
            eval_ratio=0.2,
            epochs=1,
            evaluation_steps=250,
            warmup_steps=200,
            output_path="output/retriever-model"):
        """
        Convenience method to prepare data and then train in one shot.
        """
        self.prepare_data(train_df, negative_samples=negative_samples, eval_ratio=eval_ratio)
        self.train(epochs=epochs,
                   evaluation_steps=evaluation_steps,
                   warmup_steps=warmup_steps,
                   output_path=output_path)


    def precompute_corpus_embeddings(self, batch_size=1024):
        """
        Compute and cache corpus embeddings for faster retrieval.
        """
        if not self.corpus_texts:
            raise ValueError("No corpus found. Did you run load_corpus(...) first?")

        logger.info(f"Computing embeddings for {len(self.corpus_texts)} passages...")
        all_embeddings = []

        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Encoding device: {device}")
        self.retriever_model.to(device)

        try:
            for start_idx in tqdm(range(0, len(self.corpus_texts), batch_size), desc="Encoding corpus"):
                batch = self.corpus_texts[start_idx:start_idx + batch_size]
                batch = [str(t) for t in batch]
                batch_embeddings = self.retriever_model.encode(
                    batch,
                    convert_to_tensor=True,
                    show_progress_bar=False
                )
                all_embeddings.append(batch_embeddings.cpu())

            self.corpus_embeddings = torch.cat(all_embeddings, dim=0)
            logger.info(f"Corpus embeddings shape: {self.corpus_embeddings.shape}")

        except Exception as e:
            logger.error(f"Error during corpus encoding: {e}")
            self.corpus_embeddings = None
            raise


    def re_rank_passages(self, query, initial_results):
        """
        Re-ranks an initial list of retrieved passages using the cross-encoder model.
        """
        if not initial_results: return []
        logger.info(f"Re-ranking {len(initial_results)} passages for query...")
        cross_encoder_input = [[str(query), str(res['passage'])] for res in initial_results]
        if not cross_encoder_input: return []

        try:
            rerank_scores = self.reranker_model.predict(cross_encoder_input)
            for i, res in enumerate(initial_results):
                res['rerank_score'] = float(rerank_scores[i])
            reranked_results = sorted(initial_results, key=lambda x: x.get('rerank_score', -float('inf')), reverse=True)
            logger.info("Re-ranking complete.")
            return reranked_results
        except Exception as e:
            logger.error(f"Error during re-ranking: {e}")
            logger.warning("Re-ranking failed. Returning initial results sorted by original score.")
            return sorted(initial_results, key=lambda x: x.get('score', -float('inf')), reverse=True)


    def retrieve_top_k(self, query, top_k=5, initial_retrieval_k=100, use_reranking=True):
        """
        Retrieve top_k passages for a single query string, optionally using re-ranking.
        """
        if self.corpus_embeddings is None:
            logger.info("Corpus embeddings not precomputed. Computing now.")
            self.precompute_corpus_embeddings()
            if self.corpus_embeddings is None:
                 logger.error("Failed to compute corpus embeddings. Cannot perform retrieval.")
                 return []

        device = self.corpus_embeddings.device if self.corpus_embeddings.device.type != 'cpu' else ("cuda" if torch.cuda.is_available() else "cpu")
        query_embedding = self.retriever_model.encode(str(query), convert_to_tensor=True, device=device)

        cos_scores = util.cos_sim(query_embedding, self.corpus_embeddings)[0]

        initial_k = initial_retrieval_k if use_reranking else top_k
        initial_k = min(initial_k, len(self.corpus_texts))
        initial_k = max(1, initial_k) if len(self.corpus_texts) > 0 else 0

        if initial_k == 0:
             logger.warning("Corpus is empty or initial_k is 0. Cannot retrieve.")
             return []

        top_values, top_indices = torch.topk(cos_scores, k=initial_k)
        top_indices_list = top_indices.cpu().numpy()

        initial_retrieved_passages = []
        corpus_ids_str = [str(cid) for cid in self.corpus_ids]

        for idx in top_indices_list:
            cid = corpus_ids_str[idx]
            passage = self.corpus.get(cid, "")
            initial_retrieved_passages.append({
                'corpus_id': cid,
                'passage': passage,
                'score': cos_scores[idx].item()
            })

        if use_reranking:
            reranked_results = self.re_rank_passages(query, initial_retrieved_passages)
            return reranked_results[:top_k]
        else:
            return sorted(initial_retrieved_passages, key=lambda x: x.get('score', -float('inf')), reverse=True)[:top_k]


    def retrieve_for_test(self, test_df, top_k=5, initial_retrieval_k=100, use_reranking=True):
        """
        Retrieves top_k passages for each question in a test DataFrame.
        """
        if self.corpus_embeddings is None:
            logger.info("Corpus embeddings not precomputed. Computing now.")
            self.precompute_corpus_embeddings()
            if self.corpus_embeddings is None:
                 logger.error("Failed to compute corpus embeddings. Cannot perform retrieval for test set.")
                 result_df = test_df.copy()
                 result_df['relevant_passage_ids'] = [[]] * len(test_df)
                 return result_df

        if test_df.empty:
             logger.warning("Test DataFrame is empty. Skipping retrieval for test set.")
             return test_df.copy().assign(relevant_passage_ids=[[] for _ in range(len(test_df))]) # Add empty column


        queries = test_df['question'].astype(str).tolist()

        device = self.corpus_embeddings.device if self.corpus_embeddings.device.type != 'cpu' else ("cuda" if torch.cuda.is_available() else "cpu")
        query_embeddings = self.retriever_model.encode(queries, convert_to_tensor=True, device=device, show_progress_bar=True)

        cos_scores = util.cos_sim(query_embeddings, self.corpus_embeddings)

        relevant_passage_ids = []

        initial_k = initial_retrieval_k if use_reranking else top_k
        initial_k = min(initial_k, len(self.corpus_texts))
        initial_k = max(1, initial_k) if len(self.corpus_texts) > 0 else 0

        if initial_k == 0:
             logger.warning("Corpus is empty or initial_k is 0. Cannot retrieve for test set.")
             relevant_passage_ids = [[]] * len(test_df)
        else:
            corpus_ids_str = [str(cid) for cid in self.corpus_ids]

            for i in tqdm(range(len(test_df)), desc="Retrieving for test set"):
                row_scores = cos_scores[i]
                top_values, top_indices = torch.topk(row_scores, k=initial_k)
                top_indices_list = top_indices.cpu().numpy()

                initial_retrieved_passages = []
                for idx in top_indices_list:
                     cid = corpus_ids_str[idx]
                     passage = self.corpus.get(cid, "")
                     initial_retrieved_passages.append({
                         'corpus_id': cid,
                         'passage': passage,
                         'score': row_scores[idx].item()
                     })

                if use_reranking:
                    query_text = queries[i]
                    reranked_results = self.re_rank_passages(query_text, initial_retrieved_passages)
                    top_k_ids = [res['corpus_id'] for res in reranked_results[:top_k]]
                else:
                    top_k_ids = [res['corpus_id'] for res in sorted(initial_retrieved_passages, key=lambda x: x.get('score', -float('inf')), reverse=True)[:top_k]]

                relevant_passage_ids.append(top_k_ids)

        result_df = test_df.copy()
        result_df['relevant_passage_ids'] = relevant_passage_ids
        return result_df


    def evaluate(self, test_df, top_k=5, metrics_k=10, initial_retrieval_k=100, use_reranking=True):
        """
        Evaluate the retriever using IR metrics.
        """
        if metrics_k > top_k:
             logger.warning(f"metrics_k ({metrics_k}) is greater than top_k ({top_k}). Metrics will be calculated based on top_{top_k} results.")
             metrics_k = top_k
        if metrics_k <= 0:
             logger.warning("metrics_k is <= 0. Skipping evaluation.")
             return {}
        if test_df.empty or 'relevant_passage_ids' not in test_df.columns:
             logger.warning("Test DataFrame is empty or missing 'relevant_passage_ids' for evaluation. Skipping evaluation.")
             return { f"recall@{metrics_k}": 0.0, f"precision@{metrics_k}": 0.0, "mrr": 0.0 }


        logger.info(f"Starting evaluation with metrics_k={metrics_k} (top_k={top_k}, use_reranking={use_reranking})")

        # Ensure ground truth relevant_passage_ids are lists of strings
        test_df['relevant_passage_ids'] = test_df['relevant_passage_ids'].apply(lambda x: [str(i).strip() for i in (eval(str(x)) if isinstance(x, str) else x) if str(i).strip()] if pd.notna(x) else []
        )

        # Retrieve predicted top_k for each query
        retrieved_df = self.retrieve_for_test(test_df[['id', 'question']].copy(),
                                              top_k=top_k,
                                              initial_retrieval_k=initial_retrieval_k,
                                              use_reranking=use_reranking)

        results_dict = {}
        relevant_dict = {}

        test_df['id'] = test_df['id'].astype(str)
        retrieved_df['id'] = retrieved_df['id'].astype(str)

        predicted_ids_map = dict(zip(retrieved_df['id'], retrieved_df['relevant_passage_ids']))

        common_query_ids = set(test_df['id']).intersection(set(retrieved_df['id']))
        if not common_query_ids:
             logger.warning("No common query IDs between original test_df and retrieved_df. Cannot compute metrics.")
             return { f"recall@{metrics_k}": 0.0, f"precision@{metrics_k}": 0.0, "mrr": 0.0 }

        filtered_test_df = test_df[test_df['id'].isin(common_query_ids)].set_index('id')
        filtered_retrieved_df = retrieved_df[retrieved_df['id'].isin(common_query_ids)].set_index('id')


        for qid in tqdm(common_query_ids, desc="Processing queries for metrics"):
            predicted_list = filtered_retrieved_df.loc[qid, 'relevant_passage_ids']
            results_dict[qid] = predicted_list[:metrics_k]

            true_ids = filtered_test_df.loc[qid, 'relevant_passage_ids']
            relevant_dict[qid] = {doc_id: 1 for doc_id in true_ids if doc_id in self.corpus}

            if not relevant_dict[qid] and qid in results_dict:
                 del results_dict[qid]

        return self.evaluate_ir_metrics(results_dict, relevant_dict, k=metrics_k)


    def evaluate_ir_metrics(self, results, relevant_docs, k=10):
        """Compute common IR metrics (Recall@k, Precision@k, MRR)."""
        common_qids = set(results.keys()).intersection(set(relevant_docs.keys()))

        if not common_qids:
             logger.warning("Cannot compute IR metrics: No common queries with both predictions and ground truth relevant docs in corpus.")
             return { f"recall@{k}": 0.0, f"precision@{k}": 0.0, "mrr": 0.0 }

        filtered_results = {qid: results[qid] for qid in common_qids}
        filtered_relevant_docs = {qid: relevant_docs[qid] for qid in common_qids}

        recall = self._calculate_recall_at_k(filtered_results, filtered_relevant_docs, k)
        precision = self._calculate_precision_at_k(filtered_results, filtered_relevant_docs, k)
        mrr = self._calculate_mrr(filtered_results, filtered_relevant_docs)

        return {
            f"recall@{k}": recall,
            f"precision@{k}": precision,
            "mrr": mrr
        }

    def _calculate_recall_at_k(self, results, relevant_docs, k):
        recalls = []
        for query_id, retrieved_docs in results.items():
            if query_id in relevant_docs:
                relevant = set(relevant_docs[query_id].keys())
                retrieved = set(retrieved_docs if isinstance(retrieved_docs, list) else [])
                if len(relevant) > 0:
                    recall = len(relevant.intersection(retrieved)) / len(relevant)
                    recalls.append(recall)
        return sum(recalls) / len(recalls) if recalls else 0.0

    def _calculate_precision_at_k(self, results, relevant_docs, k):
        precisions = []
        for query_id, retrieved_docs in results.items():
            if query_id in relevant_docs:
                relevant = set(relevant_docs[query_id].keys())
                retrieved = retrieved_docs if isinstance(retrieved_docs, list) else []
                if len(retrieved) > 0:
                    precision = len(relevant.intersection(set(retrieved))) / len(retrieved)
                    precisions.append(precision)
                else:
                    precisions.append(0.0)
        return sum(precisions) / len(precisions) if precisions else 0.0

    def _calculate_mrr(self, results, relevant_docs):
        mrr_scores = []
        for query_id, retrieved_docs in results.items():
            if query_id in relevant_docs:
                 relevant = set(relevant_docs[query_id].keys())
                 if isinstance(retrieved_docs, list):
                     found_relevant = False
                     for i, doc_id in enumerate(retrieved_docs):
                         if doc_id in relevant:
                             mrr_scores.append(1.0 / (i + 1))
                             found_relevant = True
                             break
                     if not found_relevant:
                         mrr_scores.append(0.0)
                 else:
                      mrr_scores.append(0.0)
        return sum(mrr_scores) / len(mrr_scores) if mrr_scores else 0.0

3. Generator

In [8]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineGrained).
The token `tewelf` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `tewelf`


In [10]:
class Generator:
    """
    A class for training, evaluating, and predicting with a text generation (causal LM) model.
    """

    def __init__(self, model_name="meta-llama/Meta-Llama-3-8B-Instruct"):
        logger.info(f"Loading model and tokenizer for {model_name}")
        # Added trust_remote_code=True for potential custom code in some models
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for better training stability if supported
            bnb_4bit_use_double_quant=True,
        )
        # Added trust_remote_code=True for potential custom code in some models
        self.model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, trust_remote_code=True)

        if self.tokenizer.pad_token is None:
             self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.tokenizer.eos_token_id

        self.doc_dict = {}
        self.qa_df = None # Stores prepared training/eval data for generator
        self.tokenized_dataset = None

    def build_doc_dict(self, documents_df):
        logger.info("Building document dictionary from DataFrame...")
        documents_df["id"] = documents_df["id"].astype(str)
        documents_df["passage"] = documents_df["passage"].fillna("").astype(str)
        self.doc_dict = dict(zip(documents_df["id"], documents_df["passage"]))
        logger.info(f"Built doc_dict with {len(self.doc_dict)} documents.")

    def create_prompt(self, question, context, answer=None):
        """Build a prompt string without any truncation."""
        # Simple prompt template for RAG
        if answer is not None:
            prompt = f"Question: {question}\nContext: {context}\nAnswer: {answer}"
        else:
            prompt = f"Question: {question}\nContext: {context}\nAnswer:"
        return prompt

    def build_truncated_prompt(self, question, context, max_new_tokens):
        """
        Build a prompt ensuring its tokenized length leaves space for max_new_tokens.
        Truncates the context if necessary.
        """
        question = str(question)
        context = str(context)

        max_positions = self.model.config.max_position_embeddings
        allowed_prompt_length = max_positions - max_new_tokens

        q_part_template = f"Question: {question}\nContext: "
        a_part_template = "\nAnswer:"

        # Use encode for more accurate token length calculation including special tokens if any are added by encode
        q_tokens_len = len(self.tokenizer.encode(q_part_template, add_special_tokens=False))
        a_tokens_len = len(self.tokenizer.encode(a_part_template, add_special_tokens=False))

        fixed_length = q_tokens_len + a_tokens_len

        allowed_for_context = max(0, allowed_prompt_length - fixed_length)

        if not isinstance(context, str) or not context.strip():
             truncated_context = ""
        else:
             context_tokens = self.tokenizer.encode(context, add_special_tokens=False)
             truncated_context_tokens = context_tokens[:allowed_for_context]
             truncated_context = self.tokenizer.decode(truncated_context_tokens, skip_special_tokens=True)

        truncated_prompt = f"Question: {question}\nContext: {truncated_context}\nAnswer:"

        final_prompt_length = len(self.tokenizer.encode(truncated_prompt, add_special_tokens=True))

        if final_prompt_length > allowed_prompt_length:
             logger.warning(
                 f"Final truncated prompt length ({final_prompt_length}) still exceeds allowed ({allowed_prompt_length}) "
                 f"before adding generation tokens. Review truncation logic or parameters."
             )

        full_raw_prompt_length = len(self.tokenizer.encode(f"Question: {question}\nContext: {context}\nAnswer:", add_special_tokens=True))
        if final_prompt_length < full_raw_prompt_length:
             logger.debug(
                 f"Prompt truncated from {full_raw_prompt_length} tokens to {final_prompt_length} tokens "
                 f"(allowed prompt length {allowed_prompt_length}). Query: {question[:50]}..."
             )

        return truncated_prompt


    def prepare_training_data(self, train_df):
        """
        Prepare training data by combining questions, context from relevant passages, and answers into prompts.

        Expects train_df to have the following columns:
          - 'question'
          - 'answer'
          - 'relevant_passage_ids' : a list of string IDs (e.g., ["doc_1", "doc_5"]) or a string representation
                                     (These are the GROUND TRUTH context IDs for generator training)
        """
        if not self.doc_dict:
            raise ValueError("Document dictionary is empty. Call build_doc_dict(documents_df) first.")
        if train_df.empty:
             logger.warning("Input training DataFrame is empty. Skipping data preparation.")
             self.qa_df = pd.DataFrame(columns=["question", "relevant_docs", "answer", "prompt"])
             return


        logger.info(f"Preparing training data from DataFrame ({len(train_df)} rows)...")
        records = []
        train_df['relevant_passage_ids'] = train_df['relevant_passage_ids'].apply(
             lambda x: [str(i).strip() for i in (eval(str(x)) if isinstance(x, str) else x) if str(i).strip()] if pd.notna(x) else []
        )
        train_df['question'] = train_df['question'].fillna("").astype(str)
        train_df['answer'] = train_df['answer'].fillna("").astype(str)


        for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Preparing training data"):
            question = row["question"]
            answer = row["answer"]
            doc_ids = row["relevant_passage_ids"] # This is now a list of strings (ground truth context ID)

            relevant_texts = []
            for pid in doc_ids:
                if pid in self.doc_dict:
                    relevant_texts.append(self.doc_dict[pid])
                else:
                    logger.warning(f"Relevant passage ID {pid} not found in doc_dict for question ID {row.get('id', 'N/A')}. Skipping passage for training prompt.")

            combined_passages = " ".join(relevant_texts)

            # Create the prompt including the answer for training
            prompt = self.create_prompt(question, combined_passages, answer)

            records.append({
                "question": question,
                "relevant_docs": combined_passages,
                "answer": answer,
                "prompt": prompt
            })

        self.qa_df = pd.DataFrame(records)
        logger.info(f"Prepared {len(self.qa_df)} training records.")

    def tokenize_training_data(self, max_length=512):
        """
        Convert the prepared training DataFrame into a tokenized Hugging Face Dataset.
        max_length is the sequence length for training (prompt + answer).
        """
        if self.qa_df is None or self.qa_df.empty:
            raise ValueError("Training data not prepared or is empty. Cannot tokenize.")

        logger.info("Converting prepared data to Dataset and tokenizing...")
        dataset = datasets.Dataset.from_pandas(self.qa_df[["prompt"]])

        model_max_length = self.model.config.max_position_embeddings
        if max_length > model_max_length:
            logger.warning(f"Requested max_length ({max_length}) exceeds model max_position_embeddings ({model_max_length}). Using model_max_length.")
            max_length = model_max_length
        if max_length <= 0:
             logger.error("max_length must be positive. Tokenization failed.")
             self.tokenized_dataset = None
             return


        def tokenize_function(examples):
            tokenized_inputs = self.tokenizer(
                examples["prompt"],
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt"
            )
            return tokenized_inputs

        self.tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["prompt"])
        logger.info(f"Tokenization complete. Using max_length={max_length}.")


    def train_model(self, output_dir="output/generator-finetuned", num_train_epochs=1, batch_size=4,
                    gradient_accumulation_steps=8, logging_steps=50, learning_rate=2e-5):
        """
        Train (fine-tune) the generator model using the tokenized dataset.
        """
        if self.tokenized_dataset is None or len(self.tokenized_dataset) == 0:
            raise ValueError("Tokenized training data not found or is empty. Cannot train.")

        logger.info("Setting up PEFT/LoRA for training...")

        lora_config = LoraConfig(
            r=64,
            lora_alpha=16,
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # Common Llama targets
        )

        try:
             peft_model = get_peft_model(self.model, lora_config)
             peft_model.print_trainable_parameters()
             logger.info("PEFT model prepared successfully.")
        except Exception as e:
             logger.error(f"Error applying PEFT config: {e}")
             logger.error("Please check target_modules for your specific Llama3 model.")
             raise # Re-raise the error


        logger.info("Setting up training arguments and Trainer...")

        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False
        )

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_train_epochs,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim="paged_adamw_8bit",
            learning_rate=learning_rate,
            fp16=True, # Requires GPU
            logging_steps=logging_steps,
            save_steps=logging_steps * 5,
            save_total_limit=3,
            evaluation_strategy="no", # Disable eval during training
            logging_dir=f"{output_dir}/logs",
            report_to="tensorboard",
            run_name=f"llama3_{output_dir}_run",
            push_to_hub=False,
        )

        trainer = Trainer(
            model=peft_model,
            args=training_args,
            train_dataset=self.tokenized_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )

        logger.info("Starting training...")
        trainer.train()
        logger.info("Generator training complete!")

        # Save the LoRA adapters
        peft_model.save_pretrained(output_dir)
        logger.info(f"LoRA adapters saved to {output_dir}")


    def evaluate_generator(self, eval_df, max_eval_samples=100, max_new_tokens=64,
                           do_sample=False, seed=42):
        """
        Evaluate the generator model on a subset of examples using metrics like BLEU/ROUGE.
        Uses GROUND TRUTH contexts from relevant_passage_ids in eval_df.
        """
        if not self.doc_dict:
            raise ValueError("Document dictionary is empty. Ensure build_doc_dict(documents_df) has been called.")
        if eval_df.empty or 'relevant_passage_ids' not in eval_df.columns or 'answer' not in eval_df.columns:
             logger.warning("Evaluation DataFrame is empty or missing required columns ('relevant_passage_ids', 'answer'). Skipping evaluation.")
             return {}

        logger.info(f"Evaluating generator model on {min(len(eval_df), max_eval_samples)} samples...")

        subset_df = eval_df.sample(n=min(len(eval_df), max_eval_samples), random_state=seed).reset_index(drop=True)

        predictions = []
        references = []

        self.model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Generation device: {device}")
        self.model.to(device)

        gen_pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if device == "cuda" else -1,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
        )

        try:
            bleu_metric = evaluate.load("bleu")
            rouge_metric = evaluate.load("rouge")
            logger.info("Evaluation metrics loaded.")
        except Exception as e:
            logger.error(f"Failed to load evaluation metrics: {e}. Evaluation will skip metric calculation.")
            bleu_metric = None
            rouge_metric = None

        with torch.no_grad():
            for i, row in tqdm(subset_df.iterrows(), total=len(subset_df), desc="Generating for Evaluation"):
                question = str(row["question"])
                gold_answer = str(row["answer"])

                relevant_ids_raw = row["relevant_passage_ids"]
                doc_ids = []
                if isinstance(relevant_ids_raw, list):
                     doc_ids = [str(pid).strip() for pid in relevant_ids_raw if str(pid).strip()]
                elif isinstance(relevant_ids_raw, str):
                      try: doc_ids = [str(pid).strip() for pid in eval(relevant_ids_raw) if str(pid).strip()]
                      except (SyntaxError, NameError, TypeError): doc_ids = [pid.strip() for pid in relevant_ids_raw.strip().split(",") if pid.strip()]
                doc_ids = [d for d in doc_ids if d]

                relevant_texts = []
                for pid in doc_ids:
                    if pid in self.doc_dict:
                        relevant_texts.append(self.doc_dict[pid])
                    else:
                        logger.warning(f"Eval passage ID {pid} not found in doc_dict for question ID {row.get('id', 'N/A')}. Skipping passage.")

                combined_context = " ".join(relevant_texts)

                prompt = self.build_truncated_prompt(question, combined_context, max_new_tokens)

                try:
                    gen_output = gen_pipe(
                        prompt,
                        max_new_tokens=max_new_tokens,
                        num_return_sequences=1,
                        do_sample=do_sample,
                        temperature=0.7 if do_sample else 1.0,
                        top_k=50 if do_sample else None,
                        top_p=0.95 if do_sample else None,
                        pad_token_id=self.tokenizer.eos_token_id,
                        eos_token_id=self.tokenizer.eos_token_id,
                        return_full_text=False
                    )
                    if gen_output and isinstance(gen_output, list) and len(gen_output) > 0 and 'generated_text' in gen_output[0]:
                         pred_answer = gen_output[0]['generated_text'].strip()
                    else:
                         pred_answer = ""
                         logger.warning(f"Pipeline returned unexpected output for prompt starting '{prompt[:50]}...'. Output: {gen_output}")

                except Exception as e:
                    logger.error(f"Error during generation for prompt starting '{prompt[:50]}...': {e}")
                    pred_answer = ""


                predictions.append(pred_answer)
                references.append(gold_answer)

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        evaluation_results = {}
        if bleu_metric:
            references_formatted = [[ans] for ans in references]
            try:
                bleu_score = bleu_metric.compute(predictions=predictions, references=references_formatted)
                evaluation_results.update({f"bleu_{k}": v for k,v in bleu_score.items()})
                logger.info(f"BLEU score: {bleu_score.get('bleu', 'N/A'):.4f}")
            except Exception as e:
                 logger.error(f"Error computing BLEU: {e}")

        if rouge_metric:
             try:
                  rouge_scores = rouge_metric.compute(predictions=predictions, references=references)
                  evaluation_results.update(rouge_scores)
                  logger.info(f"ROUGE scores: {rouge_scores}")
             except Exception as e:
                  logger.error(f"Error computing ROUGE: {e}")

        return evaluation_results, predictions, references


    def predict(self, test_df, max_new_tokens=64, do_sample=False, seed=42):
        """
        Generate answers for a test DataFrame containing:
          - 'id', 'question', and 'relevant_passage_ids' (predicted by the retriever)
        The generated answers are saved in a new column "predicted_answer".
        """
        if not self.doc_dict:
            raise ValueError("Document dictionary is empty. Ensure build_doc_dict(documents_df) has been called.")
        if test_df.empty or 'relevant_passage_ids' not in test_df.columns:
             logger.warning("Test DataFrame is empty or missing 'relevant_passage_ids'. Skipping prediction.")
             return test_df.copy().assign(predicted_answer=[""] * len(test_df))


        logger.info(f"Generating answers for test data ({len(test_df)} rows)...")

        results_df = test_df.copy()
        predictions = []

        self.model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Generation device: {device}")
        self.model.to(device)

        gen_pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if device == "cuda" else -1,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
        )

        results_df['relevant_passage_ids'] = results_df['relevant_passage_ids'].apply(
             lambda x: [str(i).strip() for i in (eval(str(x)) if isinstance(x, str) else x) if str(i).strip()] if pd.notna(x) else []
        )
        results_df['question'] = results_df['question'].fillna("").astype(str)


        with torch.no_grad():
            for i, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Generating answers"):
                question = row["question"]
                doc_ids = row["relevant_passage_ids"] # Predicted IDs from retriever

                relevant_texts = []
                for pid in doc_ids:
                    if pid in self.doc_dict:
                        relevant_texts.append(self.doc_dict[pid])
                    else:
                        logger.warning(f"Predicted passage ID {pid} not found in doc_dict for question ID {row.get('id', 'N/A')}. Skipping passage.")

                combined_context = " ".join(relevant_texts)

                prompt = self.build_truncated_prompt(question, combined_context, max_new_tokens)

                try:
                    gen_output = gen_pipe(
                        prompt,
                        max_new_tokens=max_new_tokens,
                        num_return_sequences=1,
                        do_sample=do_sample,
                        temperature=0.7 if do_sample else 1.0,
                        top_k=50 if do_sample else None,
                        top_p=0.95 if do_sample else None,
                        pad_token_id=self.tokenizer.eos_token_id,
                        eos_token_id=self.tokenizer.eos_token_id,
                        return_full_text=False
                    )
                    if gen_output and isinstance(gen_output, list) and len(gen_output) > 0 and 'generated_text' in gen_output[0]:
                         pred_answer = gen_output[0]['generated_text'].strip()
                    else:
                         pred_answer = ""
                         logger.warning(f"Pipeline returned unexpected output for prompt starting '{prompt[:50]}...'. Output: {gen_output}")

                except Exception as e:
                    logger.error(f"Error during generation for prompt starting '{prompt[:50]}...': {e}")
                    pred_answer = ""

                predictions.append(pred_answer)

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        results_df["predicted_answer"] = predictions
        logger.info("Answer generation complete.")
        return results_df

4. Encapsulation

In [11]:
class RAGSystem:
    """
    A complete RAG QnA system combining a Retriever and a Generator.
    """
    def __init__(self, retriever_model_path=None, generator_model_path=None,
                 retriever_base_model="all-mpnet-base-v2", reranker_base_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
                 generator_base_model="meta-llama/Meta-Llama-3-8B-Instruct"):
        """
        Initializes the RAG system.

        Args:
            retriever_model_path (str, optional): Path to the fine-tuned retriever model.
                                                  If None, loads the base model.
            generator_model_path (str, optional): Path to the fine-tuned generator (LoRA adapters).
                                                  If None, loads the base model (requires PEFT).
            retriever_base_model (str): Base model name for the retriever.
            reranker_base_model (str): Base model name for the reranker.
            generator_base_model (str): Base model name for the generator.
        """
        logger.info("Initializing RAG System...")

        # Initialize Retriever
        # If a fine-tuned model path is provided, load it. Otherwise, use the base model.
        if retriever_model_path and os.path.exists(retriever_model_path):
             logger.info(f"Loading fine-tuned retriever model from {retriever_model_path}")
             # SentenceTransformer can load from a directory path
             self.retriever_model_instance = SentenceTransformer(retriever_model_path)
             # Need to re-initialize Retriever class to use this loaded model
             self.retriever = Retriever(retriever_model_name=retriever_model_path, reranker_model_name=reranker_base_model)
             # However, the Retriever class init always loads from name.
             # Let's modify RAGSystem to hold the models directly or pass them.
             # A cleaner approach is to pass the loaded models to the Retriever/Generator instances.

             # Let's re-structure RAGSystem slightly to hold the loaded models
             logger.info("Loading base retriever model for RAGSystem...")
             base_retriever_model = SentenceTransformer(retriever_base_model)
             if retriever_model_path and os.path.exists(retriever_model_path):
                 # Attempt to load the fine-tuned model over the base
                 try:
                     # SentenceTransformer load from path replaces the internal model
                     base_retriever_model.load(retriever_model_path)
                     logger.info(f"Loaded fine-tuned retriever weights from {retriever_model_path}")
                 except Exception as e:
                     logger.error(f"Failed to load fine-tuned retriever model from {retriever_model_path}: {e}")
                     logger.warning("Proceeding with base retriever model only.")
             self.retriever = Retriever(retriever_base_model, reranker_base_model) # Use base names in init
             # Assign the potentially fine-tuned model instance to the retriever
             self.retriever.retriever_model = base_retriever_model # Assign the loaded model instance

        else:
             logger.info(f"Loading base retriever model {retriever_base_model}")
             self.retriever = Retriever(retriever_base_model, reranker_base_model)


        # Initialize Generator
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
        logger.info(f"Loading base generator model {generator_base_model} with quantization...")
        base_generator_model = AutoModelForCausalLM.from_pretrained(
            generator_base_model,
            quantization_config=bnb_config,
            trust_remote_code=True,
            device_map="auto"
        )
        base_tokenizer = AutoTokenizer.from_pretrained(generator_base_model, trust_remote_code=True)
        if base_tokenizer.pad_token is None:
             base_tokenizer.pad_token = base_tokenizer.eos_token
        base_generator_model.config.pad_token_id = base_tokenizer.eos_token_id


        # Load fine-tuned LoRA adapters if path is provided and exists
        if generator_model_path and os.path.exists(generator_model_path):
             logger.info(f"Loading fine-tuned generator adapters from {generator_model_path}")
             try:
                 from peft import PeftModel
                 self.generator_model = PeftModel.from_pretrained(base_generator_model, generator_model_path)
                 logger.info("LoRA adapters loaded and merged with base model.")
             except Exception as e:
                 logger.error(f"Failed to load PEFT adapters from {generator_model_path}: {e}")
                 logger.warning("Proceeding with base generator model only.")
                 self.generator_model = base_generator_model
        else:
             logger.info("No fine-tuned generator path provided or path does not exist. Using base generator model.")
             self.generator_model = base_generator_model

        self.generator_tokenizer = base_tokenizer


        # Placeholder for corpus documents (needed by Retriever and Generator's build_doc_dict)
        self.documents_df = None
        # Retriever's corpus is loaded via self.retriever.load_corpus()
        # Generator needs a doc_dict internally for predict method.
        # Let's add a doc_dict to RAGSystem and pass it or manage passage text retrieval here.
        self.corpus_doc_dict = {} # RAGSystem will hold the doc_dict


    def load_corpus(self, documents_df):
        """
        Load the corpus documents into the RAG system.
        This populates both the Retriever's corpus and the RAGSystem's internal doc_dict.
        """
        logger.info("Loading corpus into RAG system...")
        self.documents_df = documents_df.copy()
        self.retriever.load_corpus(self.documents_df) # Load into Retriever
        # Build doc_dict for RAGSystem using the same corpus
        self.corpus_doc_dict = dict(zip(self.documents_df["id"].astype(str), self.documents_df["passage"].fillna("").astype(str)))
        logger.info("Corpus loaded into RAG system's internal storage.")


    def precompute_retriever_embeddings(self, batch_size=1024):
        """
        Precompute corpus embeddings for the retriever.
        Must be called after load_corpus.
        """
        if self.documents_df is None or self.documents_df.empty:
            raise ValueError("Corpus not loaded. Call load_corpus(documents_df) first.")
        self.retriever.precompute_corpus_embeddings(batch_size=batch_size)


    def answer_question(self, question: str, top_k_retrieval=5, initial_retrieval_k=100, use_reranking=True,
                        max_new_tokens=100, do_sample=False):
        """
        Answers a question using the RAG pipeline.

        Args:
            question (str): The user's question.
            top_k_retrieval (int): Number of top passages to retrieve after potential re-ranking.
            initial_retrieval_k (int): Number of candidates for initial retrieval before re-ranking.
            use_reranking (bool): Whether to use the cross-encoder for re-ranking.
            max_new_tokens (int): Maximum tokens for the generator to produce.
            do_sample (bool): Whether to use sampling for generation.

        Returns:
            str: The generated answer.
        """
        if self.retriever.corpus_embeddings is None:
            logger.warning("Retriever embeddings not precomputed. Computing now...")
            try:
                self.precompute_retriever_embeddings()
            except Exception as e:
                logger.error(f"Failed to precompute embeddings: {e}. Cannot perform retrieval.")
                return "Error: Could not initialize retriever."

        if not self.corpus_doc_dict:
             logger.error("Corpus document dictionary not loaded. Cannot retrieve passage texts.")
             return "Error: Corpus not loaded."

        logger.info(f"Answering question: '{question}'")

        # 1. Retrieval
        logger.info(f"Retrieving top {top_k_retrieval} passages...")
        retrieved_results = self.retriever.retrieve_top_k(
            question,
            top_k=top_k_retrieval,
            initial_retrieval_k=initial_retrieval_k,
            use_reranking=use_reranking
        )

        if not retrieved_results:
            logger.warning("No passages retrieved.")
            return "Could not find relevant information to answer the question."

        # 2. Prepare Context for Generator
        # Combine the text of the retrieved passages using the RAGSystem's doc_dict
        combined_context = " ".join([self.corpus_doc_dict.get(res.get('corpus_id', ''), '') for res in retrieved_results])

        if not combined_context.strip():
             logger.warning("Retrieved passages are empty or contain no text.")
             # Optionally return a message indicating no relevant info found even if IDs were retrieved
             # Or proceed with empty context to see if the LLM can answer from general knowledge
             # For now, let's proceed with empty context if it's just whitespace
             if not combined_context:
                 logger.warning("Combined context is empty.")


        logger.info(f"Combined context length: {len(combined_context)} characters.")
        # logger.debug(f"Retrieved passages (first 100 chars): {[self.corpus_doc_dict.get(res.get('corpus_id', ''), '')[:100] + '...' for res in retrieved_results]}")


        # 3. Generation
        logger.info("Generating answer...")

        # Use the RAGSystem's generator model and tokenizer
        model = self.generator_model
        tokenizer = self.generator_tokenizer
        max_positions = model.config.max_position_embeddings

        # Build a prompt that is safely truncated for generation
        max_new_tokens_gen = max_new_tokens # Use the parameter passed to answer_question
        allowed_prompt_length = max_positions - max_new_tokens_gen

        q_part_template = f"Question: {question}\nContext: "
        a_part_template = "\nAnswer:"

        q_tokens_len = len(tokenizer.encode(q_part_template, add_special_tokens=False))
        a_tokens_len = len(tokenizer.encode(a_part_template, add_special_tokens=False))

        fixed_length = q_tokens_len + a_tokens_len
        allowed_for_context = max(0, allowed_prompt_length - fixed_length)

        context_to_truncate = combined_context # Use the combined text from retrieved passages

        if not isinstance(context_to_truncate, str) or not context_to_truncate.strip():
             truncated_context = ""
        else:
             context_tokens = tokenizer.encode(context_to_truncate, add_special_tokens=False)
             truncated_context_tokens = context_tokens[:allowed_for_context]
             truncated_context = tokenizer.decode(truncated_context_tokens, skip_special_tokens=True)

        prompt = f"Question: {question}\nContext: {truncated_context}\nAnswer:"

        # Verify final prompt length
        final_prompt_length = len(tokenizer.encode(prompt, add_special_tokens=True))
        if final_prompt_length > allowed_prompt_length:
             logger.warning(f"Final prompt length ({final_prompt_length}) exceeds allowed ({allowed_prompt_length}) before generation.")


        # Use a pipeline for generation
        try:
            model.eval()
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model.to(device)

            gen_pipe = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                device=0 if device == "cuda" else -1,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
            )

            gen_output = gen_pipe(
                prompt,
                max_new_tokens=max_new_tokens_gen,
                num_return_sequences=1,
                do_sample=do_sample,
                temperature=0.7 if do_sample else 1.0,
                top_k=50 if do_sample else None,
                top_p=0.95 if do_sample else None,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                return_full_text=False
            )

            if gen_output and isinstance(gen_output, list) and len(gen_output) > 0 and 'generated_text' in gen_output[0]:
                 generated_text = gen_output[0]['generated_text'].strip()
            else:
                 generated_text = ""
                 logger.warning(f"Pipeline returned unexpected output for prompt starting '{prompt[:50]}...'. Output: {gen_output}")

            if torch.cuda.is_available():
                 torch.cuda.empty_cache()

            final_answer = generated_text

            logger.info(f"Generated answer: {final_answer}")
            return final_answer

        except Exception as e:
            logger.error(f"Error during generation: {e}")
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            return "Error: Could not generate answer."



5. Main

In [None]:
if __name__ == "__main__":
    # Define file paths
    train_json_file = 'train_data.json'
    dev_json_file = 'dev_data.json'
    retriever_output_dir = "output/retriever-model"
    generator_output_dir = "output/generator-finetuned"

    # Ensure output directories exist
    os.makedirs(retriever_output_dir, exist_ok=True)
    os.makedirs(generator_output_dir, exist_ok=True)

    # 1. Data Processing
    logger.info("--- Starting Data Processing ---")
    documents_df, retriever_train_df, retriever_test_df, generator_train_df, generator_dev_df = process_squad_json(train_json_file, dev_json_file)

    if documents_df.empty or retriever_train_df.empty or generator_train_df.empty:
        logger.error("Data processing failed or resulted in empty essential DataFrames. Exiting.")
        sys.exit(1)

    logger.info("--- Data Processing Complete ---")

    # 2. Train Retriever
    logger.info("--- Starting Retriever Training ---")
    retriever = Retriever(retriever_model_name="all-mpnet-base-v2", reranker_model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
    retriever.load_corpus(documents_df)
    retriever.fit(retriever_train_df,
                  negative_samples=3,
                  eval_ratio=0.1, # Use a portion of train data for retriever eval during training
                  epochs=1, # Adjust epochs
                  evaluation_steps=500, # Evaluate during training
                  warmup_steps=200,
                  output_path=retriever_output_dir)

    # Evaluate Retriever on the dedicated test set (dev data)
    if not retriever_test_df.empty:
        logger.info("Evaluating trained Retriever on test data...")
        retriever_eval_metrics = retriever.evaluate(retriever_test_df,
                                                    top_k=10,
                                                    metrics_k=5, # Evaluate Recall/Precision/MRR at 5
                                                    initial_retrieval_k=100,
                                                    use_reranking=True)
        logger.info(f"Retriever test evaluation metrics: {retriever_eval_metrics}")
    else:
        logger.warning("Retriever test data is empty. Skipping final retriever evaluation.")

    logger.info("--- Retriever Training Complete ---")

    # Clean up Retriever model from GPU memory if needed before loading Generator
    if torch.cuda.is_available():
        # Ensure the model is on CPU before deleting if it was moved to GPU
        retriever.retriever_model.cpu()
        if hasattr(retriever, 'reranker_model'):
             try: retriever.reranker_model.model.to('cpu') # Access the underlying model if possible
             except: pass # Ignore if it fails
        del retriever
        torch.cuda.empty_cache()
        # torch.cuda.synchronize()


    # 3. Train Generator
    logger.info("--- Starting Generator Training ---")
    # You might need to log in to Hugging Face if using a gated model like Llama 3
    # from huggingface_hub import login
    # login() # Or set HUGGING_FACE_HUB_TOKEN environment variable

    generator = Generator(model_name="meta-llama/Meta-Llama-3-8B-Instruct") # Requires accepting terms on HF
    generator.build_doc_dict(documents_df) # Generator needs doc_dict to get context text for training prompts

    # Prepare and tokenize generator training data
    logger.info("Preparing and tokenizing generator training data...")
    generator.prepare_training_data(generator_train_df)
    # Adjust max_length based on your GPU memory and Llama3 context window (8192)
    max_seq_length_for_training = 1024 # Example, adjust as needed
    generator.tokenize_training_data(max_length=max_seq_length_for_training)

    # Train the Generator model (LoRA fine-tuning)
    logger.info("Starting Generator training...")
    generator.train_model(
        output_dir=generator_output_dir,
        num_train_epochs=1, # Adjust epochs
        batch_size=1, # Adjust batch size per device
        gradient_accumulation_steps=8, # Accumulate gradients
        logging_steps=100,
        learning_rate=5e-5
    )

    # Evaluate Generator on the dedicated dev set (using ground truth contexts)
    if not generator_dev_df.empty:
         logger.info("Evaluating trained Generator on dev data (using ground truth contexts)...")
         # The evaluate_generator method expects relevant_passage_ids in the eval_df
         eval_metrics, generated_answers_eval, ground_truth_answers_eval = generator.evaluate_generator(
             generator_dev_df,
             max_eval_samples=500, # Evaluate on a subset
             max_new_tokens=100,
             do_sample=False
         )
         logger.info(f"Generator dev evaluation metrics: {eval_metrics}")
         # Optionally save evaluation results if needed
    else:
        logger.warning("Generator dev data is empty. Skipping generator evaluation.")


    logger.info("--- Generator Training Complete ---")

    # Clean up Generator model from GPU memory if needed
    if torch.cuda.is_available():
        # Ensure the model is on CPU before deleting
        generator.model.to('cpu')
        if hasattr(generator, 'peft_model'): # If PEFT model exists
            try: generator.peft_model.to('cpu')
            except: pass
        del generator
        torch.cuda.empty_cache()
        # torch.cuda.synchronize()


    # 4. Encapsulation and Inference
    logger.info("--- Setting up RAG System for Inference ---")

    # Instantiate the RAGSystem with paths to the trained models
    # Note: The RAGSystem loads the base models and then the adapters/checkpoints.
    # Ensure the paths point to the directories containing the saved models/adapters.
    rag_system = RAGSystem(
        retriever_model_path=retriever_output_dir, # Path to trained retriever model
        generator_model_path=generator_output_dir, # Path to trained generator LoRA adapters
        retriever_base_model="all-mpnet-base-v2", # Base models are needed for loading
        reranker_base_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
        generator_base_model="meta-llama/Meta-Llama-3-8B-Instruct"
    )

    # Load the corpus into the RAGSystem
    rag_system.load_corpus(documents_df)

    # Precompute retriever embeddings for faster inference
    logger.info("Precomputing retriever embeddings for RAG inference...")
    rag_system.precompute_retriever_embeddings(batch_size=1024) # Adjust batch size


    # --- Example Usage: Answer a question ---
    logger.info("--- Testing RAG System with Example Questions ---")

    example_questions = [
        "What is the capital of France?",
        "When was the first manned mission to the Moon?",
        "Tell me about the history of the internet.",
        "What is the airspeed velocity of an unladen swallow?", # Example of potentially unanswerable question
        "Who was the first president of the United States?" # Example of a question likely answerable from SQuAD
    ]

    for question in example_questions:
        logger.info(f"\nQuestion: {question}")
        answer = rag_system.answer_question(
            question,
            top_k_retrieval=5, # Retrieve top 5 passages
            initial_retrieval_k=50, # Initially retrieve 50 before reranking
            use_reranking=True, # Use reranking
            max_new_tokens=100, # Max tokens for generator
            do_sample=False # Use greedy decoding for predictable answers
        )
        logger.info(f"Answer: {answer}")
        print("-" * 50) # Separator


    logger.info("--- RAG System Setup and Testing Complete ---")



2025-04-28 03:06:42,844 - INFO - --- Starting Data Processing ---


INFO:__main__:--- Starting Data Processing ---


2025-04-28 03:06:42,845 - INFO - Loading and processing SQuAD 2.0 data from local files: train_data.json, dev_data.json


INFO:__main__:Loading and processing SQuAD 2.0 data from local files: train_data.json, dev_data.json


2025-04-28 03:06:44,036 - INFO - Building corpus and context map...


INFO:__main__:Building corpus and context map...


Processing articles (train corpus):   0%|          | 0/442 [00:00<?, ?it/s]

Processing articles (dev corpus):   0%|          | 0/35 [00:00<?, ?it/s]

2025-04-28 03:06:44,103 - INFO - Created corpus with 20233 unique passages.


INFO:__main__:Created corpus with 20233 unique passages.


2025-04-28 03:06:44,104 - INFO - Processing questions for Retriever and Generator DataFrames...


INFO:__main__:Processing questions for Retriever and Generator DataFrames...


Processing questions (train):   0%|          | 0/442 [00:00<?, ?it/s]

Processing questions (dev):   0%|          | 0/35 [00:00<?, ?it/s]

2025-04-28 03:06:44,955 - INFO - Created retriever_train_df with 130319 questions.


INFO:__main__:Created retriever_train_df with 130319 questions.


2025-04-28 03:06:44,956 - INFO - Created retriever_test_df with 11873 questions.


INFO:__main__:Created retriever_test_df with 11873 questions.


2025-04-28 03:06:44,958 - INFO - Created generator_train_df with 130319 questions.


INFO:__main__:Created generator_train_df with 130319 questions.


2025-04-28 03:06:44,959 - INFO - Created generator_dev_df with 11873 questions.


INFO:__main__:Created generator_dev_df with 11873 questions.


2025-04-28 03:06:44,989 - INFO - --- Data Processing Complete ---


INFO:__main__:--- Data Processing Complete ---


2025-04-28 03:06:44,990 - INFO - --- Starting Retriever Training ---


INFO:__main__:--- Starting Retriever Training ---


2025-04-28 03:06:44,991 - INFO - Initializing retriever model: all-mpnet-base-v2


INFO:__main__:Initializing retriever model: all-mpnet-base-v2
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

2025-04-28 03:06:59,792 - INFO - Initializing reranker model: cross-encoder/ms-marco-MiniLM-L-6-v2


INFO:__main__:Initializing reranker model: cross-encoder/ms-marco-MiniLM-L-6-v2


config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

2025-04-28 03:07:07,316 - INFO - Loading corpus documents...


INFO:__main__:Loading corpus documents...


2025-04-28 03:07:07,331 - INFO - Corpus size: 20233 documents.


INFO:__main__:Corpus size: 20233 documents.


2025-04-28 03:07:07,332 - INFO - Preparing training data from DataFrame (130319 rows)...


INFO:__main__:Preparing training data from DataFrame (130319 rows)...


Preparing data:   0%|          | 0/130319 [00:00<?, ?it/s]

2025-04-28 03:16:40,633 - INFO - Total training examples: 521276


INFO:__main__:Total training examples: 521276


2025-04-28 03:16:40,634 - INFO - Eval queries: 13031; corpus size: 20233


INFO:__main__:Eval queries: 13031; corpus size: 20233


2025-04-28 03:16:40,640 - INFO - Training retriever model on 521276 examples...


INFO:__main__:Training retriever model on 521276 examples...


2025-04-28 03:16:40,642 - INFO - Setting up Information Retrieval Evaluator...


INFO:__main__:Setting up Information Retrieval Evaluator...


2025-04-28 03:16:40,660 - INFO - Starting training...


INFO:__main__:Starting training...


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzbw321[0m ([33mzbw321-boston-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Cosine Accuracy@1,Cosine Accuracy@3,Cosine Accuracy@5,Cosine Accuracy@10,Cosine Precision@1,Cosine Precision@3,Cosine Precision@5,Cosine Precision@10,Cosine Recall@1,Cosine Recall@3,Cosine Recall@5,Cosine Recall@10,Cosine Ndcg@10,Cosine Mrr@10,Cosine Map@100
500,2.4561,No log,0.529123,0.706162,0.767401,0.838155,0.529123,0.235387,0.15348,0.083816,0.529123,0.706162,0.767401,0.838155,0.681649,0.631767,0.637175
1000,2.3032,No log,0.519684,0.699639,0.764101,0.833474,0.519684,0.233213,0.15282,0.083347,0.519684,0.699639,0.764101,0.833474,0.674405,0.623685,0.629384
1500,2.3224,No log,0.50165,0.681989,0.750058,0.824035,0.50165,0.22733,0.150012,0.082403,0.50165,0.681989,0.750058,0.824035,0.659896,0.607683,0.613481
2000,2.3035,No log,0.514159,0.690124,0.753895,0.825033,0.514159,0.230041,0.150779,0.082503,0.514159,0.690124,0.753895,0.825033,0.667171,0.616882,0.622668
2500,2.3099,No log,0.522216,0.701404,0.764561,0.833704,0.522216,0.233801,0.152912,0.08337,0.522216,0.701404,0.764561,0.833704,0.676141,0.62587,0.631448
3000,2.3091,No log,0.50165,0.685289,0.748983,0.821042,0.50165,0.22843,0.149797,0.082104,0.50165,0.685289,0.748983,0.821042,0.659331,0.607769,0.613596
3500,2.2942,No log,0.518686,0.696493,0.756427,0.825877,0.518686,0.232164,0.151285,0.082588,0.518686,0.696493,0.756427,0.825877,0.670768,0.621285,0.627269
4000,2.3027,No log,0.517151,0.698872,0.758806,0.830021,0.517151,0.232957,0.151761,0.083002,0.517151,0.698872,0.758806,0.830021,0.672296,0.621947,0.627655
4500,2.2968,No log,0.507712,0.691275,0.75566,0.828563,0.507712,0.230425,0.151132,0.082856,0.507712,0.691275,0.75566,0.828563,0.665808,0.613958,0.61973
5000,2.2867,No log,0.53626,0.713299,0.774614,0.843373,0.53626,0.237766,0.154923,0.084337,0.53626,0.713299,0.774614,0.843373,0.688211,0.638716,0.643931


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.94s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.86s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.87s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.83s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.83s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.93s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.87s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.90s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.86s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.89s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.91s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.94s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.78s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.76s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.79s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.84s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.85s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.83s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.74s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.84s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.85s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.78s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.78s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.80s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.77s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.77s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.81s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.72s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.84s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.81s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.78s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.87s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.85s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.82s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.86s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.85s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.87s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.81s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.82s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.86s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.87s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.84s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.88s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.94s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.89s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.91s/it]


Batches:   0%|          | 0/408 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:51<00:00, 51.91s/it]
