<a href="https://colab.research.google.com/github/GabrielJobert/EU-27-trade-network-analysis/blob/main/M%C3%A9moire_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# List of libraries to install
libraries = ['torch', 'transformers', 'datasets', 'torchtext', 'faiss-cpu']

for lib in libraries:
    try:
        # Attempt to install the library
        !pip install {lib}
        print(f'Successfully installed {lib}')
    except Exception as e:
        # Catch and print any errors
        print(f'An error occurred while installing {lib}: {e}')


Successfully installed torch
Successfully installed transformers
Successfully installed datasets
Successfully installed torchtext
Successfully installed faiss-cpu


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import faiss

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

class RAGDatasetHandler(Dataset):
    """
    Class to handle loading, preprocessing, and tokenization of various datasets for RAG model training.
    """

    def __init__(self, dataset_name: str, version: str = None, split: str = 'train',
                 tokenizer_name: str = 'bert-base-uncased', max_length: int = 384,
                 include_negatives: bool = False):
        """
        Initializes the RAGDatasetHandler with dataset name, split, tokenizer, and preprocessing options.

        Args:
            dataset_name (str): Name of the dataset (e.g., 'squad').
            version (str, optional): Specific version of the dataset (e.g., '2.0' for SQuAD 2.0).
            split (str, optional): Dataset split to load ('train', 'validation', etc.).
            tokenizer_name (str, optional): Name of the tokenizer to use (e.g., 'bert-base-uncased').
            max_length (int, optional): Maximum length for tokenized inputs.
            include_negatives (bool, optional): Whether to include negative samples.
        """
        self.dataset_name = dataset_name
        self.version = version
        self.split = split
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length
        self.include_negatives = include_negatives
        self.dataset = self._load_dataset()

    def _load_dataset(self):
        """
        Loads the specified dataset using the Hugging Face datasets library.

        Returns:
            Dataset: The loaded dataset.

        Raises:
            ValueError: If the dataset name is not recognized.
        """
        try:
            if self.dataset_name == 'squad' and self.version == '2.0':
                return load_dataset('squad_v2', split=self.split)
            elif self.dataset_name == 'squad':
                return load_dataset('squad', split=self.split)
            else:
                # Implement loading for other datasets here as needed
                raise ValueError(f"Unsupported dataset: {self.dataset_name} version: {self.version}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves an item from the dataset, tokenizes the context and question,
        and prepares data for the model. Optionally includes negative sampling.

        Args:
            idx (int): Index of the item in the dataset.

        Returns:
            dict: A dictionary containing tokenized inputs for the model.
        """
        # Extract data at the specified index
        data = self.dataset[idx]

        # Tokenize the context (passage)
        passage_inputs = self.tokenizer(
            data['context'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Tokenize the query (question)
        query_inputs = self.tokenizer(
            data['question'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Convert to single tensors
        passage_input_ids = passage_inputs['input_ids'].squeeze()
        passage_attention_mask = passage_inputs['attention_mask'].squeeze()

        query_input_ids = query_inputs['input_ids'].squeeze()
        query_attention_mask = query_inputs['attention_mask'].squeeze()

        # If including negatives, prepare a negative sample
        if self.include_negatives:
            # Find a random negative passage (simple implementation, could be more sophisticated)
            neg_idx = (idx + 1) % len(self.dataset)
            negative_data = self.dataset[neg_idx]
            negative_inputs = self.tokenizer(
                negative_data['context'],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            negative_input_ids = negative_inputs['input_ids'].squeeze()
            negative_attention_mask = negative_inputs['attention_mask'].squeeze()

            return {
                'query_input_ids': query_input_ids,
                'query_attention_mask': query_attention_mask,
                'passage_input_ids': passage_input_ids,
                'passage_attention_mask': passage_attention_mask,
                'negative_input_ids': negative_input_ids,
                'negative_attention_mask': negative_attention_mask
            }

        # Return positive pairs
        return {
            'query_input_ids': query_input_ids,
            'query_attention_mask': query_attention_mask,
            'passage_input_ids': passage_input_ids,
            'passage_attention_mask': passage_attention_mask
        }

    def get_dataloader(self, batch_size: int = 8, shuffle: bool = True):
        """
        Creates a DataLoader for the dataset.

        Args:
            batch_size (int, optional): Number of samples per batch.
            shuffle (bool, optional): Whether to shuffle the data.

        Returns:
            DataLoader: A DataLoader instance for the dataset.
        """
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)

# Example usage
if __name__ == "__main__":
    rag_dataset = RAGDatasetHandler(dataset_name='squad', version='2.0', split='train', include_negatives=True)
    dataloader = rag_dataset.get_dataloader(batch_size=8, shuffle=True)

    # Iterate over the DataLoader
    for batch in dataloader:
        print(batch)
        break


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


{'query_input_ids': tensor([[ 101, 2043, 2106,  ...,    0,    0,    0],
        [ 101, 2044, 3930,  ...,    0,    0,    0],
        [ 101, 2054, 3139,  ...,    0,    0,    0],
        ...,
        [ 101, 2054, 2001,  ...,    0,    0,    0],
        [ 101, 2976, 2375,  ...,    0,    0,    0],
        [ 101, 2339, 2001,  ...,    0,    0,    0]]), 'query_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'passage_input_ids': tensor([[ 101, 2045, 2003,  ...,    0,    0,    0],
        [ 101, 3078, 3348,  ...,    0,    0,    0],
        [ 101, 1996, 3470,  ...,    0,    0,    0],
        ...,
        [ 101, 7862, 2031,  ...,    0,    0,    0],
        [ 101, 2976, 2375,  ...,    0,    0,    0],
        [ 101, 2750, 3278,  ...,    0,    0,    0]]), 'passage_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
  

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from datasets import load_dataset
import faiss
from abc import ABC, abstractmethod
import numpy as np

class Retriever(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def encode_passages(self, passages):
        pass

    @abstractmethod
    def retrieve(self, query, index, passages, top_k):
        pass

class LanguageModel(ABC):
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        if self.device == 'cuda':
            self.model = self.model.half()

    @abstractmethod
    def generate(self, input_text):
        pass

class ContrieverRetriever(Retriever):
    def encode_passages(self, passages, batch_size=16):
        embeddings = []
        for i in range(0, len(passages), batch_size):
            batch = passages[i:i + batch_size]
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                batch_embeddings = self.model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
            embeddings.append(batch_embeddings)
        return np.vstack(embeddings)

    def retrieve(self, query, index, passages, top_k=5):
        query_inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
        with torch.no_grad():
            query_embedding = self.model(**query_inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        distances, indices = index.search(query_embedding, top_k)
        return [passages[i] for i in indices[0]]

class FlanT5LanguageModel(LanguageModel):
    def generate(self, input_text):
        input_ids = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512).input_ids.to(self.device)
        with torch.no_grad():
            output_ids = self.model.generate(input_ids, max_length=256)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

class RAGPipeline:
    def __init__(self, retriever, language_model, passages, index_file='passage_index.faiss'):
        self.retriever = retriever
        self.language_model = language_model
        self.passages = passages
        self.index_file = index_file
        self.index = self._load_index() or self._build_index(passages)

    def _build_index(self, passages, batch_size=16):
        passage_embeddings = self.retriever.encode_passages(passages, batch_size)
        dimension = passage_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(passage_embeddings)
        faiss.write_index(index, self.index_file)
        return index

    def _load_index(self):
        try:
            return faiss.read_index(self.index_file)
        except:
            return None

    def __call__(self, query, top_k=5):
        retrieved_passages = self.retriever.retrieve(query, self.index, self.passages, top_k)
        input_text = " ".join(retrieved_passages[:top_k]) + " " + query
        answer = self.language_model.generate(input_text)
        return answer

if __name__ == "__main__":
    # Load a subset of the dataset with only 10 items
    dataset = load_dataset('squad_v2', split='train[:10]')
    passages = [item['context'] for item in dataset]

    # Instantiate retriever and language model
    retriever = ContrieverRetriever(model_name='facebook/contriever')
    language_model = FlanT5LanguageModel(model_name='google/flan-t5-base')

    # Create the RAG pipeline
    pipeline = RAGPipeline(retriever=retriever, language_model=language_model, passages=passages)

    # Test the pipeline
    query = "What is the capital of France?"
    result = pipeline(query)
    print(f"Query: {query}")
    print(f"Generated Answer: {result}")


Query: What is the capital of France?
Generated Answer: -selling girl groups of all time.
