# Corrective RAG (CRAG)
---

### What is Corrective RAG?

Corrective RAG (CRAG) is a methodology that adds a step to the RAG (Retrieval Augmented Generation) strategy to evaluate the documents found during the search process and refine the knowledge. This includes a series of processes to check the search results before generation and, if necessary, perform auxiliary searches to generate high-quality answers.

- Retrieval Grader: Evaluates the relevance of retrieved documents and assigns a score to each document.
- Web Search Integration: If quality of retrieved documents is low, CRAG uses web searches to augment retrieval results. It optimizes search results through query rewriting.

**Reference**

- [Corrective RAG paper](https://arxiv.org/pdf/2401.15884)  

In [None]:
from dotenv import load_dotenv
import os
import json
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.models import VectorizableTextQuery
from azure.ai.evaluation import GroundednessEvaluator, RelevanceEvaluator, RetrievalEvaluator
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage, AssistantMessage
from autogen_core import MessageContext, RoutedAgent, SingleThreadedAgentRuntime, TopicId, message_handler, type_subscription
from pydantic import BaseModel
from typing import List
from dataclasses import dataclass


load_dotenv(override=True)

In [None]:
# Get the environment variables
azure_ai_search_endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
search_credential = AzureKeyCredential(os.getenv("AZURE_AI_SEARCH_API_KEY", "")) if len(os.getenv("AZURE_AI_SEARCH_API_KEY", "")) > 0 else DefaultAzureCredential()
index_name = os.getenv("AZURE_SEARCH_INDEX_NAME", "hotels-sample-index")

azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_key = os.getenv("AZURE_OPENAI_API_KEY", "") if len(os.getenv("AZURE_OPENAI_API_KEY", "")) > 0 else None
azure_openai_chat_deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
azure_openai_embedding_deployment_name = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME", "text-embedding-ada-002")
azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-06-01")

bing_subscription_key = os.getenv("BING_SUBSCRIPTION_KEY", "") if len(os.getenv("BING_SUBSCRIPTION_KEY", "")) > 0 else None

model_config = {
    "azure_endpoint": azure_openai_endpoint,
    "api_key": azure_openai_key,
    "azure_deployment": azure_openai_chat_deployment_name,
    "api_version": azure_openai_api_version,
    "type": "azure_openai",
}

## 🧪 Step 1. Test and Construct each module
---

Before building the entire the graph pipeline, we will test and construct each module separately.

- **SearchClient(Retrieval)**
- **Retrieval Grader**
- **Answer Generator**
- **Question Re-writer**
- **Web Search Tool**

### Construct Retrieval Chain based on PDF
- We use the hotels-sample-index, which can be created in minutes and runs on any search service tier. This index is created by a wizard using built-in sample data.

In [None]:
azure_ai_search_endpoint = os.getenv("AZURE_AI_SEARCH_ENDPOINT")
azure_search_admin_key = os.getenv("AZURE_AI_SEARCH_API_KEY", "")
search_client = SearchClient(
    endpoint=azure_ai_search_endpoint,
    index_name=index_name,
    credential=AzureKeyCredential(azure_search_admin_key),
    semantic_configuration_name='my-semantic-config', 
)

# Query is the question being asked. It's sent to the search engine and the LLM.
query="Can you recommend a few hotels with complimentary breakfast?"

fields = "descriptionVector" # TODO: Check if this is the correct field name
# don't use exhaustive search for large indexes
vector_query = VectorizableTextQuery(text=query, k_nearest_neighbors=2, fields=fields, exhaustive=True)

# Search results are created by the search client.
# Search results are composed of the top 3 results and the fields selected from the search index.
# Search results include the top 3 matches to your query.
search_results = search_client.search(
    search_text=query,
    vector_queries= [vector_query],
    select="Description,HotelName,Tags",
    top=3,
)
sources_formatted = "\n".join([f'{document["HotelName"]}:{document["Description"]}:{document["Tags"]}' for document in search_results])

print(sources_formatted)

### Define your LLM

This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline.

In [None]:
# aoai_client = AzureOpenAI(
#     azure_endpoint=azure_openai_endpoint,
#     api_key=azure_openai_key,
#     api_version=openai_api_version,
# )

# This is not the same object as the one above. This is the client that is used to interact with the Azure OpenAI Chat API.
autogen_aoai_client = AzureOpenAIChatCompletionClient(
    azure_endpoint=azure_openai_endpoint,
    model = azure_openai_chat_deployment_name,
    api_version=azure_openai_api_version,
    api_key=azure_openai_key
)

### Question-Retrieval Grader

Construct a retrieval grader that evaluates the relevance of the retrieved documents to the input question. The retrieval grader should take the input question and the retrieved documents as input and output a relevance score for each document.<br>
Note that the retrieval grader should be able to handle **multiple documents** as input.

In [None]:
retrieval_eval  = RetrievalEvaluator(model_config)

query_response = dict(
    query=query,
    context=sources_formatted
)

relevance_score = retrieval_eval(
    **query_response
)
print(relevance_score)
relevance_score['retrieval']


### Answer Generator

Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. 

We recommend you to use more advanced RAG chain for production

In [None]:
from pydantic import BaseModel
from typing import List

class HotelInfo(BaseModel):
    hotel_name: str
    description: str

class RecommendationList(BaseModel):
    recommendation: List[HotelInfo]

In [None]:
# This prompt provides instructions to the model
GROUNDED_PROMPT="""
You are a friendly assistant that recommends hotels based on activities and amenities.
Answer the query using only the context provided below in a friendly and concise bulleted manner.
Answer ONLY with the facts listed in the list of context below.
If there isn't enough information below, say you don't know.
Generate a response that includes the top 3 results.
Do not generate answers that don't use the context below.
Query: {query}
Context:\n{context}
"""

# Send the search results and the query to the LLM to generate a response based on the prompt.
response = await autogen_aoai_client.create(
        messages = [
        UserMessage(content=GROUNDED_PROMPT.format(query=query, context=sources_formatted), source="user"),
    ],
        extra_create_args={"response_format": RecommendationList},
)

response_content = json.loads(response.content)
for recommendation in response_content['recommendation']:
    print(recommendation)

### Keyword Re-writer

Construct a `keyword_rewriter` agent to rewrite the question as the search keyword.

In [None]:
query="Can you recommend a few hotels with complimentary breakfast?"

# This prompt provides instructions to the model
KEYWORD_REWRITE_PROMPT="""
You a keyword re-writer that converts an input question to a better version that is optimized for search. 
Generate search keyword from a user query 
to be more specific, detailed, and likely to retrieve relevant information, allowing for a more accurate response through web search.
Don't include the additional context from the user question.

Query: {query}
Revised web search query:
"""

# Send the search results and the query to the LLM to generate a response based on the prompt.
response = await autogen_aoai_client.create(
        messages = [
        UserMessage(content=KEYWORD_REWRITE_PROMPT.format(query=query), source="user"),
    ]
)


# Here is the response from the chat model.
print(response.content)

### Web Search Tool

Web search tool is used to enhance the context. <br>

It is used when all the documents do not meet the relevance threshold or the evaluator is not confident.

In [None]:
from azure_genai_utils.tools import BingSearch

WEB_SEARCH_FORMAT_OUTPUT = False

web_search_tool = BingSearch(
    max_results=3,
    locale="en-US",
    include_news=False,
    include_entity=False,
    format_output=WEB_SEARCH_FORMAT_OUTPUT,
)

In [None]:
query = "Newest Openings Hotels in NYC 2024 2025?"
results = web_search_tool.invoke({"query": query})
print(results[0].get("content", "No content"))

<br>

## 🧪 Step 2. Define the Agentic Architecture
- Before building the agentic pipeline, we need to design the message, topic, agent and message routing logic. 
- You should define the terminate condition for the pipeline.

### Message, Topic, Agent Definition

```markdown
```python

# Message Definition
@dataclass
class Message:
    query: str = None
    context: str = None
    response: str = None
    source: str = None


# Topic Definition
user_query_topic_type = "UserQueryTopic"
rewrite_topic_type = "RewriteQueryTopic"
generate_topic_type = "GenerateTopic"
web_search_topic_type = "WebSearchTopic"
user_topic_type = "UserAgent"

# Agent Definition
class RetrievalGraderAgent(RoutedAgent):
class KeywordRewriteAgent(RoutedAgent):
class GenerateAgent(RoutedAgent):
class WebSearchAgent(RoutedAgent):
class UserAgent(RoutedAgent):


```
```

Visualizing the abstract architecture of the pipeline will help you understand the message flow and the agent's role in the pipeline.

In [None]:
from azure_genai_utils.graphs import visualize_agents

agents = [
    "Start",
    "RetrievalGraderAgent",
    "KeywordRewriteAgent",
    "GenerateAgent",
    "WebSearchAgent",
    "UserAgent",
]
interactions = [
    ("Start", "RetrievalGraderAgent"),
    ("RetrievalGraderAgent", "GenerateAgent", "Generates Response"),
    ("RetrievalGraderAgent", "KeywordRewriteAgent", "Rewrites as keyword for bing search"),
    ("KeywordRewriteAgent", "WebSearchAgent"),
    ("WebSearchAgent", "GenerateAgent"),
    ("GenerateAgent", "UserAgent"),
]

visualize_agents(agents, interactions)

This is an example of visualized pipeline

!["corrective-RAG"](../../images/corrective-RAG.png)

In [None]:
@dataclass
class Message:
    query: str = None
    context: str = None
    response: str = None
    source: str = None
    def set_source(self, source: str) -> "Message":
        self.source = source
        return self

# Topic Definition
user_query_topic_type = "UserQueryTopic"
keyword_rewrite_topic_type = "KeywordRewriteAgent"
generate_topic_type = "GenerateAgent"
web_search_topic_type = "WebSearchAgent"
user_topic_type = "UserAgent"

In [None]:

@type_subscription(topic_type=user_query_topic_type)
class RAGGraderAgent(RoutedAgent):

    def __init__(
            self, 
            azure_ai_search_endpoint:str, 
            azure_search_admin_key:str,
            index_name: str,
            retrieval_evaluator: RetrievalEvaluator,
            ) -> None:
        
        super().__init__("RAG Grader Agent")
        self.index_name = index_name
        self.azure_ai_search_endpoint = azure_ai_search_endpoint
        self.azure_search_admin_key = azure_search_admin_key
        self.retrieval_evaluator = retrieval_evaluator

    def config_search(self) -> SearchClient:
        service_endpoint = self.azure_ai_search_endpoint
        key = self.azure_search_admin_key
        index_name = self.index_name
        credential = AzureKeyCredential(key)
        return SearchClient(endpoint=service_endpoint, index_name=index_name, credential=credential)

    async def do_search(self, query: str) -> str:
        """Search indexed data using Azure Cognitive Search with vector-based queries."""
        aia_search_client = self.config_search()

        fields = "descriptionVector" # TODO: Check if this is the correct field name
        # don't use exhaustive search for large indexes
        vector_query = VectorizableTextQuery(text=query, k_nearest_neighbors=1, fields=fields, exhaustive=True)
 
        search_results = aia_search_client.search(  
            search_text=query,  
            vector_queries= [vector_query],
            select=["Description,HotelName,Tags"], #TODO: Check if these are the correct field names
            top=3 #TODO: Check if this is the correct number of results
        )
        answer = "\n".join([f'{document["HotelName"]}:{document["Description"]}:{document["Tags"]}' for document in search_results])  
        return answer
    
    @message_handler
    async def handle_message(self, message: Message, ctx: MessageContext) -> None:
        print(f"\n{'-'*80}\n{self.id.type} received a message:\n")
        context_from_ai_search = await self.do_search(message.query)
        print(context_from_ai_search)

        query_response = dict(
            query=query,
            context=context_from_ai_search
        )

        retrieval_score = self.retrieval_evaluator (
            **query_response
        )

        print(f"retrieval_score: {retrieval_score['retrieval']}")
       
        if(retrieval_score["retrieval"] >= 3.0):
            await self.publish_message(Message(query=query, context=context_from_ai_search, source=message.source), topic_id=TopicId(type=generate_topic_type, source=message.source))
        else:
            await self.publish_message(Message(query=query, context=context_from_ai_search, source=message.source), topic_id=TopicId(type=keyword_rewrite_topic_type, source=message.source))

In [None]:
KEYWORD_REWRITE_PROMPT="""
You a keyword re-writer that converts an input question to a better version that is optimized for search. 
Generate search keyword from a user query 
to be more specific, detailed, and likely to retrieve relevant information, allowing for a more accurate response through web search.
Don't include the additional context from the user question.

Query: {query}
Revised web search query:
"""

@type_subscription(topic_type=keyword_rewrite_topic_type)
class KeywordRewriteAgent(RoutedAgent):
    def __init__(self, model_client: ChatCompletionClient) -> None:
        super().__init__("Query Rewrite Agent")
        self._system_message = SystemMessage(
            content=(
                """
                    You are an helper agent that can rewrite the query.
                """
            )
        )
        self._model_client = model_client

    @message_handler
    async def handle_message(self, message: Message, ctx: MessageContext) -> None:
        print(f"\n{'-'*80}\n{self.id.type} received a message:\n")
        llm_result = await self._model_client.create(
            messages=[self._system_message, 
                        UserMessage(content=KEYWORD_REWRITE_PROMPT.format(query=message.query), source=message.source),
                      ],
            cancellation_token=ctx.cancellation_token,
        )
        response = llm_result.content
        print(response)
        assert isinstance(response, str)
        print(f"{'-'*80}\n{self.id.type}:\n{response}")
        
        await self.publish_message(Message(query=response, context=message.context, source=message.source), topic_id=TopicId(type=web_search_topic_type, source=message.source))

In [None]:
INCORRECT_ANSWER="""
Hello, and thank you for bringing this to our attention! I may have provided an inaccurate or misleading response, and I sincerely apologize for the confusion.
As an AI, I aim to deliver helpful and accurate information, but sometimes I might misinterpret or generate an incorrect response. Your feedback is invaluable and helps me improve.

If you'd like, feel free to share more details or clarify your question, and I’ll do my best to assist you further. Thank you for your understanding and patience! 😊
"""

@type_subscription(topic_type=web_search_topic_type)
class WebSearchAgent(RoutedAgent):

    def __init__(
            self, 
            web_search_tool: BingSearch,
            retrieval_evaluator: RetrievalEvaluator,
            ) -> None:
        
        super().__init__("WebSearch Agent")
        self.web_search_tool = web_search_tool
        self.retrieval_evaluator = retrieval_evaluator

    @message_handler
    async def handle_message(self, message: Message, ctx: MessageContext) -> None:
        print(f"\n{'-'*80}\n{self.id.type} received a message:\n")
        search_results = web_search_tool.invoke({"query": query})
        print(search_results)
        try:
            contents = []
            items = list(search_results)
            for i in range(min(3, len(items))):
                doc = items[i]
                contents.append(doc.get("content", "No content"))
            content = "\n".join(contents)
        except Exception as e:
            print(f"Error: {e}")
            content = "No content"
        
        search_response = dict(
            query=message.query,
            context=content,
            response=message.response
            
        )

        retrieval_score = self.retrieval_evaluator (
            **search_response
        )
        print(f"retrieval_score: {retrieval_score['retrieval']}")
        if(retrieval_score["retrieval"] < 3.0):
            await self.publish_message(AssistantMessage(content=INCORRECT_ANSWER, source=message.source), topic_id=TopicId(type=user_topic_type, source=message.source))

        await self.publish_message(Message(query=message.query, context=content, source=message.source), topic_id=TopicId(type=generate_topic_type, source=message.source))

In [None]:
# This prompt provides instructions to the model
GROUNDED_PROMPT="""
Answer the query using only the context provided below in a friendly and concise bulleted manner.
Answer ONLY with the facts listed in the list of context below.
If there isn't enough information below, say you don't know.
Do not generate answers that don't use the context below.
Query: {query}
Context:\n{context}
"""

@type_subscription(topic_type=generate_topic_type)
class GenerateAgent(RoutedAgent):
    def __init__(self, model_client: ChatCompletionClient) -> None:
        super().__init__("Generate Agent")
        self._system_message = SystemMessage(
            content=(
                """
                    You are a friendly assistant that recommends hotels based on activities and amenities.
                """
            )
        )
        self._model_client = model_client

    @message_handler
    async def handle_message(self, message: Message, ctx: MessageContext) -> None:
        print(f"\n{'-'*80}\n{self.id.type} received a message:\n")
        llm_result = await self._model_client.create(
            messages=[self._system_message, 
                        UserMessage(content=GROUNDED_PROMPT.format(query=message.query, context=message.context), source=message.source),
                      ],
            extra_create_args={"response_format": RecommendationList},
            cancellation_token=ctx.cancellation_token,
        )

        response_content = llm_result.content
        print(response_content)
        await self.publish_message(AssistantMessage(content=response_content, source=message.source), topic_id=TopicId(type=user_topic_type, source=message.source))

In [None]:
@type_subscription(topic_type=user_topic_type)
class UserAgent(RoutedAgent):
    def __init__(self) -> None:
        super().__init__("A user agent that outputs the final copy to the user.")

    @message_handler
    async def handle_final_copy(self, message: AssistantMessage, ctx: MessageContext) -> None:
        print(f"\n{'-'*80}\n{self.id.type} received final copy:\n")
        assert isinstance(message.content, str)
        response_content = json.loads(message.content)
        for recommendation in response_content['recommendation']:
            print(recommendation)

<br>

## 🧪 Step 3. Execute the Workflow
---

### Execute the workflow

In [None]:
runtime = SingleThreadedAgentRuntime()

await RAGGraderAgent.register(runtime, type=user_query_topic_type, factory=lambda: RAGGraderAgent(
    azure_ai_search_endpoint=azure_ai_search_endpoint,
    azure_search_admin_key=azure_search_admin_key,
    index_name=index_name,
    retrieval_evaluator=RetrievalEvaluator(model_config),
    ))

await KeywordRewriteAgent.register(runtime, type=keyword_rewrite_topic_type, factory=lambda: KeywordRewriteAgent(model_client=autogen_aoai_client))

await GenerateAgent.register(runtime, type=generate_topic_type, factory=lambda: GenerateAgent(model_client=autogen_aoai_client))

WEB_SEARCH_FORMAT_OUTPUT = False

await WebSearchAgent.register(runtime, type=web_search_topic_type, factory=lambda: WebSearchAgent(
    web_search_tool=BingSearch(
        max_results=3,
        locale="en-US",
        include_news=False,
        include_entity=False,
        format_output=WEB_SEARCH_FORMAT_OUTPUT,
    ),
    retrieval_evaluator=RetrievalEvaluator(model_config),
    ))


await UserAgent.register(runtime, type=user_topic_type, factory=lambda: UserAgent())



In [None]:
import time

start_time = time.perf_counter()

runtime.start()

await runtime.publish_message(Message(query="Can you recommend the newest Openings Hotels in Manhattan Midtown 2024?", source="User"), topic_id=TopicId(type=user_query_topic_type, source="user"))

await runtime.stop_when_idle()


end_time = time.perf_counter()
print(f"Elapsed time: {end_time - start_time} seconds")