# Adaptive RAG system
This project provides a simple implementation of Adaptive RAG pipeline. The pipeline classifies user queries and performs retrieval based on the classification. It also grades the retrieved documents and performs iterative retrieval if required.

## Generate API Key

### Gemini Flash
To use the ChatGoogleGenerativeAI model, you need an API key. You can generate your API key by following these steps:

1. Go to [Google AI Studio API Key Generation](https://aistudio.google.com/app/apikey).
2. Follow the instructions to generate your API key.

### Jina Embedding
1. Go to [Jina Embeddings](https://jina.ai/embeddings/)
2. From the API key and Billing section, copy the API key

## Install Poppler and Tesseract
For handling PDFs and unstructured data, you will need to install Poppler and Tesseract. Follow the installation instructions below:

* Poppler: [Installation Instructions](https://pdf2image.readthedocs.io/en/latest/installation.html)
* Tesseract: [Installation Instructions](https://tesseract-ocr.github.io/tessdoc/Installation.html)

#### For Colab or Ubuntu
Run the following commands in your terminal:

```
!sudo apt-get install poppler-utils
!sudo apt install tesseract-ocr
```

This below installation section already contains the above commands. <br>
**Note: During Colab installation, a message will pop-up "Restart required". Restart the runtime when prompted as this is necessary for installation.**

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SS-Keval/Multimodal-RAG-Meetup/blob/main/adaptive_rag.ipynb)

## Installation

In [None]:
!pip install chromadb langchain langchain_community langchain-chroma langchain-unstructured unstructured "unstructured[all-docs]" "langchain-unstructured[local]" langchain-google-genai

In [None]:
!pip install --upgrade nltk

In [None]:
import nltk
nltk.download('punkt_tab')

In [None]:
!apt-get install poppler-utils

In [None]:
!apt-get install tesseract-ocr

## Constants

In [None]:
# Embeddings model
EMBEDDING_MODEL_NAME = "jina-embeddings-v2-base-en"
JINA_API_KEY = "[ENTER API KEY]"

# LLM
LLM_NAME = "gemini-1.5-flash"
GOOGLE_API_KEY = "[ENTER API KEY]"

# File path
FILE_PATH = "[ENTER FILE PATH]"

## Parsing

In [None]:
from langchain.docstore.document import Document
from unstructured.partition.pdf import partition_pdf
from unstructured.chunking.title import chunk_by_title

In [None]:
raw_pdf_elements = partition_pdf(filename=FILE_PATH)

In [None]:
chunks = chunk_by_title(
  raw_pdf_elements,
  max_characters=2000,
  overlap=100,
  multipage_sections=True,
  new_after_n_chars=1000,
  combine_text_under_n_chars=200
)

In [None]:
docs = []

for chunk in chunks:
  doc = Document(
    page_content=chunk.text,
    metadata={
      "page_no": chunk.metadata.page_number
    }
  )
  docs.append(doc)

## Indexing

In [None]:
import chromadb
from langchain_chroma import Chroma
from langchain_community.embeddings import JinaEmbeddings

In [None]:
chroma = chromadb.PersistentClient()

In [None]:
ef = JinaEmbeddings(
  jina_api_key=JINA_API_KEY, model_name=EMBEDDING_MODEL_NAME
)

In [None]:
collection = chroma.get_or_create_collection("data")

vectordb = Chroma(
  client=chroma,
  collection_name="data",
  embedding_function=ef
)

In [None]:
for start_index in range(0, len(docs), 32):
  end_index = min(start_index + 32, len(docs))
  chunks = docs[start_index:end_index]

  vectordb.add_documents(chunks)
  print(f"Indexed {start_index+1}-{end_index} chunks")


## LLM

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

In [None]:
llm_client = ChatGoogleGenerativeAI(
  model=LLM_NAME,
  google_api_key=GOOGLE_API_KEY
)

## Adaptive RAG

Tools for query classification and document grading

In [None]:
from typing import Optional

from langchain_core.pydantic_v1 import BaseModel, Field


class QueryClassification(BaseModel):
  """Check whether the query is related to index"""

  is_finance_related: bool = Field(description="True if the query is related to finance domain, else False")


class DocumentGrader(BaseModel):
  """Grade document either as relevant or irrelevant based on user query"""

  is_relevant: bool = Field(description="Whether the document is useful in answering the query")


In [None]:
def _add_context_to_query(query: str, docs: list[Document]) -> str:
  """
  Helper function to add context to retriever query

  Args:
    query (str): Input user query
    docs (list[Document]): Relevant documents

  Returns:
    str: Query appended with the relevant documents
  """
  context = "\n\n".join(doc.page_content for doc in docs)
  query_with_context = query + "\n\n" + context

  return query_with_context

def classify_query(query: str) -> bool:
  """
  Classify the input query as related or unrelated to the index

  Args:
    query (str): Input user query

  Returns:
    bool: True if the query is related to index else False
  """
  classifier = llm_client.with_structured_output(QueryClassification)
  out_class = classifier.invoke(query)

  if out_class and out_class.is_finance_related:
    return True
  else:
    return False

def retriever(vectorstore: Chroma, query: str, docs: list[Document] | None = None) -> list[Document]:
  """
  Fetch documents relevant to the query from the vectorstore
  Relevant documents, if found in previous iteration, are appended to the query

  Args:
    vectorstore (Chroma): VectorDB instance
    query (str): Input query
    docs(list[Document] | None): Relevant documents found in previous iteration

  Returns:
    list[Document]: Documents semantically similar to user query
  """
  retriever_query = query
  if docs:
    retriever_query = _add_context_to_query(query, docs)

  docs = vectorstore.max_marginal_relevance_search(
      query=retriever_query,
      k=5,
      fetch_k=20,
  )

  return docs

def doc_grader(query: str, docs: list[Document]) -> tuple[list[bool], list]:
  """
  Grades the documents fetched by retriever

  Args:
    query (str): Input query
    docs (list[Document]): Documents to be graded

  Returns:
    list[bool]: Grade corresponding to each document
    list: Relevant documents
  """
  grades = []
  relevant_docs = []

  grader = llm_client.with_structured_output(DocumentGrader)
  for doc in docs:
    prompt = f"""Document:
{doc.page_content}

Query:
{query}
"""
    result = grader.invoke(prompt)
    grades.append(result.is_relevant)
    if result.is_relevant:
      relevant_docs.append(doc)

  return grades, relevant_docs

def generate_final_answer(query: str, relevant_docs: list[Document]) -> str:
  """
  Final answer to user query using fetched documents

  Args:
    query (str): Input query
    relevant_docs (list[Document]): Relevant documents fetched by the retriever

  Returns:
    str: Final answer
  """
  context = "\n\n".join(doc.page_content for doc in relevant_docs)
  messages = [
      (
          "system",
          "You are a helpful assistant that answers user queries. You will be provided with the context required to answer the user query. Answer queries using information from the provided context only."
      ),
      (
          "human",
          f"""Query:
{query}

Context:
{context}
"""
      )
  ]

  final_answer = llm_client.invoke(messages)

  return final_answer.content

In [None]:
def run_pipeline(query: str, vectorstore: Chroma) -> str:
  """
  Runs the adaptive rag pipeline for one or more retrievals

  Args:
    query (str): Input query
    vectorstore (Chroma): VectorDB instance

  Returns:
    str: Response to the query
  """
  docs = retriever(vectorstore, query)
  grades, relevant_docs = doc_grader(query, docs)

  print(f"Retrieval grade - {sum(grades)}/{len(grades)}")

  if sum(grades)/len(grades) < 0.75:
    print("Performing iterative retrieval")
    new_docs = retriever(vectorstore, query, relevant_docs)
    relevant_docs.extend(new_docs)

  return generate_final_answer(query, relevant_docs)

In [None]:
def adaptive_rag(query: str) -> str:
  """
  Adaptive RAG pipeline to respond to user query by running 0 or more retrieval steps

  Args:
    query (str): Input query

  Returns:
    str: Response to the query
  """
  query = query.strip()

  if not query:
    return

  is_complex = classify_query(query)
  if is_complex:
    final_answer = run_pipeline(query=query, vectorstore=vectordb)
    print(final_answer)
  else:
    messages = [
        (
            "system",
            "You are a helpful assistant that answers user queries. Be respectful and concise in your responses."
        ),
        (
            "human",
            query
        )
    ]
    final_answer = llm_client.invoke(messages)
    print(final_answer.content)

In [None]:
adaptive_rag("What is the capital of India?")