# Textbook Chatbot using RAG Pipeline

The goal of this project is to build a RAG (Retrieval Augmented Generation) based on the information provided in the paper: [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401).

#### **Each step can be roughly broken down to:**
- **Retrieval**: Retrieve (search and bring back) information from a source given a query.
- **Augmented**: Using the retrieved information to modify an input to an LLM.
- **Generation**: Generate an output give an input.

#### **Step-by-step building workflow:**
1. Open a PDF document (I'm using the "handbook-of-international-law.pdf" for demo purpose.)
2. Format the text of the PDF (splitting into chunks) for an embedding model.
3. Embed all the chunks (turn the texts into numerical representation which can be stored for later.)
4. Build a retrieval system with vector search to find relevant chunks based on an input query.
5. Create a prompt that incorporates the retrieved texts.
6. Generate an answer to a query based on texts from the PDF.
7. Use a web app interface for easy usage of the system.

The workflow is similar to the outline workflow on the NVIDIA blog: [RAG 101](https://developer.nvidia.com/blog/rag-101-demystifying-retrieval-augmented-generation-pipelines/)

<img src="images/work-flow.png" />

#### **Requirements and Setup**
- Local NVIDIA GPU or Google Colab with access to a GPU.
- Environment setup (as in `requirements.txt`)
- Data Source (a PDF)

> **Note:** I'm using the textbook `handbook-of-international-law.pdf` for demo purpose.

In [None]:
# Perform Google Colab installs (if running in Google Colab)
import os

if "COLAB_GPU" in os.environ:
    print("[INFO] Running in Google Colab, installing requirements.")
    !pip install --upgrade huggingface_hub # Hugging Face model hub
    !pip install -U torch # Torch 2.1.1+ (for efficient sdpa implementation)
    !pip install PyMuPDF # Reading PDFs with Python
    !pip install tqdm # Progress bars
    !pip install sentence-transformers # Embedding models
    !pip install accelerate # Quantization model loading
    !pip install bitsandbytes # Quantizing models (less storage space)
    !pip install flash-attn --no-build-isolation # Faster attention mechanism = faster LLM inference
    !pip install gradio # Web app

##### **Optional:** Hugging Face token for specific models 

In [None]:
from huggingface_hub import login
login()

In [None]:
# Determine the device to run the program
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

## 1. Data Processing and Embedding Creation

`Get the text → Split text into chunks → Embed the chunks → Use the embeddings`

### 1.1 Import PDF Document

In [None]:
import requests

# Get PDF document
path = "handbook-of-international-law.pdf"

# Download the PDF if it doesn't exist
if not os.path.exists(path):
    print("[INFO] PDF does not exist. Downloading...")
    url = "https://meddialogue.eu/wp-content/uploads/2021/04/HANDBOOK-OF-INTERNATIONAL-LAW-2009-Syrian-Network-for-Human-Rights.pdf"
    # Save the PDF
    file = path
    # Send a GET request to the URL
    response = requests.get(url)
    # Check if the download successful
    if response.status_code == 200:
        # Open the file
        with open(file, "wb") as f:
            f.write(response.content)
        print(f"[INFO] The PDF has been saved as '{file}'")
    else:
        print(f"[INFO] Failed to download. Status code: {response.status_code}")

else:
    print(f"[INFO] File '{path}' exists.")

In [None]:
# Get data from the PDF
import fitz
from tqdm.auto import tqdm
def text_formatter(text: str) -> str:
    cleaned_text = text.replace("\n", " ").strip()
    return cleaned_text

# Open PDF and get lines/pages
def open_and_read_pdf(path: str) -> list[dict]:
    pdf = fitz.open(path)
    pages_and_texts = []
    for page_number, page in tqdm(enumerate(pdf)):
        text = page.get_text()
        text = text_formatter(text)
        pages_and_texts.append({"page_number": page_number - 34,  # adjust page numbers since our PDF starts on page 36
                                "page_char_count": len(text),
                                "page_word_count": len(text.split(" ")),
                                "page_sentence_count_raw": len(text.split(". ")),
                                "page_token_count": len(text) / 4,  # 1 token = ~4 chars
                                "text": text})
    return pages_and_texts

pages_and_texts = open_and_read_pdf(path=path)

### 1.2 Splitting texts into chunks

#### 1.2.1 Splitting pages into sentences

In [None]:
from spacy.lang.en import English

# Instantiate sentencizer pipeline
nlp = English()
nlp.add_pipe("sentencizer");

for item in tqdm(pages_and_texts):
    item["sentences"] = list(nlp(item["text"]).sents)
    # Make sure all sentences are strings
    item["sentences"] = [str(sentence) for sentence in item["sentences"]]
    # Count the sentences
    item["page_sentence_count_processed"] = len(item["sentences"])

#### 1.2.2 Chunking the sentences together

In [None]:
# Define split size to turn groups of sentences into chunks
num_sentence_chunk_size = 10

# Splits input list into sublists
def split_list(input_list: list,
               list_size: int) -> list[list[str]]:
    return [input_list[i:i + list_size] for i in range(0, len(input_list), list_size)]

# Loop through pages and texts and split sentences into chunks
for item in tqdm(pages_and_texts):
    item["sentence_chunks"] = split_list(input_list=item["sentences"],
                                         list_size=num_sentence_chunk_size)
    item["num_chunks"] = len(item["sentence_chunks"])

#### 1.2.3 Splitting each chunk into one unique item

In [None]:
import re
# Split each chunk into its item
pages_and_chunks = []
for item in tqdm(pages_and_texts):
    for sentence_chunk in item["sentence_chunks"]:
        chunk_dict = {}
        chunk_dict["page_number"] = item["page_number"]

        # Join the sentences together into a paragraph-like structure
        joined_sentence_chunk = "".join(sentence_chunk).replace("  ", " ").strip()
        # ".A" -> ". A"
        joined_sentence_chunk = re.sub(r'\.([A-Z])', r'. \1', joined_sentence_chunk)
        chunk_dict["sentence_chunk"] = joined_sentence_chunk

        # Get stats about the chunk
        chunk_dict["chunk_char_count"] = len(joined_sentence_chunk)
        chunk_dict["chunk_word_count"] = len([word for word in joined_sentence_chunk.split(" ")])
        chunk_dict["chunk_token_count"] = len(joined_sentence_chunk) / 4 # 1 token = ~4 characters

        pages_and_chunks.append(chunk_dict)

> **Note**: Since there are some pages that contain little to no useful information, it's reasonal to filter them out to reduce the amount of information that needed to be processed.

In [None]:
import pandas as pd

# Get chunks with over 30 tokens
df = pd.DataFrame(pages_and_chunks)
min_token_length = 30
pages_and_chunks_valid = df[df["chunk_token_count"] > min_token_length].to_dict(orient="records")

### 1.3 Embedding the chunks

In [None]:
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2",
                                      device=device)

# Create embeddings one by one on the GPU
for item in tqdm(pages_and_chunks_valid):
    item["embedding"] = embedding_model.encode(item["sentence_chunk"])

In [None]:
# Organize data to a data frame
text_chunks_and_embeddings_df = pd.DataFrame(pages_and_chunks_valid)

In [None]:
# Save the embeddings to a CSV file for later use
output_csv_path = "text_chunks_and_embeddings.csv"
text_chunks_and_embeddings_df.to_csv(output_csv_path, index=False)

# 2. Search and Answer

In [None]:
import numpy as np

# Convert texts and embedding df to a list of dicts
pages_and_chunks = text_chunks_and_embeddings_df.to_dict(orient="records")

# Convert embeddings to torch tensor and send to device (Note: NumPy arrays are float64, torch tensors are float32 by default)
embeddings = torch.tensor(np.array(text_chunks_and_embeddings_df["embedding"].tolist()),
                          dtype=torch.float32).to(device)

### 2.1 Embedding the query
`Embed the query → Perform similarity search between the embedded query and embeddings → Get the top scores`

In [None]:
# Define helper function to print wrapped text
import textwrap

def print_wrapped(text, wrap_length=80):
    wrapped_text = textwrap.fill(text, wrap_length)
    print(wrapped_text)

**Note:** Query is a question or a statement that you want to find the answer to in the document. The query should be a string.

In [None]:
from sentence_transformers import util

def retrieve(query: str,
             embeddings: torch.tensor,
             embedding_model: SentenceTransformer=embedding_model,
             n_results: int = 5):

  # 1. Embed the query to the same numerical space as the text examples
  query_embedding = embedding_model.encode(query, convert_to_tensor=True)

  # 2. Get similarity scores with the dot product
  dot_scores = util.dot_score(a=query_embedding, b=embeddings)[0]

  # 3. Get the top-k results (we'll keep this to 5)
  scores, indices = torch.topk(dot_scores, k=n_results)
  return scores, indices

### 2.2 Getting an LLM

In [None]:
# Hugging Face token for specific models 
from huggingface_hub import login
login()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import is_flash_attn_2_available

# Get GPU memory 
if device == "cuda":
    use_quantization = True
    gpu_memory_gb = round(torch.cuda.get_device_properties(0).total_memory / (2**30))
    if gpu_memory_gb > 19.0:
        use_quantization = False
    else:
        use_quantization = True

# Create quantization config for smaller model loading
# For models that require 4-bit quantization
from transformers import BitsAndBytesConfig
quantization_fig = BitsAndBytesConfig(load_in_4bit=True,
                                      bnb_4bit_use_double_quant=True,
                                      bnb_4bit_quant_type="nf4",
                                      bnb_4bit_compute_dtype=torch.float16)

# Flash Attention 2 for faster inference, default to "sdpa" ("scaled do product attention")
if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
    attn_implementation = "flash_attn_2"
else:
    attn_implementation = "sdpa"

# Pick a model we'd like to use
# Note: The model I'm using required login to Hugging Face CLI to be able to access otherwise an HTTPError or an OSError will be thrown
# You can also substitute with a model that doesn't require login like: "microsoft/Phi-3-mini-4k-instruct"
# For more details, read the README.md file
model_id = "google/gemma-2b-it"
print(f"[INFO] Using model: {model_id}")

# Instantiate tokenizer (tokenizer turns text into numbers ready for the model)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id, trust_remote_code=True)

# Instantiate the model
llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id,
                                                 torch_dtype=torch.float16,
                                                 trust_remote_code=True,
                                                 low_cpu_mem_usage=False,
                                                 quantization_config=quantization_fig if use_quantization else None,
                                                 attn_implementation=attn_implementation)
