<!-- ![](https://europe-west1-atp-views-tracker.cloudfunctions.net/working-analytics?notebook=adaptive-retrieval) -->



# Adaptive Retrieval-Augmented Generation (RAG) System

## Overview

This system implements an advanced Retrieval-Augmented Generation (RAG) approach that adapts its retrieval strategy based on the type of query. By leveraging Language Models (LLMs) at various stages, it aims to provide more accurate, relevant, and context-aware responses to user queries.

## Motivation

Traditional RAG systems often use a one-size-fits-all approach to retrieval, which can be suboptimal for different types of queries. Our adaptive system is motivated by the understanding that different types of questions require different retrieval strategies. For example, a factual query might benefit from precise, focused retrieval, while an analytical query might require a broader, more diverse set of information.

## Key Components

1. **Query Classifier**: Determines the type of query (Factual, Analytical, Opinion, or Contextual).

2. **Adaptive Retrieval Strategies**: Four distinct strategies tailored to different query types:
   - Factual Strategy
   - Analytical Strategy
   - Opinion Strategy
   - Contextual Strategy

3. **LLM Integration**: LLMs are used throughout the process to enhance retrieval and ranking.

4. **OpenAI GPT Model**: Generates the final response using the retrieved documents as context.

## Method Details

### 1. Query Classification

The system begins by classifying the user's query into one of four categories:
- Factual: Queries seeking specific, verifiable information.
- Analytical: Queries requiring comprehensive analysis or explanation.
- Opinion: Queries about subjective matters or seeking diverse viewpoints.
- Contextual: Queries that depend on user-specific context.

### 2. Adaptive Retrieval Strategies

Each query type triggers a specific retrieval strategy:

#### Factual Strategy
- Enhances the original query using an LLM for better precision.
- Retrieves documents based on the enhanced query.
- Uses an LLM to rank documents by relevance.

#### Analytical Strategy
- Generates multiple sub-queries using an LLM to cover different aspects of the main query.
- Retrieves documents for each sub-query.
- Ensures diversity in the final document selection using an LLM.

#### Opinion Strategy
- Identifies different viewpoints on the topic using an LLM.
- Retrieves documents representing each viewpoint.
- Uses an LLM to select a diverse range of opinions from the retrieved documents.

#### Contextual Strategy
- Incorporates user-specific context into the query using an LLM.
- Performs retrieval based on the contextualized query.
- Ranks documents considering both relevance and user context.

### 3. LLM-Enhanced Ranking

After retrieval, each strategy uses an LLM to perform a final ranking of the documents. This step ensures that the most relevant and appropriate documents are selected for the next stage.

### 4. Response Generation

The final set of retrieved documents is passed to an OpenAI GPT model, which generates a response based on the query and the provided context.

## Benefits of This Approach

1. **Improved Accuracy**: By tailoring the retrieval strategy to the query type, the system can provide more accurate and relevant information.

2. **Flexibility**: The system adapts to different types of queries, handling a wide range of user needs.

3. **Context-Awareness**: Especially for contextual queries, the system can incorporate user-specific information for more personalized responses.

4. **Diverse Perspectives**: For opinion-based queries, the system actively seeks out and presents multiple viewpoints.

5. **Comprehensive Analysis**: The analytical strategy ensures a thorough exploration of complex topics.

## Conclusion

This adaptive RAG system represents a significant advancement over traditional RAG approaches. By dynamically adjusting its retrieval strategy and leveraging LLMs throughout the process, it aims to provide more accurate, relevant, and nuanced responses to a wide variety of user queries.

<div style="text-align: center;">

<img src="../images/adaptive_retrieval.svg" alt="adaptive retrieval" style="width:100%; height:auto;">
</div>

# Package Installation and Imports

The cell below installs all necessary packages required to run this notebook.


In [1]:
# Install required packages
!pip install faiss-cpu langchain langchain-openai python-dotenv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import os
import sys
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate

from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any, List
from langchain.docstore.document import Document
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain_core.pydantic_v1 import BaseModel, Field

In [None]:
# !pip3 install \
#   langchain==0.2.16 \
#   langchain-core==0.2.38 \
#   langchain-community==0.2.16 \
#   langchain-ollama==0.1.3


### Define the query classifer class

In [None]:
# ----------------------------
# Query Classifier
# ----------------------------
class QueryClassifier:
    def __init__(self, model_name="llama3"):
        self.llm = OllamaLLM(model=model_name)
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template=(
                "Classify the following query into one of these categories: "
                "Factual, Analytical, Opinion, or Contextual.\nQuery: {query}\nAnswer only with the category:"
            )
        )

    def classify(self, query: str) -> str:
        print("Classifying query...")
        formatted_prompt = self.prompt.format(query=query)
        raw_output = self.llm.invoke(formatted_prompt)
        # Extract only the first matching category
        for cat in ["Factual", "Analytical", "Opinion", "Contextual"]:
            if cat.lower() in raw_output.lower():
                return cat
        raise ValueError(f"Could not classify query. LLM output: {raw_output}")

### Define the Base Retriever class, such that the complex ones will inherit from it

In [None]:
# ----------------------------
# Base Retrieval Strategy
# ----------------------------
class BaseRetrievalStrategy:
    def __init__(self, texts: List[str]):
        self.embeddings = OllamaEmbeddings(model="nomic-embed-text")
        text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        self.documents = text_splitter.create_documents(texts)
        self.db = FAISS.from_documents(self.documents, self.embeddings)
        self.llm = OllamaLLM(model="llama3")

    def retrieve(self, query: str, k: int = 4) -> List[Document]:
        return self.db.similarity_search(query, k=k)

