# 🔧 Corrective RAG (CRAG) Implementation

## Self-Correcting AI Travel Assistant with Real-Time Verification

This notebook demonstrates a **Corrective RAG system** with:
- 🔍 Intelligent Document Retrieval
- 📊 Confidence Evaluation & Quality Assessment
- 🌐 Real-Time Web Search Integration
- 🧩 Query Decomposition & Re-retrieval
- 🔄 Knowledge Base Refinement
- ⚡ Adaptive Correction Mechanisms

### Key Benefits of Corrective RAG
- **Self-Correcting**: Automatically detects and fixes poor retrievals
- **Real-Time Updates**: Verifies information through web search
- **Adaptive**: Adjusts strategy based on confidence levels
- **Reliable**: Ensures high-quality, up-to-date responses

### Use Case: Smart Travel Planning Assistant
Perfect for travel applications where information freshness and accuracy are critical!

In [1]:
# Install required packages
!pip install sentence-transformers faiss-cpu google-generativeai rank-bm25 transformers scikit-learn numpy python-dotenv requests beautifulsoup4 dateparser

Collecting dateparser
  Using cached dateparser-1.2.1-py3-none-any.whl (295 kB)
Collecting tzlocal>=0.2
  Using cached tzlocal-5.3.1-py3-none-any.whl (18 kB)
Installing collected packages: tzlocal, dateparser
Successfully installed dateparser-1.2.1 tzlocal-5.3.1
You should consider upgrading via the '/home/mohdasimkhan/.pyenv/versions/3.10.2/envs/rags/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
# Import libraries
import numpy as np
import json
import re
import os
import time
import uuid
import requests
from typing import List, Dict, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from enum import Enum
from datetime import datetime, timedelta
import dateparser

from sentence_transformers import SentenceTransformer
import faiss
import google.generativeai as genai
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

load_dotenv()
print("📚 Libraries imported successfully!")

  from .autonotebook import tqdm as notebook_tqdm


📚 Libraries imported successfully!


## 🏗️ Core Data Structures

Define the foundational structures for our CRAG system:

In [3]:
# Enums for system configuration
class ConfidenceLevel(Enum):
    HIGH = "high"          # Use retrieved documents directly
    MEDIUM = "medium"      # Decompose query and re-retrieve
    LOW = "low"            # Trigger web search or knowledge refinement

class CorrectionAction(Enum):
    NONE = "none"                    # No correction needed
    DECOMPOSE_QUERY = "decompose"    # Break down query
    WEB_SEARCH = "web_search"        # Search web for current info
    REFINE_KNOWLEDGE = "refine"      # Update knowledge base
    HYBRID_SEARCH = "hybrid"         # Multiple strategies

class TravelDomain(Enum):
    HOTELS = "hotels"
    FLIGHTS = "flights"
    RESTAURANTS = "restaurants"
    ATTRACTIONS = "attractions"
    TRANSPORTATION = "transportation"
    GENERAL = "general"

class RetrievalStrategy(Enum):
    SEMANTIC = "semantic"
    KEYWORD = "keyword"
    HYBRID = "hybrid"

# Core data structures
@dataclass
class TravelDocument:
    id: str
    title: str
    content: str
    domain: TravelDomain
    location: str
    last_updated: datetime
    keywords: List[str] = field(default_factory=list)
    embedding: Optional[np.ndarray] = None
    confidence_indicators: Dict[str, Any] = field(default_factory=dict)

@dataclass
class TravelQuery:
    id: str
    text: str
    domain: TravelDomain
    location: Optional[str] = None
    date_range: Optional[Tuple[datetime, datetime]] = None
    user_id: Optional[str] = None
    session_id: Optional[str] = None

@dataclass
class RetrievalResult:
    document: TravelDocument
    score: float
    rank: int
    strategy_used: RetrievalStrategy
    relevance_indicators: Dict[str, float] = field(default_factory=dict)

@dataclass
class ConfidenceAssessment:
    level: ConfidenceLevel
    score: float
    reasons: List[str]
    recommended_action: CorrectionAction
    indicators: Dict[str, float] = field(default_factory=dict)

@dataclass
class CorrectionResult:
    action_taken: CorrectionAction
    new_documents: List[RetrievalResult]
    success: bool
    processing_time: float
    details: Dict[str, Any] = field(default_factory=dict)

@dataclass
class CRAGResponse:
    query: TravelQuery
    initial_retrieval: List[RetrievalResult]
    confidence_assessment: ConfidenceAssessment
    correction_result: Optional[CorrectionResult]
    final_documents: List[RetrievalResult]
    generated_answer: str
    overall_confidence: float
    processing_pipeline: List[str]
    processing_time: float
    correction_applied: bool

print("🏗️ Data structures defined!")

🏗️ Data structures defined!


## 📚 Travel Knowledge Base

Create a comprehensive travel knowledge base:

