<a href="https://colab.research.google.com/github/MayssenBHA/Corrective-RAG/blob/main/corrective_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Corrective RAG Implementation

This notebook implements a corrective RAG (Retrieval-Augmented Generation) system that:
1. Retrieves relevant documents from a local vector database
2. Grades document relevance using an LLM
3. If documents are not relevant, transforms the query and searches the web
4. Generates final answers using all gathered context

**Tech Stack:**
- **LLM:** Mistral API (free tier available)
- **Vector DB:** ChromaDB (local, no external services needed)
- **Web Search:** Rapid API
- **Orchestration:** LangGraph

## 1. Install Required Dependencies

In [None]:
!pip install langchain langchain-community langchain-mistralai chromadb pypdf2 beautifulsoup4 langgraph pydantic typing-extensions nest-asyncio tenacity pypdf

Collecting langchain-community
  Using cached langchain_community-0.3.29-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain-mistralai
  Using cached langchain_mistralai-0.2.11-py3-none-any.whl.metadata (2.0 kB)
Collecting chromadb
  Using cached chromadb-1.0.20-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Collecting pypdf2
  Using cached pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Collecting langgraph
  Using cached langgraph-0.6.6-py3-none-any.whl.metadata (6.8 kB)
Collecting pypdf
  Using cached pypdf-6.0.0-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-core<1.0.0,>=0.3.72 (from langchain)
  Using cached langchain_core-0.3.75-py3-none-any.whl.metadata (5.7 kB)
