# Adaptive RAG

In [1]:
import os
from dotenv import load_dotenv
load_dotenv()

os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")

#LLM Initiation

In [2]:
from langchain_groq import ChatGroq

llm = ChatGroq(model="llama3-8b-8192", max_tokens=1000)

#Query Classifier

In [3]:
from pydantic import BaseModel, Field

class categories_opinions(BaseModel):
    category: str = Field(description="The category of the query, the options are: Fatual, Analytical, Opinion, or Contextual", example="Factual")

In [4]:
from langchain_core.prompts import PromptTemplate

class QueryClassifier:
    def __init__(self):
        self.llm = llm
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="""Classify the following query into one of the following categories: Factual, Analytical, Opinion or Contextual.
                \n\nQuery: {query}
                \n\nCategory:
            """
        )
        self.chain = self.prompt | self.llm.with_structured_output(categories_opinions)

    def classify(self, query):
        print("classifying query...")
        return self.chain.invoke(query).category

#Base Retrieval Strategy

In [5]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS

class BaseRetrievalStrategy:
    def __init__(self, text):
        self.embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        self.text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        self.documents = self.text_splitter.create_documents(text)
        self.db = FAISS.from_documents(self.documents, self.embedding)
        self.llm = llm

    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k= k)

#Factual Retrieval strategy

In [6]:
class relevant_score(BaseModel):
    score: float = Field(description="The relevance score of the response to the query", examples=8.0)

In [7]:
class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving factual")

        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )
        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke(query).content
        print(f"enhanced query: {enhanced_query}")

        docs = self.db.similarity_search(enhanced_query, k=k * 2)

        ranking_prompt = PromptTemplate(
            input_variables=["query", "doc"],
            template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)

        ranked_docs = []
        print("ranking docs")
        for doc in docs:
            input_data = {"query": enhanced_query, "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

#Analytical Retriever Strategy

In [8]:
from typing_extensions import List

class SelectedIndices(BaseModel):
    indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])

In [9]:
class SubQueries(BaseModel):
    sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis", 
                                   example=["What is the population of New York?", "What is the GDP of New York?"])

In [10]:
class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving analytical")
        sub_queries_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Generate {k} sub-questions for: {query}"
        )
        sub_queries_chain = sub_queries_prompt | self.llm.with_structured_output(SubQueries)
        input_data = {"query": query, "k": k}
        sub_queries = sub_queries_chain.invoke(input_data).sub_queries
        print(f'sub queries for comprehensive analysis: {sub_queries}')

        all_docs = []
        for sub_query in sub_queries:
            all_docs.extend(self.db.similarity_search(sub_query, k=2))

        diversity_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n
            Return only the indices of selected documents as a list of integers."""
        )
        diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices_result = diversity_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]

#Opinion Retriever Strategy

In [11]:
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        print("retrieving opinion")
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        opinion_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
        )
        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')

        return [all_docs[int(i)] for i in selected_indices.split() if i.isdigit() and int(i) < len(all_docs)]

#Contextual Retriever Strategy

In [12]:
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print("retrieving contextual")
        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
        )
        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        docs = self.db.similarity_search(contextualized_query, k=k*2)

        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, _ in ranked_docs[:k]]


#Adaptive Retriever

In [13]:
from langchain_core.documents import Document

class AdaptiveRetriever:
    def __init__(self, texts: List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

#Additional Retriever

In [14]:
from langchain_core.retrievers import BaseRetriever

class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: AdaptiveRetriever = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        return self.get_relevant_documents(query)

  class PydanticAdaptiveRetriever(BaseRetriever):
  class PydanticAdaptiveRetriever(BaseRetriever):


#Adaptive RAG class

In [15]:
class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        self.llm = llm
        
        # Create a custom prompt
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        # Create the LLM chain
        self.llm_chain = prompt | self.llm

    def answer(self, query: str) -> str:
        docs = self.retriever.get_relevant_documents(query)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

#UseCase

In [16]:
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
    ]
rag_system = AdaptiveRAG(texts) 

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")

analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content
print(f"Answer: {contextual_result}")

  docs = self.retriever.get_relevant_documents(query)


classifying query...
retrieving factual
enhanced query: To enhance this factual query for better information retrieval, I'd suggest rephrasing it to provide more context and specificity. Here's a revised query:

"Please provide the average distance between the Earth and the Sun, including the closest and farthest points in their elliptical orbit, and any notable variations in their distance throughout the year."

This revised query:

1. Adds specificity by asking for the average distance, as well as the closest and farthest points, which will provide a more comprehensive understanding of the Earth-Sun distance.
2. Provides context by mentioning the elliptical orbit, which will help the search engine understand the dynamic nature of the Earth-Sun distance.
3. Asks for notable variations throughout the year, which will allow the search engine to provide information on the changing distance between the Earth and the Sun due to their elliptical orbits.

By rephrasing the query in this way,

BadRequestError: Error code: 400 - {'error': {'message': "schema is not valid JSON Schema for tool relevant_score parameters: jsonschema file:///home/di/params.json compilation failed: '/properties/score/examples' does not validate with https://json-schema.org/draft/2020-12/schema#/allOf/1/$ref/properties/properties/additionalProperties/$dynamicRef/allOf/4/$ref/properties/examples/type: expected array, but got number", 'type': 'invalid_request_error'}}