In [4]:
# Travel knowledge base with various freshness levels
travel_knowledge_base = [
    {
        "id": "hotel_001",
        "title": "Grand Palace Hotel Bangkok - Luxury Accommodation",
        "domain": "hotels",
        "location": "Bangkok, Thailand",
        "content": "The Grand Palace Hotel Bangkok offers luxury accommodation in the heart of Bangkok. Features include 24-hour room service, fitness center, spa, outdoor pool, and multiple dining options. Located near major attractions like the Grand Palace and Wat Pho temple. Rooms range from deluxe to presidential suites with city or river views. Average rate: $150-400 per night. Check-in: 3 PM, Check-out: 12 PM.",
        "last_updated": datetime(2024, 1, 15),
        "keywords": ["hotel", "luxury", "Bangkok", "palace", "spa", "pool"],
        "confidence_indicators": {"price_freshness": 0.6, "availability_freshness": 0.3}
    },
    {
        "id": "flight_001",
        "title": "Bangkok to Tokyo Flight Routes - Airlines and Schedules",
        "domain": "flights",
        "location": "Bangkok to Tokyo",
        "content": "Multiple airlines operate between Bangkok (BKK) and Tokyo (NRT/HND). Thai Airways, ANA, JAL, and budget carriers like Scoot offer direct flights. Flight duration: 6-7 hours. Typical prices: $300-800 economy, $1200-2500 business class. Peak seasons: March-May, July-August, December. Book 2-3 months in advance for better rates. Check visa requirements for Thailand and Japan.",
        "last_updated": datetime(2023, 12, 10),
        "keywords": ["flight", "Bangkok", "Tokyo", "airlines", "schedule", "price"],
        "confidence_indicators": {"price_freshness": 0.4, "schedule_freshness": 0.5}
    },
    {
        "id": "restaurant_001",
        "title": "Gaggan Restaurant Bangkok - Progressive Indian Cuisine",
        "domain": "restaurants",
        "location": "Bangkok, Thailand",
        "content": "Gaggan is a world-renowned progressive Indian restaurant in Bangkok, previously ranked #1 in Asia's 50 Best Restaurants. Chef Gaggan Anand creates innovative molecular gastronomy interpretations of Indian cuisine. Tasting menu: $200-300 per person. Reservations essential, book 2-3 months ahead. Open Tuesday-Sunday, closed Mondays. Located in a beautiful colonial house with garden seating.",
        "last_updated": datetime(2024, 2, 1),
        "keywords": ["restaurant", "Gaggan", "Bangkok", "Indian", "fine dining", "molecular"],
        "confidence_indicators": {"menu_freshness": 0.8, "price_freshness": 0.7}
    },
    {
        "id": "attraction_001",
        "title": "Wat Pho Temple - Temple of the Reclining Buddha",
        "domain": "attractions",
        "location": "Bangkok, Thailand",
        "content": "Wat Pho is one of Bangkok's oldest and most important temples, famous for the 46-meter long Reclining Buddha statue. The temple complex houses over 1,000 Buddha images and is considered the first public university in Thailand. Traditional Thai massage school operates here. Entry fee: 200 THB for foreigners. Open daily 8 AM - 6:30 PM. Dress code: Cover shoulders and knees. Allow 2-3 hours for visit.",
        "last_updated": datetime(2024, 3, 5),
        "keywords": ["temple", "Wat Pho", "Bangkok", "Buddha", "massage", "attraction"],
        "confidence_indicators": {"price_freshness": 0.9, "hours_freshness": 0.8}
    },
    {
        "id": "transport_001",
        "title": "Bangkok Public Transportation - BTS, MRT, and Taxis",
        "domain": "transportation",
        "location": "Bangkok, Thailand",
        "content": "Bangkok offers various transportation options: BTS Skytrain (elevated rail), MRT subway, buses, taxis, and tuk-tuks. BTS/MRT: 16-52 THB per trip, operates 6 AM - midnight. Taxis: meter starts at 35 THB, traffic can be heavy. Grab ride-hailing app widely available. Airport Rail Link connects Suvarnabhumi Airport to city center (45 THB). Consider getting a Rabbit Card for BTS/MRT convenience.",
        "last_updated": datetime(2023, 11, 20),
        "keywords": ["transportation", "BTS", "MRT", "taxi", "Bangkok", "public transport"],
        "confidence_indicators": {"price_freshness": 0.5, "schedule_freshness": 0.6}
    },
    {
        "id": "hotel_002",
        "title": "Budget Hostels in Khao San Road - Backpacker Haven",
        "domain": "hotels",
        "location": "Bangkok, Thailand",
        "content": "Khao San Road is Bangkok's famous backpacker street with numerous budget accommodations. Hostels offer dorm beds ($8-15) and private rooms ($20-40). Popular options include Mad Monkey Hostel, Lub d Bangkok Siam, and NapPark Hostel. Most include free WiFi, common areas, and tour booking services. Area is lively with street food, bars, and shops. Can be noisy - bring earplugs. Book ahead during peak season.",
        "last_updated": datetime(2024, 1, 8),
        "keywords": ["hostel", "budget", "Khao San Road", "backpacker", "Bangkok"],
        "confidence_indicators": {"price_freshness": 0.7, "availability_freshness": 0.4}
    },
    {
        "id": "food_001",
        "title": "Bangkok Street Food Guide - Must-Try Local Dishes",
        "domain": "restaurants",
        "location": "Bangkok, Thailand",
        "content": "Bangkok street food offers incredible variety and flavors. Must-try dishes: Pad Thai (60-80 THB), Som Tam (green papaya salad, 40-60 THB), Mango Sticky Rice (80-120 THB), Tom Yum soup (80-150 THB). Best street food areas: Chatuchak Weekend Market, Chinatown, Khao San Road. Food courts in malls offer air-conditioned dining with similar prices. Always choose busy stalls with high turnover for freshness.",
        "last_updated": datetime(2024, 2, 20),
        "keywords": ["street food", "Bangkok", "Pad Thai", "local cuisine", "markets"],
        "confidence_indicators": {"price_freshness": 0.8, "location_freshness": 0.9}
    },
    {
        "id": "weather_001",
        "title": "Bangkok Weather and Best Time to Visit",
        "domain": "general",
        "location": "Bangkok, Thailand",
        "content": "Bangkok has a tropical climate with three seasons: Cool (Nov-Feb): 20-30°C, dry and pleasant, peak tourist season. Hot (Mar-May): 30-35°C, very hot and humid. Rainy (Jun-Oct): 25-32°C, daily afternoon showers, fewer crowds, lower prices. Best time to visit: November to February for comfortable weather. Pack light, breathable clothing, umbrella during rainy season. Air conditioning is essential in hotels.",
        "last_updated": datetime(2023, 10, 15),
        "keywords": ["weather", "Bangkok", "climate", "seasons", "best time"],
        "confidence_indicators": {"seasonal_accuracy": 0.9, "current_conditions": 0.3}
    }
]

print(f"📚 Travel knowledge base created with {len(travel_knowledge_base)} documents")
print(f"🌍 Domains: {set(doc['domain'] for doc in travel_knowledge_base)}")
print(f"📍 Locations: {set(doc['location'] for doc in travel_knowledge_base)}")

📚 Travel knowledge base created with 8 documents
🌍 Domains: {'attractions', 'general', 'restaurants', 'flights', 'transportation', 'hotels'}
📍 Locations: {'Bangkok, Thailand', 'Bangkok to Tokyo'}


## 🧩 Base Module Classes

Define abstract base classes for CRAG architecture:

In [5]:
# Base module class for CRAG
class CRAGModule(ABC):
    def __init__(self, name: str):
        self.name = name
        self.call_count = 0
        self.success_count = 0
        self.created_at = datetime.now()
    
    @abstractmethod
    def process(self, input_data: Any) -> Any:
        pass
    
    def update_stats(self, success: bool = True):
        self.call_count += 1
        if success:
            self.success_count += 1
    
    def get_info(self):
        success_rate = (self.success_count / max(self.call_count, 1)) * 100
        return {
            'name': self.name,
            'calls': self.call_count,
            'success_rate': success_rate,
            'created': self.created_at
        }

print("🧩 Base CRAG module class defined!")

🧩 Base CRAG module class defined!


## 🔍 Query Processing Module

Intelligent travel query understanding and classification:

In [6]:
class TravelQueryModule(CRAGModule):
    def __init__(self):
        super().__init__("TravelQueryModule")
        
        # Domain detection patterns
        self.domain_patterns = {
            TravelDomain.HOTELS: ["hotel", "accommodation", "stay", "room", "booking", "lodge"],
            TravelDomain.FLIGHTS: ["flight", "airline", "airport", "fly", "ticket", "aviation"],
            TravelDomain.RESTAURANTS: ["restaurant", "food", "dining", "eat", "cuisine", "meal"],
            TravelDomain.ATTRACTIONS: ["attraction", "temple", "museum", "sightseeing", "tour", "visit"],
            TravelDomain.TRANSPORTATION: ["transport", "taxi", "bus", "train", "metro", "bts", "mrt"],
        }
        
        # Location patterns
        self.location_patterns = [
            r"in ([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)",
            r"at ([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)",
            r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*) (?:hotel|restaurant|attraction)"
        ]
    
    def process(self, query_text: str) -> TravelQuery:
        self.update_stats()
        
        query_id = str(uuid.uuid4())[:8]
        domain = self._detect_domain(query_text)
        location = self._extract_location(query_text)
        date_range = self._extract_dates(query_text)
        
        return TravelQuery(
            id=query_id,
            text=query_text,
            domain=domain,
            location=location,
            date_range=date_range
        )
    
    def _detect_domain(self, text: str) -> TravelDomain:
        text_lower = text.lower()
        domain_scores = {}
        
        for domain, keywords in self.domain_patterns.items():
            score = sum(1 for keyword in keywords if keyword in text_lower)
            domain_scores[domain] = score
        
        if max(domain_scores.values()) > 0:
            return max(domain_scores, key=domain_scores.get)
        return TravelDomain.GENERAL
    
    def _extract_location(self, text: str) -> Optional[str]:
        for pattern in self.location_patterns:
            match = re.search(pattern, text)
            if match:
                return match.group(1)
        return None
    
    def _extract_dates(self, text: str) -> Optional[Tuple[datetime, datetime]]:
        # Simple date extraction - can be enhanced
        date_patterns = [
            r"(\d{1,2}/\d{1,2}/\d{4})",
            r"(\d{4}-\d{2}-\d{2})",
            r"(next week|next month|tomorrow)"
        ]
        
        for pattern in date_patterns:
            matches = re.findall(pattern, text.lower())
            if matches:
                try:
                    parsed_date = dateparser.parse(matches[0])
                    if parsed_date:
                        end_date = parsed_date + timedelta(days=7)  # Default 7-day trip
                        return (parsed_date, end_date)
                except:
                    pass
        return None

