## Adaptive RAG

## Import necessary libraries

In [1]:
import os
import sys
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate

from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any
from langchain.docstore.document import Document
from langchain.document_loaders import DirectoryLoader
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field

# Load environment variables from a .env file
load_dotenv()

# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
from helper_functions import *
from evaluation.evaluate_rag import *



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


## Define the Query classifier class


In [2]:
class CategoryOption(BaseModel):
    category: str = Field(description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual", example="Factual")


class QueryClassifier:
    def __init__(self, temperature=0.2, model_name="gpt-4o-2024-08-06", max_tokens=4000):
        self.llm = ChatOpenAI(temperature=temperature, model_name=model_name, max_tokens=max_tokens)
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="<Role>You are a query classifier for a RAG system.</Role>\
                <Instructions>Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual.</Instructions>\
                <Query>{query}</Query>\
                <Category>Category:</Category>\
            "
        )
        self.llm_chain = self.prompt | self.llm.with_structured_output(CategoryOption)


    def classify(self, query):
        print("Starting to classify query")
        return self.llm_chain.invoke(query).category


## Define the base retrieval class

In [3]:
class BaseRetrievalStrategy:
    def __init__(self, texts, embedding_model="text-embedding-3-large", chunk_size=1000, chunk_overlap=0, temperature=0.2, model_name="gpt-4o-2024-08-06", max_tokens=4000):
        self.embeddings = OpenAIEmbeddings(model=embedding_model)
        text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        self.documents = text_splitter.create_documents(texts)
        self.db = FAISS.from_documents(self.documents, self.embeddings)
        self.llm = ChatOpenAI(temperature=temperature, model_name=model_name, max_toskens=max_tokens)


    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k=k)

## Define the factual retriever strategy

In [4]:
class RelevantScore(BaseModel):
    score: float = Field(description="The relevance score of the document to the query", example=8.0)

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving factual")
        # Use LLM to enhance the query
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )
        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke(query).content
        print(f'enhande query: {enhanced_query}')

        # Retrieve documents using the enhanced query
        docs = self.db.similarity_search(enhanced_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents
        ranking_prompt = PromptTemplate(
            input_variables=["query", "doc"],
            template="<Role>You are a relevance ranker for a RAG system.</Role>\
                <Task>Rank the relevance of the following document to the query on a scale of 1-10.</Task>\
                <Query>{query}</Query>\
                <Document>{document}</Document>\
                <RelevanceScore>Relevance score:</RelevanceScore>"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(RelevantScore)

        ranked_docs = []
        print("ranking docs")
        for doc in docs:
            input_data = {"query": enhanced_query, "document": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

## Define Analytical Retriever Strategy

In [5]:
class SelectedIndices(BaseModel):
    indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])

class SubQueries(BaseModel):
    sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis", example=["What is the population of New York?", "What is the GDP of New York?"])

class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving analytical")
        # Use LLM to generate sub-queries for comprehensive analysis
        sub_queries_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Generate {k} sub-questions for: {query}"
        )

        sub_queries_chain = sub_queries_prompt | self.llm.with_structured_output(SubQueries)

        input_data = {"query": query, "k": k}
        sub_queries = sub_queries_chain.invoke(input_data).sub_queries
        print(f'sub queries for comprehensive analysis: {sub_queries}')

        all_docs = []
        for sub_query in sub_queries:
            all_docs.extend(self.db.similarity_search(sub_query, k=2))

        # Use LLM to ensure diversity and relevance
        diversity_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="<Role>You are a document selector for a RAG system.</Role>\
                <Task>Select the most diverse and relevant set of {k} documents for the query: '{query}'.</Task>\
                <Documents>{docs}</Documents>\
                <SelectedIndices>Return only the indices of selected documents as a list of integers.</SelectedIndices>"
        )
        diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices_result = diversity_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]

## Define Opinion Retriever Strategy

In [6]:
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        print("retrieving opinion")
        # Use LLM to identify potential viewpoints
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="<Role>You are a viewpoint identifier for a RAG system.</Role>\
                <Task>Identify {k} distinct viewpoints or perspectives on the topic: {query}.</Task>\
                <Query>{query}</Query>\
                <Viewpoints>Return the viewpoints as a list of strings.</Viewpoints>"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        # Use LLM to classify and select diverse opinions
        opinion_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="<Role>You are a document selector for a RAG system.</Role>\
                <Task>Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints.</Task>\
                <Documents>{docs}</Documents>\
                <SelectedIndices>Return the indices of selected documents as a list of integers.</SelectedIndices>"
        )
        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)

        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[int(i)] for i in selected_indices.split() if i.isdigit() and int(i) < len(all_docs)]

## Define Contextual Retriever Strategy

