In [None]:
# ----- Useful Links -----
# https://github.com/AkariAsai/self-rag/tree/main?tab=readme-ov-file#updates

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from datasets import load_dataset

import wikipediaapi
import re
import nltk
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

import pandas as pd
import json
from typing import List, Dict, Any

[nltk_data] Downloading package punkt to
[nltk_data]     /cis/home/adesilva/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
def get_wikipedia_article(title, lang="en"):
    """Fetch the Wikipedia article text."""
    wiki_wiki = wikipediaapi.Wikipedia(user_agent='ashwin', language=lang)
    page = wiki_wiki.page(title)
    
    if not page.exists():
        print(f"Page '{title}' does not exist.")
        return None
    
    return page.text

def clean_text(text):
    """Remove unwanted formatting, citations, and excessive whitespace."""
    text = re.sub(r"\[\d+\]", "", text)  # Remove citation references like [1], [2]
    text = re.sub(r"\s+", " ", text).strip()  # Normalize whitespace
    return text

def split_into_passages(text, sentences_per_passage=5):
    """Split text into passages based on sentence count."""
    sentences = sent_tokenize(text)  # Tokenize into sentences
    passages = [
        " ".join(sentences[i : i + sentences_per_passage])
        for i in range(0, len(sentences), sentences_per_passage)
    ]
    return passages

In [10]:
class RAG:
    def __init__(self, language_model="meta-llama/Llama-2-7b-chat-hf", 
                 embedding_model="facebook/contriever",
                 device=None):
        """
        Initialize the RAG system
        """
        # Set device (GPU if available, otherwise CPU)
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        print(f"Using device: {self.device}")
        
        # Load Llama model and tokenizer
        try:
            print(f"Loading Llama model from {language_model}...")
            self.tokenizer = AutoTokenizer.from_pretrained(language_model)
            self.model = AutoModelForCausalLM.from_pretrained(
                language_model,
                torch_dtype=torch.float16,
            ).to(self.device)
            print("Llama model loaded successfully")
        except Exception as e:
            print(f"Error loading Llama model: {str(e)}")
            print("You may need to login with 'huggingface-cli login' and request access to Llama models")
            self.tokenizer = None
            self.model = None
        
        # Load embedding model
        try:
            print(f"Loading embedding model {embedding_model}...")
            self.embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model)
            self.embedding_model = AutoModel.from_pretrained(embedding_model).to(self.device)
            print("Embedding model loaded successfully")
        except Exception as e:
            print(f"Error loading embedding model: {str(e)}")
            self.embedding_model = None

        self.emebeddings = None
        self.passages = None

    def get_article_passages(self, title, passage_length=3):
        article_text = get_wikipedia_article(title)

        if article_text:
            cleaned_text = clean_text(article_text)
            passages = split_into_passages(cleaned_text, sentences_per_passage=5)

            # Display first 3 passages
            for i, passage in enumerate(passages[:3]):
                print(f"Passage {i+1}:\n{passage}\n{'-'*50}")

        self.passages = passages

    def embed(self, text):
        """
        Embed a list of passages using the embedding model.
        """
        if self.embedding_model is None:
            raise ValueError("Embedding model not loaded")
        
        # Tokenize the input
        inputs = self.embedding_tokenizer(text, padding=True, truncation=True, return_tensors="pt")
        inputs.to(self.device)
        
        # Generate embeddings
        with torch.no_grad():
            outputs = self.embedding_model(**inputs).last_hidden_state
        
        mask = inputs["attention_mask"]
        outputs = outputs.masked_fill(~mask[..., None].bool(), 0.)
        embeddings = outputs.sum(dim=1) / mask.sum(dim=1)[..., None]
        
        return embeddings.cpu().numpy()

    def create_embeddings(self):
        """
        Create embeddings for a list of passages.
        """
        if self.embedding_model is None:
            raise ValueError("Embedding model not loaded")

        self.embeddings = self.embed(self.passages)
        
    def retrieve_passages(self, query, top_k=3):
        """
        Retrieve relevant passages based on semantic similarity.
        """
        if self.embedding_model is None:
            raise ValueError("Embedding model not loaded")
        
        # Generate query embedding
        query_embedding = self.embed([query])
        
        # Generate passage embeddings if not provided
        if self.embeddings is None:
            raise ValueError("Passage embeddings not created")
        
        # Calculate similarity scores
        similarities = cosine_similarity(query_embedding, self.embeddings)[0]
        
        # Get top-k passages
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        top_passages = [self.passages[i] for i in top_indices]
        top_scores = [similarities[i] for i in top_indices]
        
        return list(zip(top_passages, top_scores))
    
    def generate_response(self, query, context, max_length=512):
        """
        Generate a response using the Llama model based on query and context.
        """
        if self.model is None or self.tokenizer is None:
            raise ValueError("language model not loaded")
        
        # Prepare prompt with context and query
        prompt = f"""<s>[INST] <<SYS>>
            You are a helpful assistant. Use the following context to answer the question at the end.
            <</SYS>>

            Context:
            {context}

            Question: {query} [/INST]
            """
        
        # Tokenize the prompt
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                inputs.input_ids,
                max_length=inputs.input_ids.shape[1] + max_length,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        # Decode and return the response (excluding the prompt)
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = full_response[len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):].strip()
        
        return response
    
    def rag_response(self, query, top_k=3):
        """
        End-to-end RAG: Retrieve relevant passages and generate a response.
        """
        # Get passages if not provided
        if self.passages is None:
            raise ValueError("passages are not ready")
        
        if self.embeddings is None:
            raise ValueError("embeddings are not ready")
        
        # Retrieve relevant passages
        retrieved_passages = self.retrieve_passages(query, top_k=top_k)
        
        # Combine retrieved passages as context
        context = "\n\n".join([p for p, _ in retrieved_passages])
        
        # Generate response
        response = self.generate_response(query, context)
        
        return response, retrieved_passages