# Test travel query module
travel_query_module = TravelQueryModule()
test_query = travel_query_module.process("Find luxury hotels in Bangkok for next week")
print(f"🔍 Test query processed:")
print(f"   Text: {test_query.text}")
print(f"   Domain: {test_query.domain.value}")
print(f"   Location: {test_query.location}")
print("✅ Travel Query Module ready!")

🔍 Test query processed:
   Text: Find luxury hotels in Bangkok for next week
   Domain: hotels
   Location: Bangkok
✅ Travel Query Module ready!


## 🔎 Retrieval Module

Document retrieval with travel-specific optimizations:

In [7]:
class TravelRetrievalModule(CRAGModule):
    def __init__(self):
        super().__init__("TravelRetrievalModule")
        
        # Initialize components
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.documents = []
        self.semantic_index = None
        self.bm25_index = None
        
        print("🔎 Travel retrieval module initialized")
    
    def index_documents(self, documents: List[Dict]):
        print(f"📚 Indexing {len(documents)} travel documents...")
        
        # Convert to TravelDocument objects
        self.documents = [
            TravelDocument(
                id=doc['id'],
                title=doc['title'],
                content=doc['content'],
                domain=TravelDomain(doc['domain']),
                location=doc['location'],
                last_updated=doc['last_updated'],
                keywords=doc.get('keywords', []),
                confidence_indicators=doc.get('confidence_indicators', {})
            )
            for doc in documents
        ]
        
        # Build semantic index
        self._build_semantic_index()
        
        # Build keyword index
        self._build_keyword_index()
        
        print("✅ Travel documents indexed successfully!")
    
    def _build_semantic_index(self):
        doc_texts = [f"{doc.title} {doc.content} {doc.location}" for doc in self.documents]
        embeddings = self.embedding_model.encode(doc_texts)
        
        # Store embeddings
        for doc, embedding in zip(self.documents, embeddings):
            doc.embedding = embedding
        
        # Create FAISS index
        dimension = embeddings.shape[1]
        self.semantic_index = faiss.IndexFlatIP(dimension)
        faiss.normalize_L2(embeddings)
        self.semantic_index.add(embeddings.astype('float32'))
    
    def _build_keyword_index(self):
        doc_texts = [f"{doc.title} {doc.content} {doc.location}" for doc in self.documents]
        tokenized_docs = [text.lower().split() for text in doc_texts]
        self.bm25_index = BM25Okapi(tokenized_docs)
    
    def process(self, query: TravelQuery, strategy: RetrievalStrategy = RetrievalStrategy.HYBRID, top_k: int = 5) -> List[RetrievalResult]:
        self.update_stats()
        
        if strategy == RetrievalStrategy.SEMANTIC:
            results = self._semantic_search(query, top_k)
        elif strategy == RetrievalStrategy.KEYWORD:
            results = self._keyword_search(query, top_k)
        else:  # HYBRID
            results = self._hybrid_search(query, top_k)
        
        # Apply travel-specific filtering
        return self._apply_travel_filters(query, results)
    
    def _semantic_search(self, query: TravelQuery, top_k: int) -> List[RetrievalResult]:
        search_text = f"{query.text} {query.location or ''}"
        query_embedding = self.embedding_model.encode([search_text])
        faiss.normalize_L2(query_embedding)
        
        scores, indices = self.semantic_index.search(query_embedding.astype('float32'), top_k * 2)
        
        results = []
        for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
            if idx < len(self.documents):  # Valid index check
                relevance_indicators = self._calculate_relevance_indicators(query, self.documents[idx], float(score))
                results.append(RetrievalResult(
                    document=self.documents[idx],
                    score=float(score),
                    rank=i + 1,
                    strategy_used=RetrievalStrategy.SEMANTIC,
                    relevance_indicators=relevance_indicators
                ))
        return results[:top_k]
    
    def _keyword_search(self, query: TravelQuery, top_k: int) -> List[RetrievalResult]:
        search_text = f"{query.text} {query.location or ''}"
        query_tokens = search_text.lower().split()
        scores = self.bm25_index.get_scores(query_tokens)
        top_indices = np.argsort(scores)[::-1][:top_k]
        
        results = []
        for i, idx in enumerate(top_indices):
            relevance_indicators = self._calculate_relevance_indicators(query, self.documents[idx], float(scores[idx]))
            results.append(RetrievalResult(
                document=self.documents[idx],
                score=float(scores[idx]),
                rank=i + 1,
                strategy_used=RetrievalStrategy.KEYWORD,
                relevance_indicators=relevance_indicators
            ))
        return results
    
    def _hybrid_search(self, query: TravelQuery, top_k: int) -> List[RetrievalResult]:
        # Get results from both strategies
        semantic_results = self._semantic_search(query, top_k * 2)
        keyword_results = self._keyword_search(query, top_k * 2)
        
        # Normalize scores
        self._normalize_scores(semantic_results)
        self._normalize_scores(keyword_results)
        
        # Combine with travel-aware weights
        semantic_weight = 0.7
        keyword_weight = 0.3
        
        combined_scores = {}
        
        for result in semantic_results:
            doc_id = result.document.id
            combined_scores[doc_id] = {
                'document': result.document,
                'semantic_score': result.score,
                'keyword_score': 0.0,
                'relevance_indicators': result.relevance_indicators
            }
        
        for result in keyword_results:
            doc_id = result.document.id
            if doc_id in combined_scores:
                combined_scores[doc_id]['keyword_score'] = result.score
            else:
                combined_scores[doc_id] = {
                    'document': result.document,
                    'semantic_score': 0.0,
                    'keyword_score': result.score,
                    'relevance_indicators': result.relevance_indicators
                }
        
        # Calculate final scores
        final_results = []
        for doc_id, scores in combined_scores.items():
            final_score = (semantic_weight * scores['semantic_score'] + 
                          keyword_weight * scores['keyword_score'])
            
            final_results.append(RetrievalResult(
                document=scores['document'],
                score=final_score,
                rank=0,
                strategy_used=RetrievalStrategy.HYBRID,
                relevance_indicators=scores['relevance_indicators']
            ))
        
        # Sort and assign ranks
        final_results.sort(key=lambda x: x.score, reverse=True)
        for i, result in enumerate(final_results[:top_k]):
            result.rank = i + 1
        
        return final_results[:top_k]
    
    def _calculate_relevance_indicators(self, query: TravelQuery, document: TravelDocument, base_score: float) -> Dict[str, float]:
        indicators = {
            'base_score': base_score,
            'domain_match': 1.0 if query.domain == document.domain else 0.5,
            'location_match': 0.0,
            'freshness_score': self._calculate_freshness_score(document),
            'keyword_overlap': self._calculate_keyword_overlap(query, document)
        }
        
        # Location matching
        if query.location and document.location:
            if query.location.lower() in document.location.lower():
                indicators['location_match'] = 1.0
            elif any(word in document.location.lower() for word in query.location.lower().split()):
                indicators['location_match'] = 0.7
        
        return indicators
    
    def _calculate_freshness_score(self, document: TravelDocument) -> float:
        days_old = (datetime.now() - document.last_updated).days
        
        # Different domains have different freshness requirements
        if document.domain in [TravelDomain.FLIGHTS, TravelDomain.HOTELS]:
            # Pricing info becomes stale quickly
            return max(0.0, 1.0 - days_old / 90)  # 90 days for full decay
        elif document.domain == TravelDomain.RESTAURANTS:
            # Menu and pricing moderate freshness needs
            return max(0.0, 1.0 - days_old / 180)  # 180 days
        else:
            # Attractions and general info stay fresh longer
            return max(0.0, 1.0 - days_old / 365)  # 365 days
    
    def _calculate_keyword_overlap(self, query: TravelQuery, document: TravelDocument) -> float:
        query_words = set(query.text.lower().split())
        doc_words = set((document.title + " " + document.content).lower().split())
        
        if not query_words:
            return 0.0
        
        overlap = len(query_words.intersection(doc_words))
        return overlap / len(query_words)
    
    def _apply_travel_filters(self, query: TravelQuery, results: List[RetrievalResult]) -> List[RetrievalResult]:
        # Apply travel-specific filtering and boosting
        filtered_results = []
        
        for result in results:
            # Boost domain matches
            if query.domain == result.document.domain:
                result.score *= 1.2
            
            # Boost location matches
            if result.relevance_indicators.get('location_match', 0) > 0.8:
                result.score *= 1.15
            
            # Apply freshness penalty for time-sensitive domains
            freshness = result.relevance_indicators.get('freshness_score', 1.0)
            if result.document.domain in [TravelDomain.FLIGHTS, TravelDomain.HOTELS] and freshness < 0.5:
                result.score *= 0.8  # Penalty for stale pricing info
            
            filtered_results.append(result)
        
        # Re-sort by adjusted scores
        filtered_results.sort(key=lambda x: x.score, reverse=True)
        for i, result in enumerate(filtered_results):
            result.rank = i + 1
        
        return filtered_results
    
    def _normalize_scores(self, results: List[RetrievalResult]):
        if not results:
            return
        
        scores = [result.score for result in results]
        min_score, max_score = min(scores), max(scores)
        
        if max_score > min_score:
            for result in results:
                result.score = (result.score - min_score) / (max_score - min_score)

