In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pandas as pd


In [6]:
def preprocess_data(dataset):
  """Preprocesses the SQuAD v2 dataset.

  Args:
    dataset: The loaded SQuAD v2 dataset.

  Returns:
    A preprocessed dataset.
  """

  def preprocess_function(examples):
    # Basic preprocessing:
    # - Lowercase questions and contexts
    # - Remove extra spaces
    questions = [q.strip().lower() for q in examples["question"]]
    contexts = [c.strip().lower() for c in examples["context"]]
    answers = [{"answer_start": a["answer_start"], "text": a["text"]} for a in examples["answers"]]
    return {"question": questions, "answers": answers, "context": contexts}

  # Apply preprocessing
  dataset = dataset.map(preprocess_function)

  # Create a DataFrame for efficient handling
  df = pd.DataFrame(dataset['train'])

  # Add a unique identifier for each context
  df['id'] = range(len(df))

  return df


In [7]:
def create_embeddings(df, model_name='all-MiniLM-L6-v2'):
  """Creates embeddings for contexts.

  Args:
    df: The preprocessed dataset as a DataFrame.
    model_name: The name of the SentenceTransformer model.

  Returns:
    A dictionary containing context embeddings and their corresponding IDs.
  """

  model = SentenceTransformer(model_name)
  context_embeddings = model.encode(df['context'].tolist())

  return {'embeddings': context_embeddings, 'ids': df['id'].tolist()}


In [8]:
def create_faiss_index(embeddings):
  """Creates a FAISS index for efficient retrieval.

  Args:
    embeddings: A dictionary containing context embeddings and their corresponding IDs.

  Returns:
    A FAISS index.
  """

  dimension = embeddings['embeddings'].shape[1]
  index = faiss.IndexFlatL2(dimension)
  index.add(embeddings['embeddings'])

  return index


In [9]:
def retrieve_relevant_contexts(query, index, embeddings, top_k=5):
  """Retrieves relevant contexts for a given query.

  Args:
    query: The query string.
    index: The FAISS index.
    embeddings: A dictionary containing context embeddings and their corresponding IDs.
    top_k: The number of top results to return.

  Returns:
    A list of relevant context IDs.
  """

  model = SentenceTransformer('all-MiniLM-L6-v2')
  query_embedding = model.encode([query])[0]
  distances, indices = index.search(query_embedding.reshape(1, -1), top_k)
  relevant_ids = embeddings['ids'][indices[0]]
  return relevant_ids


In [10]:
def load_language_model(model_name='facebook/bart-large-cnn'):
  """Loads a language model for answer generation.

  Args:
    model_name: The name of the language model.

  Returns:
    A tuple of tokenizer and model.
  """

  tokenizer = AutoTokenizer.from_pretrained(model_name)
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
  return tokenizer, model


In [11]:
def generate_answer(query, relevant_contexts, df, tokenizer, model):
  """Generates an answer based on the query and relevant contexts.

  Args:
    query: The query string.
    relevant_contexts: A list of relevant context IDs.
    df: The preprocessed dataset as a DataFrame.
    tokenizer: The tokenizer for the language model.
    model: The language model.

  Returns:
    The generated answer.
  """

  # Retrieve relevant contexts from the DataFrame
  relevant_texts = df[df['id'].isin(relevant_contexts)]['context'].tolist()

  # Combine query and relevant contexts into a single input
  input_text = f"{query} {' '.join(relevant_texts)}"
  input_ids = tokenizer.encode(input_text, return_tensors="pt")

  # Generate answer using the language model
  output = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
  generated_answer = tokenizer.decode(output[0], skip_special_tokens=True)
  return generated_answer


In [12]:
def rag(query, df, index, embeddings, tokenizer, model, top_k=5):
  """Implements the RAG pipeline.

  Args:
    query: The query string.
    df: The preprocessed dataset as a DataFrame.
    index: The FAISS index.
    embeddings: A dictionary containing context embeddings and their corresponding IDs.
    tokenizer: The tokenizer for the language model.
    model: The language model.
    top_k: The number of top results to return.

  Returns:
    The generated answer.
  """

  relevant_ids = retrieve_relevant_contexts(query, index, embeddings, top_k)
  answer = generate_answer(query, relevant_ids, df, tokenizer, model)
  return answer


In [13]:
import pandas as pd
import numpy as np

# Assuming necessary imports and data loading

# Example DataFrame
df = pd.DataFrame({'id': [1, 2, 3], 'context': ['context1', 'context2', 'context3']})

# Example embeddings
embeddings = {'embeddings': np.random.rand(3, 768), 'ids': [1, 2, 3]}

# Example FAISS index
index = faiss.IndexFlatL2(768)
index.add(embeddings['embeddings'])

# Example tokenizer and model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")


In [14]:
query = "Who is the capital of France?"
answer = rag(query, df, index, embeddings, tokenizer, model)
print(answer)

AssertionError: 