In [7]:
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print("retrieving contextual")
        # Use LLM to incorporate user context into the query
        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="<Role>You are a query reformulator for a RAG system.</Role>\
                <Task>Given the user context: {context}, reformulate the query to best address the user's needs: {query}.</Task>\
                <Query>{query}</Query>\
                <Context>{context}</Context>\
                <ContextualizedQuery>Return the reformulated query as a string.</ContextualizedQuery>"
        )
        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        # Retrieve documents using the contextualized query
        docs = self.db.similarity_search(contextualized_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents considering the user context
        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="<Role>You are a relevance ranker for a RAG system.</Role>\
                <Task>Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10.</Task>\
                <Query>{query}</Query>\
                <Context>{context}</Context>\
                <Document>{document}</Document>\
                <RelevanceScore>Relevance score:</RelevanceScore>"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(RelevantScore)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))


        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, _ in ranked_docs[:k]]

## Define the Adaptive Retriever Class

In [8]:
class AdaptiveRetriever:
    def __init__(self, texts: List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

## Define additional retriever that inherits from langchain BaseRetriever

In [9]:
class StructuredAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: AdaptiveRetriever = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        return self._get_relevant_documents(query)

    def _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        return self._get_relevant_documents(query)

  class StructuredAdaptiveRetriever(BaseRetriever):
  class StructuredAdaptiveRetriever(BaseRetriever):


## Define the Adaptive RAG Class

In [14]:
class AdaptiveRAG:
    def __init__(self, texts: List[str], temperature=0.2, model_name="gpt-4o-2024-08-06", max_tokens=4000):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = StructuredAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        self.llm = ChatOpenAI(temperature=temperature, model_name=model_name, max_tokens=max_tokens)

        # Create a custom prompt
        prompt_template = """<Role>You are a helpful assistant.</Role>\
            <Task>Use the following pieces of context to answer the question at the end.</Task>\
            <Instructions>If you don't know the answer, just say that you don't know, don't try to make up an answer.</Instructions>\
            <Context>{context}</Context>\
            <Question>{question}</Question>\
            <Answer>Answer:</Answer>"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

        # Create the LLM chain
        self.llm_chain = prompt | self.llm


    def answer(self, query: str) -> str:
        docs = self.retriever.invoke(query)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

## Load the documents from the directory

In [11]:
# Load the documents with the type of files: .txt, .md, .pdf, .csv, .json
loader = DirectoryLoader("data", show_progress=True, use_multithreading=True)
documents = loader.load()

print(f"documents length: {len(documents)}")
print(f"documents [0]: {documents[0]}")

  from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4
The PDF <_io.BufferedReader name='data/multi-region-application-architecture.pdf'> contains a metadata field indicating that it should not allow text extraction. Ignoring this field and proceeding. Use the check_extractable if you want to raise an error in this case
100%|██████████| 1/1 [00:14<00:00, 14.69s/it]

documents length: 1
documents [0]: page_content='Multi-Region Application Architecture

AWS Implementation Guide

George Bearden

Eric Quinones

June 2020

Copyright (c) 2020 by Amazon.com, Inc. or its affiliates.

Multi-Region Application Architecture is licensed under the terms of the Apache License Version 2.0 available at https://www.apache.org/licenses/LICENSE-2.0

Amazon Web Services – Multi-Region Application Architecture

June 2020

Table of Contents

About This Guide .................................................................................................................. 3

Overview ................................................................................................................................... 3

Cost ....................................................................................................................................... 3

Architecture Overview ...........................................................................................




## Demonstrate use-case of the method

In [12]:
texts = [doc.page_content for doc in documents]
rag_system = AdaptiveRAG(texts)

                max_toskens was transferred to model_kwargs.
                Please confirm that max_toskens is what you intended.
  adaptive_retriever = AdaptiveRetriever(texts)
                max_toskens was transferred to model_kwargs.
                Please confirm that max_toskens is what you intended.
  adaptive_retriever = AdaptiveRetriever(texts)
                max_toskens was transferred to model_kwargs.
                Please confirm that max_toskens is what you intended.
  adaptive_retriever = AdaptiveRetriever(texts)
                max_toskens was transferred to model_kwargs.
                Please confirm that max_toskens is what you intended.
  adaptive_retriever = AdaptiveRetriever(texts)


# Showcase the four different types of queries

## Factual Result

In [13]:
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")

  docs = self.retriever.get_relevant_documents(query)


RecursionError: maximum recursion depth exceeded

In [25]:
"""
Showcase the four different types of queries
"""

factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")

analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content
print(f"Answer: {contextual_result}")

  warn_deprecated(


clasiffying query
retrieving factual
enhande query: What is the average distance between the Earth and the Sun, and how does this distance vary throughout the year due to the elliptical nature of Earth's orbit?
ranking docs
Answer: I don't know.
clasiffying query
retrieving analytical
sub queries for comprehensive analysis: ["What is the relationship between the Earth's distance from the Sun and its average temperature?", "How do variations in the Earth's orbit (eccentricity) influence seasonal climate changes?", "What role does solar radiation play in the Earth's climate system?", "How have historical changes in the Earth's distance from the Sun impacted past climate events?"]
selected diverse and relevant documents
Answer: I don't know.
clasiffying query
retrieving analytical
sub queries for comprehensive analysis: ['What is the primordial soup theory?', 'What is the panspermia hypothesis?', 'How does the hydrothermal vent hypothesis explain the origin of life?', 'What role do clay m