# Initialize travel retrieval module
travel_retrieval_module = TravelRetrievalModule()
travel_retrieval_module.index_documents(travel_knowledge_base)

print("✅ Travel Retrieval Module ready!")

🔎 Travel retrieval module initialized
📚 Indexing 8 travel documents...
✅ Travel documents indexed successfully!
✅ Travel Retrieval Module ready!


## 📊 Confidence Assessment Module

The core of CRAG - evaluating retrieval quality and determining correction needs:

In [8]:
class ConfidenceAssessmentModule(CRAGModule):
    def __init__(self):
        super().__init__("ConfidenceAssessmentModule")
        
        # Confidence thresholds
        self.thresholds = {
            'high_confidence': 0.75,
            'medium_confidence': 0.45,
            'min_documents': 2,
            'freshness_threshold': 0.6
        }
        
        # Domain-specific confidence requirements
        self.domain_requirements = {
            TravelDomain.FLIGHTS: {'freshness_weight': 0.4, 'accuracy_weight': 0.6},
            TravelDomain.HOTELS: {'freshness_weight': 0.3, 'accuracy_weight': 0.7},
            TravelDomain.RESTAURANTS: {'freshness_weight': 0.25, 'accuracy_weight': 0.75},
            TravelDomain.ATTRACTIONS: {'freshness_weight': 0.1, 'accuracy_weight': 0.9},
            TravelDomain.TRANSPORTATION: {'freshness_weight': 0.3, 'accuracy_weight': 0.7},
            TravelDomain.GENERAL: {'freshness_weight': 0.2, 'accuracy_weight': 0.8}
        }
    
    def process(self, query: TravelQuery, retrieval_results: List[RetrievalResult]) -> ConfidenceAssessment:
        self.update_stats()
        
        # Calculate multiple confidence indicators
        indicators = self._calculate_confidence_indicators(query, retrieval_results)
        
        # Determine overall confidence level
        confidence_score = self._calculate_overall_confidence(query, indicators)
        confidence_level = self._determine_confidence_level(confidence_score)
        
        # Generate reasons and recommended action
        reasons = self._generate_confidence_reasons(indicators, confidence_level)
        recommended_action = self._recommend_correction_action(query, confidence_level, indicators)
        
        return ConfidenceAssessment(
            level=confidence_level,
            score=confidence_score,
            reasons=reasons,
            recommended_action=recommended_action,
            indicators=indicators
        )
    
    def _calculate_confidence_indicators(self, query: TravelQuery, results: List[RetrievalResult]) -> Dict[str, float]:
        if not results:
            return {
                'retrieval_quality': 0.0,
                'result_count_score': 0.0,
                'freshness_score': 0.0,
                'relevance_score': 0.0,
                'domain_match_score': 0.0,
                'location_match_score': 0.0,
                'score_variance': 0.0
            }
        
        indicators = {}
        
        # 1. Retrieval Quality (top scores)
        top_scores = [r.score for r in results[:3]]
        indicators['retrieval_quality'] = np.mean(top_scores) if top_scores else 0.0
        
        # 2. Result Count Score
        indicators['result_count_score'] = min(1.0, len(results) / 5.0)
        
        # 3. Freshness Score
        freshness_scores = [r.relevance_indicators.get('freshness_score', 0.5) for r in results]
        indicators['freshness_score'] = np.mean(freshness_scores)
        
        # 4. Relevance Score (based on domain and location matches)
        domain_matches = [r.relevance_indicators.get('domain_match', 0.5) for r in results]
        location_matches = [r.relevance_indicators.get('location_match', 0.0) for r in results]
        indicators['domain_match_score'] = np.mean(domain_matches)
        indicators['location_match_score'] = np.mean(location_matches)
        indicators['relevance_score'] = (indicators['domain_match_score'] + indicators['location_match_score']) / 2
        
        # 5. Score Variance (consistency indicator)
        scores = [r.score for r in results]
        indicators['score_variance'] = 1.0 - (np.std(scores) / (np.mean(scores) + 1e-6)) if len(scores) > 1 else 1.0
        indicators['score_variance'] = max(0.0, min(1.0, indicators['score_variance']))
        
        # 6. Keyword Overlap Score
        keyword_overlaps = [r.relevance_indicators.get('keyword_overlap', 0.0) for r in results]
        indicators['keyword_overlap_score'] = np.mean(keyword_overlaps)
        
        return indicators
    
    def _calculate_overall_confidence(self, query: TravelQuery, indicators: Dict[str, float]) -> float:
        # Get domain-specific weights
        domain_req = self.domain_requirements.get(query.domain, self.domain_requirements[TravelDomain.GENERAL])
        
        # Base confidence calculation
        base_confidence = (
            indicators['retrieval_quality'] * 0.25 +
            indicators['relevance_score'] * 0.25 +
            indicators['result_count_score'] * 0.15 +
            indicators['score_variance'] * 0.15 +
            indicators['keyword_overlap_score'] * 0.20
        )
        
        # Apply domain-specific adjustments
        freshness_impact = indicators['freshness_score'] * domain_req['freshness_weight']
        accuracy_impact = base_confidence * domain_req['accuracy_weight']
        
        final_confidence = accuracy_impact + freshness_impact
        
        # Apply penalties for critical issues
        if indicators['result_count_score'] < 0.4:  # Very few results
            final_confidence *= 0.7
        
        if query.domain in [TravelDomain.FLIGHTS, TravelDomain.HOTELS] and indicators['freshness_score'] < 0.3:
            final_confidence *= 0.6  # Heavy penalty for stale pricing info
        
        return max(0.0, min(1.0, final_confidence))
    
    def _determine_confidence_level(self, confidence_score: float) -> ConfidenceLevel:
        if confidence_score >= self.thresholds['high_confidence']:
            return ConfidenceLevel.HIGH
        elif confidence_score >= self.thresholds['medium_confidence']:
            return ConfidenceLevel.MEDIUM
        else:
            return ConfidenceLevel.LOW
    
    def _generate_confidence_reasons(self, indicators: Dict[str, float], level: ConfidenceLevel) -> List[str]:
        reasons = []
        
        if level == ConfidenceLevel.HIGH:
            reasons.append(f"High retrieval quality (score: {indicators['retrieval_quality']:.2f})")
            if indicators['relevance_score'] > 0.8:
                reasons.append("Strong domain and location matching")
            if indicators['freshness_score'] > 0.7:
                reasons.append("Recent and fresh information")
        
        elif level == ConfidenceLevel.MEDIUM:
            reasons.append(f"Moderate confidence (score: {indicators['retrieval_quality']:.2f})")
            if indicators['result_count_score'] < 0.6:
                reasons.append("Limited number of relevant documents")
            if indicators['freshness_score'] < 0.6:
                reasons.append("Some information may be outdated")
            if indicators['relevance_score'] < 0.7:
                reasons.append("Partial domain or location matching")
        
        else:  # LOW
            reasons.append(f"Low confidence (score: {indicators['retrieval_quality']:.2f})")
            if indicators['result_count_score'] < 0.4:
                reasons.append("Very few relevant documents found")
            if indicators['freshness_score'] < 0.3:
                reasons.append("Information appears outdated")
            if indicators['relevance_score'] < 0.5:
                reasons.append("Poor domain or location matching")
            if indicators['keyword_overlap_score'] < 0.3:
                reasons.append("Low keyword overlap with query")
        
        return reasons
    
    def _recommend_correction_action(self, query: TravelQuery, level: ConfidenceLevel, indicators: Dict[str, float]) -> CorrectionAction:
        if level == ConfidenceLevel.HIGH:
            return CorrectionAction.NONE
        
        elif level == ConfidenceLevel.MEDIUM:
            # For medium confidence, try query decomposition first
            if len(query.text.split()) > 8:  # Complex query
                return CorrectionAction.DECOMPOSE_QUERY
            elif indicators['freshness_score'] < 0.5 and query.domain in [TravelDomain.FLIGHTS, TravelDomain.HOTELS]:
                return CorrectionAction.WEB_SEARCH  # Price-sensitive domains need fresh data
            else:
                return CorrectionAction.HYBRID_SEARCH
        
        else:  # LOW confidence
            # For low confidence, more aggressive correction
            if indicators['freshness_score'] < 0.3:
                return CorrectionAction.WEB_SEARCH  # Definitely need fresh data
            elif indicators['result_count_score'] < 0.3:
                return CorrectionAction.DECOMPOSE_QUERY  # Try breaking down the query
            else:
                return CorrectionAction.REFINE_KNOWLEDGE  # Knowledge base issue

