RAG Implementation

In [1]:
%pip install transformers datasets faiss-cpu

Note: you may need to restart the kernel to use updated packages.


In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import faiss
import numpy as np
import torch

tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")

# Load a dataset for retrieval
dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')

# Create a FAISS index for efficient retrieval
def embed_texts(texts, tokenizer, model):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        embeddings = model.get_encoder()(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings

texts = dataset['article']
embeddings = embed_texts(texts, tokenizer, model)

index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# Define the RAG pipeline
def rag_pipeline(query, tokenizer, model, index, texts, top_k=5):
    query_embedding = embed_texts([query], tokenizer, model)
    D, I = index.search(query_embedding, top_k)
    retrieved_texts = [texts[i] for i in I[0]]
    
    input_text = query + " " + " ".join(retrieved_texts)
    inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)
    summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
    output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return output

: 

: 

In [None]:
# Example usage
query = "What are the latest advancements in AI?"
output = rag_pipeline(query, tokenizer, model, index, texts)
print(output)