### Define Factual retriever strategy

In [None]:
# ----------------------------
# Factual Strategy
# ----------------------------
class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query: str, k: int = 4) -> List[Document]:
        print("Retrieving factual documents...")
        # Step 1: Enhance query
        prompt = f"Enhance this factual query for better information retrieval: {query}"
        enhanced_query = self.llm.invoke(prompt).strip()
        print(f"Enhanced query: {enhanced_query}")

        # Step 2: Retrieve
        docs = self.db.similarity_search(enhanced_query, k=k*2)

        # Step 3: Rank documents
        ranking_prompt_template = "On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
        ranked_docs = []
        for doc in docs:
            input_text = ranking_prompt_template.format(query=enhanced_query, doc=doc.page_content)
            score_str = self.llm.invoke(input_text).strip()
            try:
                score = float(score_str)
            except ValueError:
                score = 0.0
            ranked_docs.append((doc, score))

        # Step 4: Return top-k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

### Define Analytical retriever strategy

In [None]:
# ----------------------------
# Analytical Strategy
# ----------------------------
class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query: str, k: int = 4) -> List[Document]:
        print("Retrieving analytical documents...")
        # Step 1: Generate sub-queries
        prompt = f"Generate {k} sub-questions for: {query}"
        sub_queries_text = self.llm.invoke(prompt).strip()
        sub_queries = [sq.strip() for sq in sub_queries_text.split("\n") if sq.strip()]
        print(f"Sub-queries: {sub_queries}")

        # Step 2: Retrieve for each sub-query
        all_docs = []
        for sq in sub_queries:
            all_docs.extend(self.db.similarity_search(sq, k=2))

        # Step 3: Select top k manually (just take first k for simplicity)
        return all_docs[:k]

### Define Opinion retriever strategy

In [None]:
# ----------------------------
# Opinion Strategy
# ----------------------------
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query: str, k: int = 3) -> List[Document]:
        print("Retrieving opinion documents...")
        # Step 1: Generate viewpoints
        prompt = f"Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        viewpoints_text = self.llm.invoke(prompt).strip()
        viewpoints = [vp.strip() for vp in viewpoints_text.split("\n") if vp.strip()]
        print(f"Viewpoints: {viewpoints}")

        # Step 2: Retrieve for each viewpoint
        all_docs = []
        for vp in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {vp}", k=2))

        # Step 3: Return first k
        return all_docs[:k]

### Define Contextual retriever strategy

In [None]:
# ----------------------------
# Contextual Strategy
# ----------------------------
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query: str, k: int = 4, user_context: str = None) -> List[Document]:
        print("Retrieving contextual documents...")
        context_prompt = f"Given the user context: {user_context or 'No specific context provided'}\nReformulate the query: {query}"
        contextual_query = self.llm.invoke(context_prompt).strip()
        print(f"Contextualized query: {contextual_query}")

        docs = self.db.similarity_search(contextual_query, k=k*2)
        return docs[:k]

### Define the Adapive retriever class

In [None]:
# ----------------------------
# Adaptive Retriever
# ----------------------------
class AdaptiveRetriever:
    def __init__(self, texts: List[str]):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        print(f"Query category: {category}")
        strategy = self.strategies[category]
        return strategy.retrieve(query)

### Define the Adaptive RAG class

In [None]:
# ----------------------------
# Adaptive RAG System
# ----------------------------
class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        self.retriever = AdaptiveRetriever(texts)
        self.llm = OllamaLLM(model="llama3")
        prompt_template = """Use the following context to answer the question.
{context}

Question: {question}
Answer:"""
        self.prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

    def answer(self, query: str) -> str:
        docs = self.retriever.get_relevant_documents(query)
        context_text = "\n".join([doc.page_content for doc in docs])
        input_data = self.prompt.format(context=context_text, question=query)
        return self.llm.invoke(input_data).strip()

### Demonstrate use of this model

In [2]:
# ----------------------------
# Usage Example
# ----------------------------
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
]

rag_system = AdaptiveRAG(texts)

# Factual
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?")
print("Factual answer:", factual_result)

# Analytical
analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?")
print("Analytical answer:", analytical_result)

# Opinion
opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?")
print("Opinion answer:", opinion_result)

# Contextual
contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?")
print("Contextual answer:", contextual_result)


Classifying query...
Query category: Factual
Retrieving factual documents...
Enhanced query: Here's an enhanced version of your factual query:

"What is the average distance between the center of the Earth and the center of the Sun, considering various methods of measurement and accounting for slight variations due to the elliptical shape of both bodies' orbits?"

This revised query aims to provide more specific and accurate information by:

1. Specifying that you're looking for the average distance: This helps to distinguish from other queries that might be asking about the closest or farthest points in their orbits.
2. Providing context about measurement methods: By mentioning various methods of measurement, you're giving search algorithms a better idea of what types of information are relevant to your query (e.g., orbital mechanics, astronomical units, etc.).
3. Accounting for orbit shape: The Earth and Sun's orbits aren't perfect circles; they're elliptical, which means their dista

![](https://europe-west1-rag-techniques-views-tracker.cloudfunctions.net/rag-techniques-tracker?notebook=all-rag-techniques--adaptive-retrieval)