# Test confidence assessment
confidence_module = ConfidenceAssessmentModule()
print("✅ Confidence Assessment Module ready!")

✅ Confidence Assessment Module ready!


## 🔧 Correction Action Modules

Modules that perform corrective actions based on confidence assessment:

In [9]:
class WebSearchModule(CRAGModule):
    """Web search module for real-time information verification"""
    
    def __init__(self):
        super().__init__("WebSearchModule")
        self.search_timeout = 10
    
    def process(self, query: TravelQuery) -> CorrectionResult:
        start_time = time.time()
        
        try:
            # Create search queries based on travel domain
            search_queries = self._create_search_queries(query)
            
            # Simulate web search (in real implementation, use actual search API)
            web_results = self._simulate_web_search(search_queries)
            
            # Convert web results to retrieval results
            new_documents = self._convert_web_results(web_results, query)
            
            processing_time = time.time() - start_time
            
            self.update_stats(True)
            return CorrectionResult(
                action_taken=CorrectionAction.WEB_SEARCH,
                new_documents=new_documents,
                success=True,
                processing_time=processing_time,
                details={'search_queries': search_queries, 'results_found': len(web_results)}
            )
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.update_stats(False)
            return CorrectionResult(
                action_taken=CorrectionAction.WEB_SEARCH,
                new_documents=[],
                success=False,
                processing_time=processing_time,
                details={'error': str(e)}
            )
    
    def _create_search_queries(self, query: TravelQuery) -> List[str]:
        base_query = query.text
        location = query.location or ""
        
        search_queries = []
        
        if query.domain == TravelDomain.HOTELS:
            search_queries.extend([
                f"{base_query} {location} prices 2024",
                f"{location} hotel booking rates current",
                f"{base_query} {location} availability"
            ])
        elif query.domain == TravelDomain.FLIGHTS:
            search_queries.extend([
                f"{base_query} flight prices current",
                f"{location} flight schedules 2024",
                f"{base_query} airline deals"
            ])
        elif query.domain == TravelDomain.RESTAURANTS:
            search_queries.extend([
                f"{base_query} {location} menu prices 2024",
                f"{location} restaurant reviews recent",
                f"{base_query} {location} opening hours"
            ])
        else:
            search_queries.append(f"{base_query} {location} 2024 current")
        
        return search_queries
    
    def _simulate_web_search(self, search_queries: List[str]) -> List[Dict[str, Any]]:
        """Simulate web search results - replace with actual search API in production"""
        simulated_results = [
            {
                'title': 'Bangkok Hotels - Current Prices and Availability',
                'content': 'Updated hotel prices for Bangkok: Grand Palace Hotel currently $180-420/night, availability good for next month. New promotion: 15% off for bookings made this week. Peak season rates apply Dec-Feb.',
                'url': 'https://example-booking.com/bangkok-hotels',
                'last_updated': datetime.now(),
                'relevance_score': 0.9
            },
            {
                'title': 'Bangkok Flight Deals - Real-time Pricing',
                'content': 'Live flight prices to Bangkok: Economy from $350, Business from $1400. New route announcements from AirAsia and Thai Airways. Current fuel surcharges and seasonal adjustments applied.',
                'url': 'https://example-flights.com/bangkok',
                'last_updated': datetime.now(),
                'relevance_score': 0.85
            }
        ]
        
        # Simulate processing delay
        time.sleep(0.5)
        
        return simulated_results
    
    def _convert_web_results(self, web_results: List[Dict], query: TravelQuery) -> List[RetrievalResult]:
        """Convert web search results to RetrievalResult objects"""
        retrieval_results = []
        
        for i, result in enumerate(web_results):
            # Create a temporary document from web result
            web_doc = TravelDocument(
                id=f"web_{uuid.uuid4().hex[:8]}",
                title=result['title'],
                content=result['content'],
                domain=query.domain,
                location=query.location or "Unknown",
                last_updated=result['last_updated'],
                keywords=[],
                confidence_indicators={'freshness_score': 1.0, 'web_verified': True}
            )
            
            retrieval_results.append(RetrievalResult(
                document=web_doc,
                score=result['relevance_score'],
                rank=i + 1,
                strategy_used=RetrievalStrategy.HYBRID,
                relevance_indicators={
                    'freshness_score': 1.0,
                    'web_verified': True,
                    'domain_match': 1.0 if query.domain == web_doc.domain else 0.5
                }
            ))
        
        return retrieval_results