Collecting requests<3,>=2 (from langchain)
  Using cached requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting dataclasses-json<0.7,>=0.6.7 (from langchain-community)
  Using cached dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pybase64>=1.4.1 (from 

## 2. Import Required Libraries

In [None]:
import os
import json
import tempfile
import pprint
from typing import Dict, TypedDict, List
import nest_asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
import requests  # Added for RapidAPI

# LangChain imports
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, TextLoader, WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_mistralai import ChatMistralAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langgraph.graph import END, StateGraph

# Enable nested asyncio for Jupyter
nest_asyncio.apply()

print("✅ All libraries imported successfully!")



✅ All libraries imported successfully!


## 3. Configuration and API Keys

Set up your API keys here. Get them from:
- **Mistral API:** https://console.mistral.ai/
- **Rapid API:** https://rapidapi.com/

In [None]:
# API Keys - Replace with your actual keys
MISTRAL_API_KEY = "YOUR MISTRAL API KEY"
RAPIDAPI_KEY = "YOUR RAPID API KEY"

# Document path to load - Update this with your uploaded document path in Colab
DOC_PATH = "/content/2411.15146v1.pdf"

# Validate API keys
if MISTRAL_API_KEY == "your_mistral_api_key_here":
    print("⚠️ Please set your Mistral API key")
else:
    print("✅ Mistral API key configured")

if RAPIDAPI_KEY == "your_rapidapi_key_here":
    print("⚠️ Please set your RapidAPI key")
else:
    print("✅ RapidAPI key configured")

✅ Mistral API key configured
✅ RapidAPI key configured


## 4. Initialize Models and Vector Store

In [None]:
# Initialize Mistral LLM
llm = ChatMistralAI(
    model="mistral-small-latest",
    mistral_api_key=MISTRAL_API_KEY,
    temperature=0,
    max_tokens=1000
)

# Initialize embeddings (using free HuggingFace embeddings)
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

print("✅ Models initialized successfully!")

  embeddings = HuggingFaceEmbeddings(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Models initialized successfully!


## 5. Document Loading and Processing

In [None]:
def load_documents(file_path: str) -> List[Document]:
    """Load documents from local file path."""
    try:
        if not os.path.exists(file_path):
            print(f"❌ File not found: {file_path}")
            return []

        file_extension = os.path.splitext(file_path)[1].lower()

        if file_extension == '.pdf':
            loader = PyPDFLoader(file_path)
        elif file_extension in ['.txt', '.md']:
            loader = TextLoader(file_path, encoding='utf-8')
        else:
            raise ValueError(f"Unsupported file type: {file_extension}")

        return loader.load()
    except Exception as e:
        print(f"Error loading document: {str(e)}")
        return []

# Load and process documents
print("📄 Loading documents...")
print(f"📂 Document path: {DOC_PATH}")

# Check if file exists
if not os.path.exists(DOC_PATH):
    print(f"⚠️ Document not found at {DOC_PATH}")
    print("Please upload your document to Colab and update the DOC_PATH variable")
    docs = []
else:
    docs = load_documents(DOC_PATH)

if docs:
    # Split documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=100
    )
    all_splits = text_splitter.split_documents(docs)

    # Create ChromaDB vectorstore
    vectorstore = Chroma.from_documents(
        documents=all_splits,
        embedding=embeddings,
        persist_directory="./chroma_db"  # Local storage
    )

    # Create retriever
    retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

    print(f"✅ Successfully processed {len(all_splits)} document chunks")
else:
    print("❌ No documents loaded")
    retriever = None

📄 Loading documents...
📂 Document path: /content/2411.15146v1.pdf
✅ Successfully processed 123 document chunks


## 6. Define Graph State

In [None]:
class GraphState(TypedDict):
    """State of the corrective RAG graph."""
    keys: Dict[str, any]

print("✅ Graph state defined")

✅ Graph state defined


## 7. Define RAG Workflow Nodes

### Node 1: Retrieve Documents

In [None]:
def retrieve(state):
    """Retrieve documents based on user question."""
    print("🔍 STEP: Retrieving documents...")
    state_dict = state["keys"]
    question = state_dict["question"]

    if retriever is None:
        print("❌ No retriever available")
        return {"keys": {"documents": [], "question": question}}

    documents = retriever.get_relevant_documents(question)
    print(f"📋 Retrieved {len(documents)} documents")

    return {"keys": {"documents": documents, "question": question}}

### Node 2: Grade Document Relevance

In [None]:
def grade_documents(state):
    """Grade whether retrieved documents are relevant to the question and decide if web search is needed."""
    print("⚖️ STEP: Grading document relevance...")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Grading prompt
    prompt = PromptTemplate(
        template="""You are grading the relevance of a retrieved document to a user question.
Return ONLY a JSON object with a "score" field that is either "yes" or "no".
Do not include any other text or explanation.

Document: {context}
Question: {question}

Rules:
- Check for related keywords or semantic meaning
- Use lenient grading to only filter clear mismatches
- Return exactly like this example: {{"score": "yes"}} or {{"score": "no"}}""",
        input_variables=["context", "question"]
    )

    chain = prompt | llm | StrOutputParser()

    filtered_docs = []
    relevant_count = 0
    relevant_threshold = 2 # Require at least 2 relevant documents to skip web search

    for doc in documents:
        try:
            response = chain.invoke({"question": question, "context": doc.page_content})

            # Extract JSON from response
            import re
            json_match = re.search(r'\{.*\}', response)
            if json_match:
                response = json_match.group()

            score = json.loads(response)

            if score.get("score") == "yes":
                print("✅ Document relevant")
                filtered_docs.append(doc)
                relevant_count += 1
            else:
                print("❌ Document not relevant")

        except Exception as e:
            print(f"⚠️ Error grading document: {str(e)}")
            # On error, keep the document and treat as relevant for safety
            filtered_docs.append(doc)
            relevant_count += 1 # Assume relevant on error to be safe
            continue

    # Decide if web search is needed based on the relevant document count
    search_needed = "Yes" if relevant_count < relevant_threshold else "No"
    print(f"📊 Relevant documents: {relevant_count}/{len(documents)}")
    print(f"🤔 Relevant threshold for skipping web search: {relevant_threshold}")


    return {"keys": {"documents": filtered_docs, "question": question, "run_web_search": search_needed}}

### Node 3: Transform Query (for better web search)

In [None]:
def transform_query(state):
    """Transform the query to produce a better question for web search."""
    print("🔄 STEP: Transforming query for web search...")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Query transformation prompt
    prompt = PromptTemplate(
        template="""Generate a search-optimized version of this question by
analyzing its core semantic meaning and intent.
\n ------- \n
{question}
\n ------- \n
Return only the improved question with no additional text:""",
        input_variables=["question"],
    )

    chain = prompt | llm | StrOutputParser()
    better_question = chain.invoke({"question": question})

    print(f"📝 Original: {question}")
    print(f"🎯 Improved: {better_question}")

    return {"keys": {"documents": documents, "question": better_question}}

### Node 4: Web Search

In [None]:
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def execute_rapidapi_search(query, limit=3):
    url = "https://google-search74.p.rapidapi.com/"
    headers = {
        "x-rapidapi-host": "google-search74.p.rapidapi.com",
        "x-rapidapi-key": RAPIDAPI_KEY
    }
    params = {
        "query": query,
        "limit": limit,
        "related_keywords": "true"
    }
    response = requests.get(url, headers=headers, params=params)
    response.raise_for_status()
    return response.json()

def web_search(state):
    """Web search based on the transformed question using RapidAPI Google Search."""
    print("🌐 STEP: Performing web search via RapidAPI...")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    try:
        # Validate RapidAPI key
        if not RAPIDAPI_KEY or RAPIDAPI_KEY == "your_rapidapi_key_here":
            print("⚠️ RapidAPI key not provided - skipping web search")
            return {"keys": {"documents": documents, "question": question}}

        # Execute search with retry logic
        search_results = execute_rapidapi_search(question, limit=3)

        if not search_results or "results" not in search_results:
            print("⚠️ No search results found")
            return {"keys": {"documents": documents, "question": question}}

        # Process results
        web_results = []
        for result in search_results["results"]:
            content = (
                f"Title: {result.get('title', 'No title')}\n"
                f"Content: {result.get('description', 'No content')}\n"
                f"Link: {result.get('link', '')}\n"
            )
            web_results.append(content)

        # Create document from results
        web_document = Document(
            page_content="\n\n".join(web_results),
            metadata={
                "source": "rapidapi_google_search",
                "query": question,
                "result_count": len(web_results)
            }
        )
        documents.append(web_document)

        print(f"✅ Added {len(web_results)} web search results via RapidAPI")

    except Exception as error:
        print(f"❌ Web search error: {str(error)}")

    return {"keys": {"documents": documents, "question": question}}

### Node 5: Generate Answer

In [None]:
def generate(state):
    """Generate answer using Mistral model."""
    print("✨ STEP: Generating final answer...")
    state_dict = state["keys"]
    question, documents = state_dict["question"], state_dict["documents"]

    try:
        # Create prompt template
        prompt = PromptTemplate(
            template="""Based on the following context, please answer the question.
Context: {context}
Question: {question}
Answer:""",
            input_variables=["context", "question"]
        )

        # Combine all document content
        context = "\n\n".join(doc.page_content for doc in documents)

        # Create and run chain
        rag_chain = (
            {"context": lambda x: context, "question": lambda x: question}
            | prompt
            | llm
            | StrOutputParser()
        )

        generation = rag_chain.invoke({})
        print("✅ Answer generated successfully")

        return {
            "keys": {
                "documents": documents,
                "question": question,
                "generation": generation
            }
        }

    except Exception as e:
        error_msg = f"Error in generate function: {str(e)}"
        print(f"❌ {error_msg}")
        return {
            "keys": {
                "documents": documents,
                "question": question,
                "generation": "Sorry, I encountered an error while generating the response."
            }
        }

### Node 6: Decision Logic

In [None]:
def decide_to_generate(state):
    """Decide whether to generate directly or search the web first."""
    print("🤔 STEP: Deciding next action...")
    state_dict = state["keys"]
    search = state_dict["run_web_search"]

    if search == "Yes":
        print("➡️ Decision: Transform query and run web search")
        return "transform_query"
    else:
        print("➡️ Decision: Generate answer directly")
        return "generate"

## 8. Build the Corrective RAG Workflow Graph

In [None]:
# Create workflow graph
workflow = StateGraph(GraphState)

# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)

# Build graph connections
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile the app
app = workflow.compile()

print("✅ Corrective RAG workflow graph created successfully!")

✅ Corrective RAG workflow graph created successfully!


## 9. Helper Functions for Output Formatting

In [None]:
def format_document(doc: Document) -> str:
    """Format document for display."""
    return f"""
Source: {doc.metadata.get('source', 'Unknown')}
Content: {doc.page_content[:200]}...
"""

def format_state(state: dict) -> str:
    """Format state for pretty printing."""
    formatted = {}

    for key, value in state.items():
        if key == "documents":
            formatted[key] = [format_document(doc) for doc in value]
        else:
            formatted[key] = value

    return formatted

print("✅ Helper functions defined")

✅ Helper functions defined


## 10. Test the Corrective RAG System

In [None]:
# Test question
test_question = "What are the experiment results and ablation studies in this research paper?"

print(f"🎯 Question: {test_question}")
print("="*80)

# Run the corrective RAG pipeline
inputs = {
    "keys": {
        "question": test_question,
    }
}

# Execute workflow and show step-by-step progress
final_result = None
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"\n📋 Step '{key}' completed:")
        print("-" * 40)
        # Show formatted state for this step
        formatted_state = format_state(value["keys"])
        for state_key, state_value in formatted_state.items():
            if state_key == "generation":
                print(f"\n🎯 Final Answer Preview: {state_value[:100]}...")
            elif state_key == "documents":
                print(f"\n📄 Documents: {len(state_value)} items")
            else:
                print(f"\n{state_key}: {state_value}")
        final_result = value

