<a href="https://colab.research.google.com/github/GabrielJobert/Simulation_paper---Effect_of_missing_data_on_K-means_performance---MATH60603A_STATISTICAL_LEARNING/blob/main/M%C3%A9moire_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Librairies

In [None]:
# List of libraries to install
libraries = ['torch', 'transformers', 'datasets', 'torchtext', 'faiss-cpu', 'rank_bm25', 'git+https://github.com/stanford-futuredata/ColBERT.git']

for lib in libraries:
    try:
        # Attempt to install the library
        !pip install {lib}
        print(f'Successfully installed {lib}')
    except Exception as e:
        # Catch and print any errors
        print(f'An error occurred while installing {lib}: {e}')


Successfully installed torch
Successfully installed transformers
Successfully installed datasets
Successfully installed torchtext
Successfully installed faiss-cpu
Successfully installed rank_bm25
Collecting git+https://github.com/stanford-futuredata/ColBERT.git
  Cloning https://github.com/stanford-futuredata/ColBERT.git to /tmp/pip-req-build-q08efapl
  Running command git clone --filter=blob:none --quiet https://github.com/stanford-futuredata/ColBERT.git /tmp/pip-req-build-q08efapl
  Resolved https://github.com/stanford-futuredata/ColBERT.git to commit 85837b6af6f92bbdd13238fbaa6f99f2f073da8e
  Preparing metadata (setup.py) ... [?25l[?25hdone
Successfully installed git+https://github.com/stanford-futuredata/ColBERT.git


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import faiss
from abc import ABC, abstractmethod
from rank_bm25 import BM25Okapi
from transformers import AutoModel, DPRQuestionEncoder, DPRContextEncoder

# Dataset

In [None]:
import os
import requests

def download_obliqa_dataset():
    # Define the URL of the dataset file
    url = "https://raw.githubusercontent.com/RegNLP/ObliQADataset/main/ObliQA_train.json"

    # Define the local path where the file will be saved
    local_path = "/content/ObliQADataset-main/ObliQA_train.json"

    # Create the directory if it doesn't exist
    os.makedirs(os.path.dirname(local_path), exist_ok=True)

    # Download the dataset
    response = requests.get(url)

    # Raise an error if the request was unsuccessful
    response.raise_for_status()

    # Save the file
    with open(local_path, 'wb') as f:
        f.write(response.content)

    print(f"Dataset downloaded successfully and saved to {local_path}")

# Run the function
download_obliqa_dataset()


Dataset downloaded successfully and saved to /content/ObliQADataset-main/ObliQA_train.json


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import json

class DatasetLoader(Dataset):
    """
    General class to handle loading, preprocessing, and tokenization of various datasets for model training.
    Supports multiple datasets and can be extended to add more in the future.
    """

    def __init__(self, dataset_name: str, version: str = None, split: str = 'train',
                 tokenizer_name: str = 'bert-base-uncased', max_length: int = 384,
                 dataset_path: str = None, include_negatives: bool = False):
        """
        Initializes the DatasetLoader with dataset name, split, tokenizer, and preprocessing options.

        Args:
            dataset_name (str): Name of the dataset (e.g., 'squad', 'ms_marco', 'obliqa').
            version (str, optional): Specific version of the dataset (if applicable).
            split (str, optional): Dataset split to load ('train', 'validation', etc.).
            tokenizer_name (str, optional): Name of the tokenizer to use (e.g., 'bert-base-uncased').
            max_length (int, optional): Maximum length for tokenized inputs.
            dataset_path (str, optional): Path to the dataset file (for custom datasets like 'obliqa').
            include_negatives (bool, optional): Whether to include negative samples.
        """
        self.dataset_name = dataset_name
        self.version = version
        self.split = split
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length
        self.include_negatives = include_negatives
        self.dataset_path = dataset_path
        self.dataset = self._load_dataset()

    def _load_dataset(self):
        """
        Loads the specified dataset dynamically based on the dataset name.

        Returns:
            Dataset: The loaded dataset or list of data samples.
        """
        try:
            if self.dataset_name == 'squad':
                return load_dataset('squad_v2' if self.version == '2.0' else 'squad', split=self.split)
            elif self.dataset_name == 'ms_marco':
                return load_dataset('ms_marco', self.version, split=self.split)
            elif self.dataset_name == 'obliqa':
                return self._load_obliqa_dataset(self.dataset_path)
            else:
                raise ValueError(f"Unsupported dataset: {self.dataset_name}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise

    def _load_obliqa_dataset(self, dataset_path: str):
        """
        Loads the ObliQA dataset from a JSON file.

        Args:
            dataset_path (str): Path to the ObliQA dataset file.

        Returns:
            list: A list of loaded dataset samples.
        """
        with open(dataset_path, 'r') as f:
            return json.load(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset, tokenizes the query and passage,
        and prepares data for the model. Optionally includes negative sampling.

        Args:
            idx (int): Index of the item in the dataset.

        Returns:
            dict: A dictionary containing tokenized inputs for the model.
        """
        data = self.dataset[idx]

        if self.dataset_name == 'squad':
            return self._process_squad(data, idx)
        elif self.dataset_name == 'ms_marco':
            return self._process_ms_marco(data, idx)
        elif self.dataset_name == 'obliqa':
            return self._process_obliqa(data, idx)
        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")

    def _process_squad(self, data, idx):
        # Tokenize the context (passage) and question for the SQuAD dataset
        passage_inputs = self.tokenizer(
            data['context'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        query_inputs = self.tokenizer(
            data['question'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        if self.include_negatives:
            neg_idx = (idx + 1) % len(self.dataset)
            negative_data = self.dataset[neg_idx]
            negative_inputs = self.tokenizer(
                negative_data['context'],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            return {
                'query_input_ids': query_inputs['input_ids'].squeeze(),
                'passage_input_ids': passage_inputs['input_ids'].squeeze(),
                'negative_input_ids': negative_inputs['input_ids'].squeeze()
            }

        return {
            'query_input_ids': query_inputs['input_ids'].squeeze(),
            'passage_input_ids': passage_inputs['input_ids'].squeeze()
        }

    def _process_ms_marco(self, data, idx):
        # Tokenize the query and passage for the MS MARCO dataset
        query_inputs = self.tokenizer(
            data['query'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        selected_passages = [
            passage for passage, is_selected in zip(data['passages']['passage_text'], data['passages']['is_selected'])
            if is_selected == 1
        ] or [data['passages']['passage_text'][0]]  # Fallback

        passage_inputs = self.tokenizer(
            selected_passages[0],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'query_input_ids': query_inputs['input_ids'].squeeze(),
            'passage_input_ids': passage_inputs['input_ids'].squeeze()
        }

    def _process_obliqa(self, data, idx):
        # Tokenize the question and passage for the ObliQA dataset
        query_inputs = self.tokenizer(
            data['Question'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        passage_inputs = self.tokenizer(
            data['Passages'][0]['Passage'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'query_input_ids': query_inputs['input_ids'].squeeze(),
            'passage_input_ids': passage_inputs['input_ids'].squeeze(),
            'question_id': data['QuestionID'],
            'group': data['Group']
        }

    def get_dataloader(self, batch_size: int = 8, shuffle: bool = True):
        """
        Creates a DataLoader for the dataset.

        Args:
            batch_size (int, optional): Number of samples per batch.
            shuffle (bool, optional): Whether to shuffle the data.

        Returns:
            DataLoader: A DataLoader instance for the dataset.
        """
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)


In [None]:
if __name__ == "__main__":
    # Example 1: Testing with SQuAD v2.0
    print("Testing SQuAD v2.0 Dataset")
    try:
        squad_dataset_loader = DatasetLoader(
            dataset_name='squad',
            version='2.0',
            split='train',
            include_negatives=True
        )
        squad_dataloader = squad_dataset_loader.get_dataloader(batch_size=8, shuffle=True)

        for batch in squad_dataloader:
            print("SQuAD v2.0 Batch:", batch)
            break  # Print only one batch for testing
    except Exception as e:
        print(f"Error loading or processing SQuAD v2.0: {e}")

    # Example 2: Testing with MS MARCO v2.1
    print("\nTesting MS MARCO v2.1 Dataset")
    try:
        ms_marco_dataset_loader = DatasetLoader(
            dataset_name='ms_marco',
            version='v2.1',
            split='train',
            include_negatives=True
        )
        ms_marco_dataloader = ms_marco_dataset_loader.get_dataloader(batch_size=8, shuffle=True)

        for batch in ms_marco_dataloader:
            print("MS MARCO v2.1 Batch:", batch)
            break  # Print only one batch for testing
    except Exception as e:
        print(f"Error loading or processing MS MARCO v2.1: {e}")

    # Example 3: Testing with ObliQA Dataset
    print("\nTesting ObliQA Dataset")
    try:
        dataset_path = '/content/ObliQADataset-main/ObliQA_train.json'  # Adjust the path to your dataset file
        obliqa_dataset_loader = DatasetLoader(
            dataset_name='obliqa',
            dataset_path=dataset_path,
            include_negatives=True
        )
        obliqa_dataloader = obliqa_dataset_loader.get_dataloader(batch_size=8, shuffle=True)

        for batch in obliqa_dataloader:
            print("ObliQA Batch:", batch)
            break  # Print only one batch for testing
    except Exception as e:
        print(f"Error loading or processing ObliQA: {e}")


Testing SQuAD v2.0 Dataset


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


SQuAD v2.0 Batch: {'query_input_ids': tensor([[ 101, 2043, 2020,  ...,    0,    0,    0],
        [ 101, 1997, 2054,  ...,    0,    0,    0],
        [ 101, 7937, 3514,  ...,    0,    0,    0],
        ...,
        [ 101, 2247, 2007,  ...,    0,    0,    0],
        [ 101, 2054, 4277,  ...,    0,    0,    0],
        [ 101, 2054, 2001,  ...,    0,    0,    0]]), 'passage_input_ids': tensor([[  101,  3199,  2810,  ...,     0,     0,     0],
        [  101,  3025,  2695,  ...,     0,     0,     0],
        [  101, 16295,  1013,  ...,     0,     0,     0],
        ...,
        [  101,  2116,  1997,  ...,     0,     0,     0],
        [  101,  1999,  2494,  ...,     0,     0,     0],
        [  101,  2522, 22145,  ...,     0,     0,     0]]), 'negative_input_ids': tensor([[  101,  3199,  2810,  ...,     0,     0,     0],
        [  101,  3025,  2695,  ...,     0,     0,     0],
        [  101, 16295,  1013,  ...,     0,     0,     0],
        ...,
        [  101,  2116,  1997,  ...,     0,



MS MARCO v2.1 Batch: {'query_input_ids': tensor([[  101,  3517,  3465,  ...,     0,     0,     0],
        [  101,  2515,  1037,  ...,     0,     0,     0],
        [  101,  2054,  9983,  ...,     0,     0,     0],
        ...,
        [  101,  2054,  2003,  ...,     0,     0,     0],
        [  101,  2054,  2003,  ...,     0,     0,     0],
        [  101,  2003, 19857,  ...,     0,     0,     0]]), 'passage_input_ids': tensor([[  101,  1996,  5725,  ...,     0,     0,     0],
        [  101, 28079,  7395,  ...,     0,     0,     0],
        [  101,  2057,  3227,  ...,     0,     0,     0],
        ...,
        [  101,  3078,  8040,  ...,     0,     0,     0],
        [  101,  2004, 10288,  ...,     0,     0,     0],
        [  101,  2619,  2007,  ...,     0,     0,     0]])}

Testing ObliQA Dataset
ObliQA Batch: {'query_input_ids': tensor([[ 101, 2071, 2017,  ...,    0,    0,    0],
        [ 101, 2054, 2024,  ...,    0,    0,    0],
        [ 101, 2054, 2024,  ...,    0,    0,    0]

# Models

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from datasets import load_dataset
import faiss
from abc import ABC, abstractmethod
import numpy as np

class Retriever(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def encode_passages(self, passages):
        pass

    @abstractmethod
    def retrieve(self, query, index, passages, top_k):
        pass

class LanguageModel(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def generate(self, input_text):
        pass

class ContrieverRetriever(Retriever):
    def encode_passages(self, passages, batch_size=16):
        embeddings = []
        for i in range(0, len(passages), batch_size):
            batch = passages[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                batch_embeddings = self.model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
            embeddings.append(batch_embeddings)
        return np.vstack(embeddings)

    def retrieve(self, query, index, passages, top_k=5):
        query_inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        with torch.no_grad():
            query_embedding = self.model(**query_inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        distances, indices = index.search(query_embedding, top_k)
        return [passages[i] for i in indices[0]]

class FlanT5LanguageModel(LanguageModel):
    def generate(self, input_text):
        input_ids = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512).input_ids.to(self.device)
        with torch.no_grad():
            output_ids = self.model.generate(input_ids, max_length=256)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

class RAGPipeline:
    def __init__(self, retriever, language_model, passages, index_file='passage_index.faiss'):
        self.retriever = retriever
        self.language_model = language_model
        self.passages = passages
        self.index_file = index_file
        self.index = self._load_index() or self._build_index(passages)

    def _build_index(self, passages, batch_size=16):
        passage_embeddings = self.retriever.encode_passages(passages, batch_size)
        dimension = passage_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(passage_embeddings)
        faiss.write_index(index, self.index_file)
        return index

    def _load_index(self):
        try:
            return faiss.read_index(self.index_file)
        except:
            return None

    def __call__(self, query, top_k=5):
        retrieved_passages = self.retriever.retrieve(query, self.index, self.passages, top_k)
        input_text = " ".join(retrieved_passages[:top_k]) + " " + query
        answer = self.language_model.generate(input_text)
        return answer

if __name__ == "__main__":
    # Load a subset of the dataset with only 10 items
    dataset = load_dataset('squad_v2', split='train[:10]')
    passages = [item['context'] for item in dataset]

    # Instantiate retriever and language model
    retriever = ContrieverRetriever(model_name='facebook/contriever')
    language_model = FlanT5LanguageModel(model_name='google/flan-t5-base')

    # Create the RAG pipeline
    pipeline = RAGPipeline(retriever=retriever, language_model=language_model, passages=passages)

    # Test the pipeline
    query = "What is the capital of France?"
    result = pipeline(query)
    print(f"Query: {query}")
    print(f"Generated Answer: {result}")


Query: What is the capital of France?
Generated Answer: -selling girl groups of all time.


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from datasets import load_dataset
import faiss
from abc import ABC, abstractmethod
import numpy as np
from rank_bm25 import BM25Okapi


# Base Retriever class
class Retriever(ABC):
    def __init__(self, model_name=None):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name) if model_name else None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModel.from_pretrained(model_name).to(self.device) if model_name else None
        if self.device == 'cuda' and self.model:
            self.model = self.model.half()

    @abstractmethod
    def encode_passages(self, passages):
        pass

    @abstractmethod
    def retrieve(self, query, index, passages, top_k):
        pass

# BM25-based Retriever
class BM25Retriever(Retriever):
    def __init__(self):
        super().__init__(None)
        self.bm25 = None

    def encode_passages(self, passages):
        # Tokenize passages for BM25
        tokenized_passages = [passage.split() for passage in passages]
        self.bm25 = BM25Okapi(tokenized_passages)

    def retrieve(self, query, index=None, passages=None, top_k=5):
        # Tokenize the query for BM25
        tokenized_query = query.split()
        scores = self.bm25.get_scores(tokenized_query)
        top_k_indices = np.argsort(scores)[-top_k:][::-1]
        return [passages[i] for i in top_k_indices]

# Contriever-based Retriever (already implemented)
class ContrieverRetriever(Retriever):
    def encode_passages(self, passages, batch_size=16):
        embeddings = []
        for i in range(0, len(passages), batch_size):
            batch = passages[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                batch_embeddings = self.model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
            embeddings.append(batch_embeddings)
        return np.vstack(embeddings)

    def retrieve(self, query, index, passages, top_k=5):
        query_inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        with torch.no_grad():
            query_embedding = self.model(**query_inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        distances, indices = index.search(query_embedding, top_k)
        return [passages[i] for i in indices[0]]

# Language Model base class
class LanguageModel(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def generate(self, input_text):
        pass

# FlanT5-based language model
class FlanT5LanguageModel(LanguageModel):
    def generate(self, input_text):
        input_ids = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512).input_ids.to(self.device)
        with torch.no_grad():
            output_ids = self.model.generate(input_ids, max_length=256)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

# RAG pipeline class
class RAGPipeline:
    def __init__(self, retriever, language_model, passages, index_file='passage_index.faiss'):
        self.retriever = retriever
        self.language_model = language_model
        self.passages = passages
        self.index_file = index_file
        self.index = self._load_index() if isinstance(retriever, ContrieverRetriever) else None
        if not self.index and isinstance(retriever, ContrieverRetriever):
            self.index = self._build_index(passages)

    def _build_index(self, passages, batch_size=16):
        passage_embeddings = self.retriever.encode_passages(passages, batch_size)
        dimension = passage_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(passage_embeddings)
        faiss.write_index(index, self.index_file)
        return index

    def _load_index(self):
        try:
            return faiss.read_index(self.index_file)
        except:
            return None

    def __call__(self, query, top_k=5):
        retrieved_passages = self.retriever.retrieve(query, self.index, self.passages, top_k)
        input_text = " ".join(retrieved_passages[:top_k]) + " " + query
        answer = self.language_model.generate(input_text)
        return answer

if __name__ == "__main__":
    # Load a subset of the dataset with only 10 items
    dataset = load_dataset('squad_v2', split='train[:10]')
    passages = [item['context'] for item in dataset]

    # Select BM25 retriever or Contriever retriever
    use_bm25 = True  # Toggle this to switch retrievers
    if use_bm25:
        retriever = BM25Retriever()
        retriever.encode_passages(passages)
    else:
        retriever = ContrieverRetriever(model_name='facebook/contriever')

    # Instantiate the language model
    language_model = FlanT5LanguageModel(model_name='google/flan-t5-base')

    # Create the RAG pipeline
    pipeline = RAGPipeline(retriever=retriever, language_model=language_model, passages=passages)

    # Test the pipeline
    query = "What is the capital of France?"
    result = pipeline(query)
    print(f"Query: {query}")
    print(f"Generated Answer: {result}")


Query: What is the capital of France?
Generated Answer: -selling girl groups of all time.


In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, DPRQuestionEncoder, DPRContextEncoder
from datasets import load_dataset
import faiss
from abc import ABC, abstractmethod
import numpy as np
from rank_bm25 import BM25Okapi
from colbert.modeling.colbert import ColBERT
from colbert.utils.runs import Run
from colbert.data import Collection, Queries
from colbert.infra import ColBERTConfig
from colbert.utils.utils import print_message


# Base Retriever class
class Retriever(ABC):
    def __init__(self, model_name=None):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name) if model_name else None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModel.from_pretrained(model_name).to(self.device) if model_name else None
        if self.device == 'cuda' and self.model:
            self.model = self.model.half()

    @abstractmethod
    def encode_passages(self, passages):
        pass

    @abstractmethod
    def retrieve(self, query, index, passages, top_k):
        pass

# BM25-based Retriever
class BM25Retriever(Retriever):
    def __init__(self):
        super().__init__(None)
        self.bm25 = None

    def encode_passages(self, passages):
        # Tokenize passages for BM25
        tokenized_passages = [passage.split() for passage in passages]
        self.bm25 = BM25Okapi(tokenized_passages)

    def retrieve(self, query, index=None, passages=None, top_k=5):
        # Tokenize the query for BM25
        tokenized_query = query.split()
        scores = self.bm25.get_scores(tokenized_query)
        top_k_indices = np.argsort(scores)[-top_k:][::-1]
        return [passages[i] for i in top_k_indices]

# Contriever-based Retriever (already implemented)
class ContrieverRetriever(Retriever):
    def encode_passages(self, passages, batch_size=16):
        embeddings = []
        for i in range(0, len(passages), batch_size):
            batch = passages[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                batch_embeddings = self.model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
            embeddings.append(batch_embeddings)
        return np.vstack(embeddings)

    def retrieve(self, query, index, passages, top_k=5):
        query_inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        with torch.no_grad():
            query_embedding = self.model(**query_inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        distances, indices = index.search(query_embedding, top_k)
        return [passages[i] for i in indices[0]]

# DPR-based Retriever
class DPRRetriever(Retriever):
    def __init__(self):
        super().__init__('facebook/dpr-question_encoder-single-nq-base')  # Initialize with the question encoder tokenizer
        self.query_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(self.device)
        self.passage_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(self.device)
        # We also need a tokenizer for encoding the passages and queries
        self.tokenizer = AutoTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

    def encode_passages(self, passages, batch_size=16):
        embeddings = []
        for i in range(0, len(passages), batch_size):
            batch = passages[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                batch_embeddings = self.passage_encoder(**inputs).pooler_output.cpu().numpy()
            embeddings.append(batch_embeddings)
        return np.vstack(embeddings)

    def retrieve(self, query, index, passages, top_k=5):
        query_inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        with torch.no_grad():
            query_embedding = self.query_encoder(**query_inputs).pooler_output.cpu().numpy()
        distances, indices = index.search(query_embedding, top_k)
        return [passages[i] for i in indices[0]]

from transformers import BertModel, BertTokenizer
import torch.nn as nn

class ColBERTRetriever(Retriever):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__(model_name)
        self.model = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

        # Initialize linear layer for ColBERT's token interaction
        self.linear = nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)

    def encode(self, texts):
        inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
        outputs = self.model(**inputs).last_hidden_state
        return self.linear(outputs)  # Apply the linear layer for late interaction

    def retrieve(self, query, documents, top_k=3):
        query_embeds = self.encode(query).mean(dim=1)  # Simplified
        doc_embeds = self.encode(documents).mean(dim=1)

        scores = torch.matmul(query_embeds, doc_embeds.T)
        ranked_indices = torch.argsort(scores[0], descending=True)[:top_k]
        return [documents[i] for i in ranked_indices]


# Language Model base class
class LanguageModel(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def generate(self, input_text):
        pass

# FlanT5-based language model
class FlanT5LanguageModel(LanguageModel):
    def generate(self, input_text):
        input_ids = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512).input_ids.to(self.device)
        with torch.no_grad():
            output_ids = self.model.generate(input_ids, max_length=256)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

# RAG pipeline class
class RAGPipeline:
    def __init__(self, retriever, language_model, passages, index_file='passage_index.faiss'):
        self.retriever = retriever
        self.language_model = language_model
        self.passages = passages
        self.index_file = index_file
        self.index = self._load_index() if isinstance(retriever, (ContrieverRetriever, DPRRetriever, ColBERTRetriever)) else None
        if not self.index and isinstance(retriever, (ContrieverRetriever, DPRRetriever, ColBERTRetriever)):
            self.index = self._build_index(passages)

    def _build_index(self, passages, batch_size=16):
        passage_embeddings = self.retriever.encode_passages(passages, batch_size)
        dimension = passage_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(passage_embeddings)
        faiss.write_index(index, self.index_file)
        return index

    def _load_index(self):
        try:
            return faiss.read_index(self.index_file)
        except:
            return None

    def __call__(self, query, top_k=5):
        retrieved_passages = self.retriever.retrieve(query, self.index, self.passages, top_k)
        input_text = " ".join(retrieved_passages[:top_k]) + " " + query
        answer = self.language_model.generate(input_text)
        return answer


if __name__ == "__main__":
    # Load a subset of the dataset with only 10 items
    dataset = load_dataset('squad_v2', split='train[:10]')
    passages = [item['context'] for item in dataset]

    # Select BM25, DPR, Contriever, or ColBERT retriever
    use_bm25 = False  # Toggle to switch between retrievers
    use_dpr = False   # Toggle to switch between retrievers
    use_colbert = True  # Toggle to use ColBERT retriever

    if use_bm25:
        retriever = BM25Retriever()
        retriever.encode_passages(passages)
    elif use_dpr:
        retriever = DPRRetriever()
    elif use_colbert:
        retriever = ColBERTRetriever(colbert_model_path='bert-base-uncased')  # Provide ColBERT path
    else:
        retriever = ContrieverRetriever(model_name='facebook/contriever')

    # Instantiate the language model
    language_model = FlanT5LanguageModel(model_name='google/flan-t5-base')

    # Create the RAG pipeline
    pipeline = RAGPipeline(retriever=retriever, language_model=language_model, passages=passages)

    # Test the pipeline
    query = "What is the capital of France?"
    result = pipeline(query)
    print(f"Query: {query}")
    print(f"Generated Answer: {result}")



TypeError: Can't instantiate abstract class ColBERTRetriever with abstract method encode_passages