llm_model = llm_model.to(device)

### 2.3 Augmenting Prompt with Context Items

In [None]:
# Define a prompt formatter
def prompt_formatter(query: str, 
                     context_items: list[dict]) -> str:

    # Join context items into one dotted paragraph
    context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])

    # Create a base prompt with examples to help the model
    # Customizable prompt for the user, should change according to the PDF content
    base_prompt = """
Based on the following legal documents and case studies, please answer the query.
Don't include your thought process, just provide the answer in a clear and comprehensive way.
Strive to emulate the following examples for the ideal answer format:
Example 1:
Query: What are the elements of a contract?
Answer: A valid contract entails several elements: offer, acceptance, consideration, capacity, and legality. An offer signifies a willingness to enter into an agreement, outlining the proposed terms. Acceptance represents a clear and unequivocal agreement to the offer's terms. Consideration refers to the exchange of something of value between the parties, which can be a good or service. Capacity ensures both parties possess the legal authority to form a contract. Legality implies the contract's purpose adheres to the law.
Example 2:
Query: Describe the concept of negligence in tort law.
Answer: In tort law, negligence signifies a failure to exercise reasonable care, resulting in harm to another person or their property. It encompasses four key elements: duty of care, breach of duty, causation, and damages. The duty of care mandates that individuals act with a degree of caution to avoid foreseeable risks to others. A breach of duty occurs when someone fails to uphold this standard of care. Causation establishes a link between the breach of duty and the resulting harm. Damages represent the losses or injuries suffered by the plaintiff due to the defendant's negligence.
Example 3:
Query: Explain the Miranda rights in the United States.
Answer: The Miranda rights, established in the landmark Miranda v. Arizona case, safeguard individuals suspected of criminal activity during custodial interrogation. These rights encompass the right to remain silent, the right to an attorney, and the right to have an attorney present during questioning. If a suspect is not informed of these rights, any statements they make during questioning may be inadmissible in court.
Now replace the placeholders with specific legal documents, case studies, or relevant legal topics, and ask your question.
**{context}** (Replace with specific legal documents, case studies, or relevant legal topics)
**Relevant passages:** <extract relevant passages from the legal context here>
**User query:** {query}
"""

    # Update base prompt with context items and query   
    base_prompt = base_prompt.format(context=context, query=query)

    # Create prompt template for instruction-tuned model
    dialogue_template = [
        {"role": "user",
         "content": base_prompt}
    ]

    # Apply the chat template
    prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                          tokenize=False,
                                          add_generation_prompt=True)
    return prompt