print("\n" + "="*80)
print("🏆 FINAL ANSWER:")
print("="*80)
if final_result and 'generation' in final_result['keys']:
    print(final_result['keys']['generation'])
else:
    print("No final generation produced.")

🎯 Question: What are the experiment results and ablation studies in this research paper?
🔍 STEP: Retrieving documents...
📋 Retrieved 4 documents

📋 Step 'retrieve' completed:
----------------------------------------

📄 Documents: 4 items

question: What are the experiment results and ablation studies in this research paper?
⚖️ STEP: Grading document relevance...


  documents = retriever.get_relevant_documents(question)


✅ Document relevant
✅ Document relevant
✅ Document relevant
✅ Document relevant
📊 Relevant documents: 4/4
🤔 Relevant threshold for skipping web search: 2
🤔 STEP: Deciding next action...
➡️ Decision: Generate answer directly

📋 Step 'grade_documents' completed:
----------------------------------------

📄 Documents: 4 items

question: What are the experiment results and ablation studies in this research paper?

run_web_search: No
✨ STEP: Generating final answer...
✅ Answer generated successfully

📋 Step 'generate' completed:
----------------------------------------

📄 Documents: 4 items

question: What are the experiment results and ablation studies in this research paper?

🎯 Final Answer Preview: The experiment results and ablation studies in this research paper focus on evaluating the performan...