class QueryDecompositionModule(CRAGModule):
    """Breaks down complex queries into simpler sub-queries"""
    
    def __init__(self):
        super().__init__("QueryDecompositionModule")
    
    def process(self, query: TravelQuery, retrieval_module: TravelRetrievalModule) -> CorrectionResult:
        start_time = time.time()
        
        try:
            # Decompose query into sub-queries
            sub_queries = self._decompose_query(query)
            
            # Retrieve documents for each sub-query
            all_results = []
            for sub_query in sub_queries:
                sub_results = retrieval_module.process(sub_query, RetrievalStrategy.HYBRID, top_k=3)
                all_results.extend(sub_results)
            
            # Remove duplicates and re-rank
            unique_results = self._deduplicate_and_rerank(all_results)
            
            processing_time = time.time() - start_time
            
            self.update_stats(True)
            return CorrectionResult(
                action_taken=CorrectionAction.DECOMPOSE_QUERY,
                new_documents=unique_results,
                success=True,
                processing_time=processing_time,
                details={'sub_queries': [sq.text for sq in sub_queries], 'total_results': len(all_results)}
            )
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.update_stats(False)
            return CorrectionResult(
                action_taken=CorrectionAction.DECOMPOSE_QUERY,
                new_documents=[],
                success=False,
                processing_time=processing_time,
                details={'error': str(e)}
            )
    
    def _decompose_query(self, query: TravelQuery) -> List[TravelQuery]:
        """Decompose complex query into simpler sub-queries"""
        sub_queries = []
        text = query.text.lower()
        
        # Pattern-based decomposition
        if "and" in text:
            parts = text.split(" and ")
            for part in parts:
                sub_queries.append(TravelQuery(
                    id=f"{query.id}_sub_{len(sub_queries)}",
                    text=part.strip(),
                    domain=query.domain,
                    location=query.location,
                    date_range=query.date_range
                ))
        
        # Domain-specific decomposition
        elif query.domain == TravelDomain.HOTELS:
            if "luxury" in text and "bangkok" in text:
                sub_queries.extend([
                    TravelQuery(f"{query.id}_sub_0", "luxury hotels", query.domain, query.location),
                    TravelQuery(f"{query.id}_sub_1", f"hotels {query.location}", query.domain, query.location),
                    TravelQuery(f"{query.id}_sub_2", "hotel prices", query.domain, query.location)
                ])
        
        # If no decomposition patterns match, create topic-based sub-queries
        if not sub_queries:
            keywords = text.split()
            if len(keywords) > 4:
                mid = len(keywords) // 2
                sub_queries.extend([
                    TravelQuery(f"{query.id}_sub_0", " ".join(keywords[:mid]), query.domain, query.location),
                    TravelQuery(f"{query.id}_sub_1", " ".join(keywords[mid:]), query.domain, query.location)
                ])
            else:
                # Add domain-specific context
                sub_queries.append(TravelQuery(
                    f"{query.id}_sub_0", 
                    f"{text} {query.domain.value}", 
                    query.domain, 
                    query.location
                ))
        
        return sub_queries or [query]  # Return original if no decomposition
    
    def _deduplicate_and_rerank(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
        """Remove duplicate documents and re-rank by relevance"""
        seen_docs = set()
        unique_results = []
        
        # Sort by score first
        results.sort(key=lambda x: x.score, reverse=True)
        
        for result in results:
            doc_id = result.document.id
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                unique_results.append(result)
        
        # Re-assign ranks
        for i, result in enumerate(unique_results):
            result.rank = i + 1
        
        return unique_results[:7]  # Return top 7 unique results


class KnowledgeRefinementModule(CRAGModule):
    """Refines and updates knowledge base based on correction needs"""
    
    def __init__(self):
        super().__init__("KnowledgeRefinementModule")
    
    def process(self, query: TravelQuery, retrieval_module: TravelRetrievalModule) -> CorrectionResult:
        start_time = time.time()
        
        try:
            # Apply different retrieval strategies
            semantic_results = retrieval_module.process(query, RetrievalStrategy.SEMANTIC, top_k=5)
            keyword_results = retrieval_module.process(query, RetrievalStrategy.KEYWORD, top_k=5)
            
            # Combine and enhance results
            enhanced_results = self._enhance_results(semantic_results, keyword_results)
            
            processing_time = time.time() - start_time
            
            self.update_stats(True)
            return CorrectionResult(
                action_taken=CorrectionAction.REFINE_KNOWLEDGE,
                new_documents=enhanced_results,
                success=True,
                processing_time=processing_time,
                details={
                    'semantic_results': len(semantic_results),
                    'keyword_results': len(keyword_results),
                    'enhanced_count': len(enhanced_results)
                }
            )
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.update_stats(False)
            return CorrectionResult(
                action_taken=CorrectionAction.REFINE_KNOWLEDGE,
                new_documents=[],
                success=False,
                processing_time=processing_time,
                details={'error': str(e)}
            )
    
    def _enhance_results(self, semantic_results: List[RetrievalResult], keyword_results: List[RetrievalResult]) -> List[RetrievalResult]:
        """Enhance results by combining different retrieval strategies"""
        all_results = semantic_results + keyword_results
        
        # Remove duplicates while preserving best scores
        doc_scores = {}
        for result in all_results:
            doc_id = result.document.id
            if doc_id not in doc_scores or result.score > doc_scores[doc_id].score:
                doc_scores[doc_id] = result
        
        # Enhance scores based on multiple strategy agreement
        enhanced_results = list(doc_scores.values())
        
        # Boost scores for documents that appear in both strategies
        semantic_doc_ids = {r.document.id for r in semantic_results}
        keyword_doc_ids = {r.document.id for r in keyword_results}
        
        for result in enhanced_results:
            if result.document.id in semantic_doc_ids and result.document.id in keyword_doc_ids:
                result.score *= 1.2  # Boost for multi-strategy agreement
                result.relevance_indicators['multi_strategy_match'] = True
        
        # Sort by enhanced scores
        enhanced_results.sort(key=lambda x: x.score, reverse=True)
        
        # Re-assign ranks
        for i, result in enumerate(enhanced_results):
            result.rank = i + 1
        
        return enhanced_results[:5]


# Initialize correction modules
web_search_module = WebSearchModule()
query_decomposition_module = QueryDecompositionModule()
knowledge_refinement_module = KnowledgeRefinementModule()

print("✅ Correction Action Modules ready!")

✅ Correction Action Modules ready!


## 🤖 Adaptive Generation Module

Context-aware answer generation with correction awareness:

In [10]:
class CRAGGenerationModule(CRAGModule):
    def __init__(self):
        super().__init__("CRAGGenerationModule")
        
        # Try to initialize Gemini
        api_key = os.getenv('GEMINI_API_KEY')
        if api_key:
            try:
                genai.configure(api_key=api_key)
                self.model = genai.GenerativeModel('gemini-1.5-flash')
                self.has_llm = True
                print("🤖 Gemini API configured for CRAG")
            except Exception as e:
                print(f"⚠️ Gemini error: {e}")
                self.has_llm = False
        else:
            print("⚠️ No Gemini API key. Using template generation.")
            self.has_llm = False
    
    def process(self, query: TravelQuery, final_documents: List[RetrievalResult], 
                confidence_assessment: ConfidenceAssessment, 
                correction_applied: bool) -> Tuple[str, float]:
        self.update_stats()
        
        if self.has_llm:
            return self._generate_with_llm(query, final_documents, confidence_assessment, correction_applied)
        else:
            return self._generate_with_template(query, final_documents, confidence_assessment, correction_applied)
    
    def _generate_with_llm(self, query: TravelQuery, documents: List[RetrievalResult], 
                          confidence: ConfidenceAssessment, correction_applied: bool) -> Tuple[str, float]:
        # Prepare context
        context_parts = []
        for result in documents:
            doc = result.document
            freshness_info = f"(Last updated: {doc.last_updated.strftime('%Y-%m-%d')})"
            web_verified = "✓ Web-verified" if result.relevance_indicators.get('web_verified', False) else ""
            context_parts.append(f"Title: {doc.title}\nLocation: {doc.location}\nContent: {doc.content} {freshness_info} {web_verified}")
        
        context = "\n\n".join(context_parts)
        
        # Create CRAG-aware prompt
        correction_note = ""
        if correction_applied:
            correction_note = f"\n\nNote: This response includes corrected information. Confidence was initially {confidence.level.value}, so corrective action was taken to improve accuracy."
        
        prompt = f"""You are a professional travel assistant with access to verified information. 
Use the provided context to answer the travel query accurately and helpfully.

Context (includes verification status):
{context}

Travel Query: {query.text}
Location Focus: {query.location or 'Not specified'}
Domain: {query.domain.value}
Confidence Level: {confidence.level.value} ({confidence.score:.2f})
{correction_note}

Instructions:
- Provide practical, actionable travel advice
- Include current pricing when available
- Mention any booking or timing recommendations
- Note if information has been recently verified
- Be specific about locations and logistics

Travel Assistant Response:"""
        
        try:
            response = self.model.generate_content(prompt)
            
            # Calculate confidence based on correction status and initial confidence
            final_confidence = confidence.score
            if correction_applied:
                final_confidence = min(0.95, confidence.score + 0.15)  # Boost confidence after correction
            
            return response.text, final_confidence
            
        except Exception as e:
            return f"Sorry, I encountered an error generating your travel advice: {str(e)}", 0.0
    
    def _generate_with_template(self, query: TravelQuery, documents: List[RetrievalResult], 
                               confidence: ConfidenceAssessment, correction_applied: bool) -> Tuple[str, float]:
        if not documents:
            return "I couldn't find relevant travel information to answer your question.", 0.1
        
        # Create template-based travel response
        answer = f"Based on {'verified and corrected' if correction_applied else 'available'} travel information for {query.domain.value}"
        if query.location:
            answer += f" in {query.location}"
        answer += ":\n\n"
        
        for i, result in enumerate(documents[:3], 1):
            doc = result.document
            freshness = "(Recently verified)" if result.relevance_indicators.get('web_verified', False) else f"(Updated: {doc.last_updated.strftime('%Y-%m-%d')})"
            
            answer += f"**{i}. {doc.title}** {freshness}\n"
            answer += f"{doc.content[:300]}{'...' if len(doc.content) > 300 else ''}\n\n"
        
        if correction_applied:
            answer += f"\n*Note: This information has been enhanced through corrective search to ensure accuracy and freshness.*"
        
        # Adjust confidence based on correction and document count
        base_confidence = min(0.8, len(documents) / 3.0)
        if correction_applied:
            base_confidence = min(0.9, base_confidence + 0.15)
        
        return answer, base_confidence

# Initialize CRAG generation module
crag_generation_module = CRAGGenerationModule()
print("✅ CRAG Generation Module ready!")

🤖 Gemini API configured for CRAG
✅ CRAG Generation Module ready!


## 🧩 Complete CRAG System

Integrate all modules into the complete Corrective RAG system:

In [11]:
class CorrectiveRAGSystem:
    def __init__(self):
        print("🔧 Initializing Corrective RAG (CRAG) System...")
        
        # Initialize all modules
        self.modules = {
            'query': travel_query_module,
            'retrieval': travel_retrieval_module,
            'confidence': confidence_module,
            'web_search': web_search_module,
            'decomposition': query_decomposition_module,
            'refinement': knowledge_refinement_module,
            'generation': crag_generation_module
        }
        
        # System metrics
        self.total_queries = 0
        self.corrections_applied = 0
        self.avg_confidence_improvement = 0.0
        self.processing_times = []
        
        print("✅ CRAG System initialized!")
        print(f"🔧 Active modules: {list(self.modules.keys())}")
    
    def process_query(self, query_text: str, user_id: str = None) -> CRAGResponse:
        start_time = time.time()
        pipeline = []
        
        try:
            print(f"\n🔧 CRAG Processing: '{query_text}'")
            print("=" * 60)
            
            # Step 1: Query Processing
            print("📝 Step 1: Travel query analysis...")
            query = self.modules['query'].process(query_text)
            query.user_id = user_id
            pipeline.append('query_processing')
            
            print(f"   Domain: {query.domain.value}, Location: {query.location or 'Not specified'}")
            
            # Step 2: Initial Retrieval
            print("🔍 Step 2: Initial document retrieval...")
            initial_results = self.modules['retrieval'].process(query, RetrievalStrategy.HYBRID, top_k=5)
            pipeline.append('initial_retrieval')
            
            print(f"   Retrieved {len(initial_results)} documents")
            for i, result in enumerate(initial_results[:3], 1):
                freshness = result.relevance_indicators.get('freshness_score', 0.5)
                print(f"   {i}. {result.document.title} (Score: {result.score:.3f}, Freshness: {freshness:.2f})")
            
            # Step 3: Confidence Assessment
            print("📊 Step 3: Confidence assessment...")
            confidence_assessment = self.modules['confidence'].process(query, initial_results)
            pipeline.append('confidence_assessment')
            
            print(f"   Confidence: {confidence_assessment.level.value} ({confidence_assessment.score:.2f})")
            print(f"   Recommended action: {confidence_assessment.recommended_action.value}")
            print(f"   Reasons: {', '.join(confidence_assessment.reasons[:2])}")
            
            # Step 4: Corrective Action (if needed)
            correction_result = None
            final_documents = initial_results
            correction_applied = False
            
            if confidence_assessment.recommended_action != CorrectionAction.NONE:
                print(f"🔧 Step 4: Applying correction - {confidence_assessment.recommended_action.value}...")
                
                if confidence_assessment.recommended_action == CorrectionAction.WEB_SEARCH:
                    correction_result = self.modules['web_search'].process(query)
                elif confidence_assessment.recommended_action == CorrectionAction.DECOMPOSE_QUERY:
                    correction_result = self.modules['decomposition'].process(query, self.modules['retrieval'])
                elif confidence_assessment.recommended_action == CorrectionAction.REFINE_KNOWLEDGE:
                    correction_result = self.modules['refinement'].process(query, self.modules['retrieval'])
                else:  # HYBRID_SEARCH
                    correction_result = self.modules['refinement'].process(query, self.modules['retrieval'])
                
                if correction_result and correction_result.success:
                    final_documents = self._merge_results(initial_results, correction_result.new_documents)
                    correction_applied = True
                    self.corrections_applied += 1
                    
                    print(f"   Correction successful: {len(correction_result.new_documents)} new documents")
                    print(f"   Processing time: {correction_result.processing_time:.2f}s")
                else:
                    print(f"   Correction failed or not needed")
                
                pipeline.append(f'correction_{confidence_assessment.recommended_action.value}')
            else:
                print("✅ Step 4: No correction needed - high confidence")
                pipeline.append('no_correction_needed')
            
            # Step 5: Answer Generation
            print("🤖 Step 5: Travel advice generation...")
            generated_answer, final_confidence = self.modules['generation'].process(
                query, final_documents, confidence_assessment, correction_applied
            )
            pipeline.append('generation')
            
            print(f"   Generated response (Final confidence: {final_confidence:.2f})")
            
            # Create response
            processing_time = time.time() - start_time
            
            response = CRAGResponse(
                query=query,
                initial_retrieval=initial_results,
                confidence_assessment=confidence_assessment,
                correction_result=correction_result,
                final_documents=final_documents,
                generated_answer=generated_answer,
                overall_confidence=final_confidence,
                processing_pipeline=pipeline,
                processing_time=processing_time,
                correction_applied=correction_applied
            )
            
            # Update system metrics
            self._update_metrics(response, confidence_assessment.score, final_confidence)
            
            print(f"\n✅ CRAG query processed in {processing_time:.2f}s")
            return response
            
        except Exception as e:
            processing_time = time.time() - start_time
            print(f"❌ Error: {str(e)}")
            
            # Return error response
            error_query = TravelQuery("error", query_text, TravelDomain.GENERAL, None, None, user_id)
            error_assessment = ConfidenceAssessment(ConfidenceLevel.LOW, 0.0, ["Error occurred"], CorrectionAction.NONE)
            
            return CRAGResponse(
                query=error_query,
                initial_retrieval=[],
                confidence_assessment=error_assessment,
                correction_result=None,
                final_documents=[],
                generated_answer=f"Sorry, I encountered an error processing your travel query: {str(e)}",
                overall_confidence=0.0,
                processing_pipeline=['error'],
                processing_time=processing_time,
                correction_applied=False
            )
    
    def _merge_results(self, initial_results: List[RetrievalResult], 
                      correction_results: List[RetrievalResult]) -> List[RetrievalResult]:
        """Merge initial and correction results, prioritizing corrected information"""
        
        # Start with correction results (higher priority due to freshness/accuracy)
        merged_results = correction_results.copy()
        
        # Add initial results that don't conflict
        correction_doc_ids = {r.document.id for r in correction_results}
        
        for result in initial_results:
            if result.document.id not in correction_doc_ids:
                # Slightly reduce score to show it's less fresh/verified
                result.score *= 0.9
                merged_results.append(result)
        
        # Sort by score and limit to top results
        merged_results.sort(key=lambda x: x.score, reverse=True)
        
        # Re-assign ranks
        for i, result in enumerate(merged_results[:7]):
            result.rank = i + 1
        
        return merged_results[:7]
    
    def _update_metrics(self, response: CRAGResponse, initial_confidence: float, final_confidence: float):
        """Update system performance metrics"""
        self.total_queries += 1
        self.processing_times.append(response.processing_time)
        
        if response.correction_applied:
            confidence_improvement = final_confidence - initial_confidence
            self.avg_confidence_improvement = ((self.avg_confidence_improvement * (self.corrections_applied - 1)) + 
                                             confidence_improvement) / self.corrections_applied
    
    def get_system_status(self) -> Dict[str, Any]:
        """Get comprehensive system status and metrics"""
        module_stats = {name: module.get_info() for name, module in self.modules.items()}
        
        return {
            'total_queries': self.total_queries,
            'corrections_applied': self.corrections_applied,
            'correction_rate': (self.corrections_applied / max(self.total_queries, 1)) * 100,
            'avg_confidence_improvement': self.avg_confidence_improvement,
            'avg_processing_time': np.mean(self.processing_times) if self.processing_times else 0.0,
            'module_stats': module_stats,
            'correction_effectiveness': self.avg_confidence_improvement * 100
        }

# Initialize complete CRAG system
corrective_rag = CorrectiveRAGSystem()
print("\n🚀 Complete Corrective RAG System ready!")

🔧 Initializing Corrective RAG (CRAG) System...
✅ CRAG System initialized!
🔧 Active modules: ['query', 'retrieval', 'confidence', 'web_search', 'decomposition', 'refinement', 'generation']

🚀 Complete Corrective RAG System ready!


In [12]:
# Process a travel query
response = corrective_rag.process_query("Find luxury hotels in Bangkok for next week")

# Get system status
status = corrective_rag.get_system_status()


🔧 CRAG Processing: 'Find luxury hotels in Bangkok for next week'
📝 Step 1: Travel query analysis...
   Domain: hotels, Location: Bangkok
🔍 Step 2: Initial document retrieval...
   Retrieved 5 documents
   1. Grand Palace Hotel Bangkok - Luxury Accommodation (Score: 1.104, Freshness: 0.00)
   2. Bangkok Street Food Guide - Must-Try Local Dishes (Score: 0.907, Freshness: 0.00)
   3. Bangkok Public Transportation - BTS, MRT, and Taxis (Score: 0.864, Freshness: 0.00)
📊 Step 3: Confidence assessment...
   Confidence: low (0.33)
   Recommended action: web_search
   Reasons: Low confidence (score: 0.96), Information appears outdated
🔧 Step 4: Applying correction - web_search...
   Correction successful: 2 new documents
   Processing time: 0.50s
🤖 Step 5: Travel advice generation...
   Generated response (Final confidence: 0.48)

✅ CRAG query processed in 3.65s