### 2.4 Deploy the Chatbot to Gradio Web App

In [None]:
import gradio as gr
# Define the function for the Gradio interface
# To enable the context items to be returned, set 'only_answer' to False
def ask(query, temperature=0.7, max_new_tokens=512, only_answer=True):
    # Get just the scores and indices of top related results
    scores, indices = retrieve(query=query,
                               embeddings=embeddings)

    # Create a list of context items
    context_items = [pages_and_chunks[i] for i in indices]

    # Add score to context items
    for i, item in enumerate(context_items):
        item["score"] = scores[i].cpu()  # return score back to CPU

    # Format the prompt with context items
    prompt = prompt_formatter(query=query,
                              context_items=context_items)

    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").to(device)

    # Generate an output of tokens
    outputs = llm_model.generate(**input_ids,
                                 temperature=temperature,
                                 do_sample=True,
                                 max_new_tokens=max_new_tokens)

    # Turn the output tokens into text
    output_text = tokenizer.decode(outputs[0])

    # Replace special tokens and unnecessary help message (different from models)
    output_text = output_text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "").replace("Sure, here is the answer to the user query:\n\n", "")
    if only_answer:
        return output_text
    
    return output_text, context_items

# Define the Gradio interface
interface = gr.Interface(
    fn=ask,
    inputs=[
        gr.Textbox(show_label=False, placeholder="Enter your query here..."),
    ],
    outputs="text",
)

# Launch the interface
interface.launch()