🏆 FINAL ANSWER:
The experiment results and ablation studies in this research paper focus on evaluating the performance of the **TIMBRE** system (or similar model) in job recommendation, part

## 11. Interactive Question-Answer Function

In [None]:
def ask_corrective_rag(question: str, verbose: bool = True) -> str:
    """Ask a question to the corrective RAG system."""
    if not question.strip():
        return "Please provide a valid question."

    inputs = {"keys": {"question": question}}

    final_result = None
    for output in app.stream(inputs):
        for key, value in output.items():
            if verbose:
                print(f"\n📋 Step '{key}' completed")
            final_result = value

    return final_result['keys'].get('generation', 'No answer generated.')

print("✅ Interactive function ready! Use ask_corrective_rag('your question') to ask questions.")

✅ Interactive function ready! Use ask_corrective_rag('your question') to ask questions.


## 12. Example Usage

In [None]:
# Example questions to test
example_questions = [
    "What is the main topic of the paper?",
    "What is the capital of France?",
    "Can you summarize the abstract of this paper?",
    "What are the key limitations discussed?",
    "Who is the author of 'Pride and Prejudice'?",
    "What is the purpose of a vector database in RAG?"
]

print("🧪 Testing with example questions...\n")

for i, question in enumerate(example_questions, 1):
    print(f"\n{'='*60}")
    print(f"Question {i}: {question}")
    print(f"{'='*60}")

    answer = ask_corrective_rag(question, verbose=False)
    print(f"\n🎯 Answer: {answer}")
    print("\n" + "-"*60)