In [11]:
# Initialize the RAG system
rag = RAG(device=1)

Using device: 1
Loading Llama model from meta-llama/Llama-2-7b-chat-hf...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Llama model loaded successfully
Loading embedding model facebook/contriever...
Embedding model loaded successfully


In [12]:
# Get passages from an article
article_title = "Australian Grand Prix"
rag.get_article_passages(article_title)

print(f"Retrieved {len(rag.passages)} passages from '{article_title}'")

# Generate embeddings for passages (can be cached for repeated queries)
rag.create_embeddings()

Passage 1:
The Australian Grand Prix is an annual Formula One motor racing event, taking place in Melbourne, Victoria. The event is contracted to be held at least until 2035. One of the oldest surviving motorsport competitions held in Australia, the Grand Prix has moved frequently with 23 different venues having been used since it was first run at Phillip Island in 1928. The race became part of the Formula One World Championship in 1985. Since 1996, it has been held at the Albert Park Circuit in Melbourne, with the exceptions of 2020 and 2021, when the races were cancelled due to the COVID-19 pandemic.
--------------------------------------------------
Passage 2:
Before that, it was held in Adelaide. History Pre-war While an event called the Australian Grand Prix was staged in 1927 at the grass surface Goulburn Racecourse held as a series of sprints, it is generally accepted that the Australian Grand Prix began as the 100 Miles Road Race held at the Phillip Island road circuit in 1928.

In [13]:
# Process a query
query = "Who is Carlos Sainz Jr.?"

# Retrieve relevant passages
retrieved_passages = rag.retrieve_passages(query, top_k=3)

print("\nRetrieved passages:")
for i, (passage, score) in enumerate(retrieved_passages):
    print(f"\nPassage {i+1} (similarity: {score:.4f}):")
    print(passage)


Retrieved passages:

Passage 1 (similarity: 0.2346):
Several other corners were reprofiled to encourage overtaking, most notably the old turn 13, which was widened to create additional racing lines. Positive camber was also added to allow drivers to carry more speed through the corner. The main straight and pit lane were also redesigned, with the pit lane wall moved two metres closer to the circuit so that the edge of the circuit sat directly next to the wall. The 2022 Grand Prix saw Ferrari's Charles Leclerc achieve his first career grand slam, having started in pole position, set the fastest lap, led every lap, and won the race ahead of Red Bull's Sergio Pérez and Mercedes' George Russell. It was the first grand slam for an individual Ferrari driver since Fernando Alonso's at the 2010 Singapore Grand Prix.

Passage 2 (similarity: 0.2276):
The 2022 edition set a new attendance record at the circuit for the weekend, with a reported 419,114 attendees, including 128,294 on race day; the

In [7]:
# Generate response
context = "\n\n".join([p for p, _ in retrieved_passages])

In [8]:
response = rag.generate_response(query, [context])

print("\nGenerated response:")
print(response)


Generated response:
Carlos Sainz Jr. is a Formula One driver who won the 2024 Australian Grand Prix. He is a Spanish driver who has been competing in the Formula One World Championship since 2015, driving for teams such as Toro Rosso and Ferrari. Sainz Jr. is the son of two-time Formula One World Champion Carlos Sainz Sr. and has had a successful career in the sport, achieving several podium finishes and setting fast lap times throughout his career.


In [14]:
# Alternatively, use the end-to-end method
response, _ = rag.rag_response(query)

print("\nEnd-to-end RAG response:")
print(response)


End-to-end RAG response:
Carlos Sainz Jr. is a Spanish racing driver who won the 2024 Australian Grand Prix.