🧪 Testing with example questions...


Question 1: What is the main topic of the paper?
🔍 STEP: Retrieving documents...
📋 Retrieved 4 documents
⚖️ STEP: Grading document relevance...
❌ Document not relevant
❌ Document not relevant
✅ Document relevant
✅ Document relevant
📊 Relevant documents: 2/4
🤔 Relevant threshold for skipping web search: 2
🤔 STEP: Deciding next action...
➡️ Decision: Generate answer directly
✨ STEP: Generating final answer...
✅ Answer generated successfully

🎯 Answer: The main topic of the paper is **temporal recommendation systems**, specifically focusing on how to integrate and structure diverse temporal and heterogeneous data (such as user interactions, notes, and external knowledge bases) into a unified graph. The paper then leverages this graph, along with a specialized graph neural network (GNN), to generate recommendations while emphasizing temporal dynamics and relationships. The inclusion of a "reification node" (shortlist node) and the adaptation of the gra

## 13. Custom Question Input

In [None]:
# Ask your own question
your_question = input("Enter your question: ")

if your_question:
    print(f"\n🎯 Your Question: {your_question}")
    print("="*80)

    answer = ask_corrective_rag(your_question, verbose=True)

    print("\n" + "="*80)
    print("🏆 FINAL ANSWER:")
    print("="*80)
    print(answer)

Enter your question: what is TIMBRE?

🎯 Your Question: what is TIMBRE?
🔍 STEP: Retrieving documents...
📋 Retrieved 4 documents

📋 Step 'retrieve' completed
⚖️ STEP: Grading document relevance...
✅ Document relevant
❌ Document not relevant
✅ Document relevant
✅ Document relevant
📊 Relevant documents: 3/4
🤔 Relevant threshold for skipping web search: 2
🤔 STEP: Deciding next action...
➡️ Decision: Generate answer directly

📋 Step 'grade_documents' completed
✨ STEP: Generating final answer...
✅ Answer generated successfully

📋 Step 'generate' completed

🏆 FINAL ANSWER:
TIMBRE is a **temporal graph-based recommender system** designed to improve recommendation performance by integrating data from multiple sources and structuring it for better exploitation. Specifically, it is introduced to address challenges in recommendation systems that require extracting and organizing information from diverse sources (e.g., resumes, job descriptions, recruiter notes, and external knowledge bases) to enha

## Key Features of This Corrective RAG Implementation:

### 🔄 **Corrective Mechanism:**
1. **Initial Retrieval:** Gets relevant documents from local vector database
2. **Relevance Grading:** LLM evaluates if retrieved docs are actually relevant
3. **Correction:** If docs are poor, it improves the query and searches the web
4. **Enhanced Generation:** Uses all available context to generate better answers

### 🛠️ **Tech Stack:**
- **LLM:** Mistral API (free tier available)
- **Vector DB:** ChromaDB (local, no external services)
- **Embeddings:** HuggingFace (free)
- **Web Search:** Rapid API
- **Orchestration:** LangGraph for workflow management

### 🎯 **Benefits:**
- **Self-Correcting:** Automatically improves retrieval quality
- **Hybrid Approach:** Combines local documents + web search
- **Robust:** Handles cases where initial retrieval fails
- **Cost-Effective:** Uses free/affordable APIs

### 📝 **Next Steps:**
1. Set your API keys in the configuration cell
2. Run all cells in order
3. Test with the example questions
4. Ask your own questions using the interactive function