# Module 11: LLM Integration and Model Selection for RAG

## Learning Objectives
By the end of this module, you will:
- Understand different LLM categories and their RAG suitability
- Learn practical integration techniques with LangChain
- Implement cost-performance optimization strategies
- Build multi-model RAG systems
- Handle model-specific requirements and limitations

## Key Concepts

### 1. LLM Categories for RAG Systems

#### API-Based Models (Recommended for Production)
- **GPT-4/GPT-4 Turbo**: Excellent reasoning, complex queries
- **GPT-3.5-Turbo**: Fast, cost-effective for simple RAG tasks
- **Claude-3 (Opus/Sonnet/Haiku)**: Strong analytical capabilities
- **Gemini Pro/Ultra**: Google's competitive offering

#### Open Source Models (Self-Hosted)
- **Llama 2/3**: Meta's open models, various sizes
- **Mistral 7B/8x7B**: Efficient European alternative
- **Code Llama**: Specialized for code-related RAG
- **Falcon**: UAE's competitive open model

### 2. RAG-Specific Model Considerations

#### Context Window Size
- **GPT-4 Turbo**: 128K tokens
- **Claude-3**: 200K tokens
- **Gemini Pro**: 1M tokens (limited preview)
- **Open Source**: Usually 2K-32K tokens

#### Instruction Following
- Critical for RAG prompt adherence
- Citation generation accuracy
- Context utilization efficiency

#### Cost Structure (2025 Pricing)
- **Input tokens**: Usually cheaper (context)
- **Output tokens**: More expensive (generation)
- **RAG impact**: High input/output ratio

### 3. Multi-Model Strategies

#### Router-Based Approach
- Simple queries → Cheaper model
- Complex queries → Premium model
- Code queries → Specialized model

#### Fallback Systems
- Primary model fails → Secondary model
- Rate limits → Alternative provider
- Cost optimization → Model switching

---

## Setup and Imports

In [None]:
# Install required packages
!pip install langchain langchain-openai langchain-anthropic langchain-google-genai
!pip install langchain-community tiktoken chromadb
!pip install openai anthropic google-generativeai
!pip install python-dotenv pandas matplotlib seaborn plotly

In [None]:
import os
import time
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# LangChain imports
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from langchain.callbacks import get_openai_callback
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# Environment setup
from dotenv import load_dotenv
load_dotenv()

print("✅ All packages imported successfully!")

## Exercise 1: Model Performance Comparison

Let's create a comprehensive system to compare different LLMs for RAG tasks.

In [None]:
@dataclass
class ModelMetrics:
    """Store performance metrics for each model"""
    model_name: str
    response_time: float
    tokens_used: int
    cost: float
    quality_score: float
    context_utilization: float
    citation_accuracy: float

class LLMComparator:
    """Compare different LLMs for RAG tasks"""
    
    def __init__(self):
        self.models = {}
        self.test_results = []
        self.setup_models()
    
    def setup_models(self):
        """Initialize different LLM models"""
        try:
            # OpenAI Models
            if os.getenv("OPENAI_API_KEY"):
                self.models['gpt-4-turbo'] = ChatOpenAI(
                    model="gpt-4-turbo-preview",
                    temperature=0.1
                )
                self.models['gpt-3.5-turbo'] = ChatOpenAI(
                    model="gpt-3.5-turbo",
                    temperature=0.1
                )
            
            # Anthropic Models
            if os.getenv("ANTHROPIC_API_KEY"):
                self.models['claude-3-opus'] = ChatAnthropic(
                    model="claude-3-opus-20240229",
                    temperature=0.1
                )
                self.models['claude-3-sonnet'] = ChatAnthropic(
                    model="claude-3-sonnet-20240229",
                    temperature=0.1
                )
            
            # Google Models
            if os.getenv("GOOGLE_API_KEY"):
                self.models['gemini-pro'] = ChatGoogleGenerativeAI(
                    model="gemini-pro",
                    temperature=0.1
                )
            
            print(f"✅ Initialized {len(self.models)} models")
            
        except Exception as e:
            print(f"⚠️ Some models failed to initialize: {e}")
            print("💡 Make sure to set your API keys in .env file")
    
    def create_rag_prompt(self, query: str, context: str) -> str:
        """Create standardized RAG prompt"""
        prompt = f"""
You are a helpful assistant that answers questions based on the provided context.
Use ONLY the information from the context to answer the question.
If the context doesn't contain enough information, say "I don't have enough information to answer this question."
Always cite specific parts of the context in your answer.

Context:
{context}

Question: {query}

Answer:
"""
        return prompt
    
    def evaluate_response_quality(self, response: str, expected_keywords: List[str]) -> float:
        """Simple quality evaluation based on keyword presence"""
        score = 0.0
        response_lower = response.lower()
        
        for keyword in expected_keywords:
            if keyword.lower() in response_lower:
                score += 1.0
        
        return min(score / len(expected_keywords), 1.0) if expected_keywords else 0.0
    
    def check_citation_accuracy(self, response: str, context: str) -> float:
        """Check if citations are accurate"""
        # Simple heuristic: check if quoted text appears in context
        import re
        quotes = re.findall(r'"([^"]+)"', response)
        
        if not quotes:
            return 0.5  # No citations, neutral score
        
        accurate_citations = 0
        for quote in quotes:
            if quote.lower() in context.lower():
                accurate_citations += 1
        
        return accurate_citations / len(quotes)
    
    def test_model(self, model_name: str, query: str, context: str, 
                  expected_keywords: List[str]) -> ModelMetrics:
        """Test a single model and return metrics"""
        
        if model_name not in self.models:
            print(f"⚠️ Model {model_name} not available")
            return None
        
        model = self.models[model_name]
        prompt = self.create_rag_prompt(query, context)
        
        # Measure response time
        start_time = time.time()
        
        try:
            if 'gpt' in model_name:
                with get_openai_callback() as cb:
                    response = model.invoke([HumanMessage(content=prompt)])
                    response_text = response.content
                    tokens_used = cb.total_tokens
                    cost = cb.total_cost
            else:
                response = model.invoke([HumanMessage(content=prompt)])
                response_text = response.content
                tokens_used = len(prompt.split()) + len(response_text.split())  # Rough estimate
                cost = self.estimate_cost(model_name, tokens_used)
            
            response_time = time.time() - start_time
            
            # Evaluate quality metrics
            quality_score = self.evaluate_response_quality(response_text, expected_keywords)
            citation_accuracy = self.check_citation_accuracy(response_text, context)
            context_utilization = min(len(response_text) / len(context), 1.0)
            
            metrics = ModelMetrics(
                model_name=model_name,
                response_time=response_time,
                tokens_used=tokens_used,
                cost=cost,
                quality_score=quality_score,
                context_utilization=context_utilization,
                citation_accuracy=citation_accuracy
            )
            
            print(f"✅ {model_name}: {response_time:.2f}s, Quality: {quality_score:.2f}")
            return metrics
            
        except Exception as e:
            print(f"❌ Error testing {model_name}: {e}")
            return None
    
    def estimate_cost(self, model_name: str, tokens: int) -> float:
        """Estimate cost based on 2025 pricing (approximate)"""
        pricing = {
            'claude-3-opus': 0.015,      # per 1K tokens
            'claude-3-sonnet': 0.003,
            'gemini-pro': 0.0005,
            'gpt-4-turbo': 0.01,
            'gpt-3.5-turbo': 0.0015
        }
        
        rate = pricing.get(model_name, 0.001)
        return (tokens / 1000) * rate
    
    def run_comparison(self, test_cases: List[Dict]) -> pd.DataFrame:
        """Run comparison across all models and test cases"""
        results = []
        
        for i, test_case in enumerate(test_cases):
            print(f"\n🧪 Test Case {i+1}: {test_case['query'][:50]}...")
            
            for model_name in self.models.keys():
                metrics = self.test_model(
                    model_name,
                    test_case['query'],
                    test_case['context'],
                    test_case['expected_keywords']
                )
                
                if metrics:
                    result_dict = {
                        'test_case': f"Test {i+1}",
                        'model': metrics.model_name,
                        'response_time': metrics.response_time,
                        'tokens_used': metrics.tokens_used,
                        'cost': metrics.cost,
                        'quality_score': metrics.quality_score,
                        'citation_accuracy': metrics.citation_accuracy,
                        'context_utilization': metrics.context_utilization
                    }
                    results.append(result_dict)
        
        return pd.DataFrame(results)

# Initialize comparator
comparator = LLMComparator()

### Define Test Cases

In [None]:
# Create test cases for model comparison
test_cases = [
    {
        'query': 'What are the main benefits of renewable energy?',
        'context': '''
        Renewable energy sources offer numerous advantages over fossil fuels. 
        First, they significantly reduce greenhouse gas emissions, helping combat climate change. 
        Solar and wind power produce no direct carbon emissions during operation. 
        Second, renewable energy sources are inexhaustible - the sun and wind will continue 
        to provide energy for billions of years. Third, they reduce dependence on imported 
        fossil fuels, enhancing energy security. Fourth, renewable energy creates jobs in 
        manufacturing, installation, and maintenance sectors. Finally, operating costs are 
        typically lower than fossil fuels once infrastructure is in place.
        ''',
        'expected_keywords': ['greenhouse gas', 'climate change', 'inexhaustible', 'energy security', 'jobs']
    },
    {
        'query': 'How does machine learning work in recommendation systems?',
        'context': '''
        Machine learning powers modern recommendation systems through several approaches. 
        Collaborative filtering analyzes user behavior patterns to find similar users or items. 
        Content-based filtering recommends items similar to those a user previously liked. 
        Matrix factorization techniques decompose user-item interaction matrices to discover 
        latent features. Deep learning models can capture complex non-linear patterns in 
        user preferences. Hybrid systems combine multiple approaches for better accuracy. 
        Real-time learning allows systems to adapt quickly to changing user preferences.
        ''',
        'expected_keywords': ['collaborative filtering', 'content-based', 'matrix factorization', 'deep learning', 'hybrid']
    },
    {
        'query': 'What are the key principles of sustainable agriculture?',
        'context': '''
        Sustainable agriculture focuses on meeting current food needs while preserving 
        resources for future generations. Key principles include soil health management 
        through crop rotation, cover cropping, and minimal tillage. Water conservation 
        involves efficient irrigation systems and drought-resistant crops. Biodiversity 
        preservation includes maintaining diverse crop varieties and supporting beneficial 
        insects. Integrated pest management reduces chemical pesticide use through 
        biological controls. Economic viability ensures farmers can maintain profitable 
        operations while following sustainable practices.
        ''',
        'expected_keywords': ['soil health', 'water conservation', 'biodiversity', 'pest management', 'economic viability']
    }
]

print(f"📋 Created {len(test_cases)} test cases for model comparison")

### Run Model Comparison

In [None]:
# Run the comparison (this may take a few minutes)
print("🚀 Starting model comparison...")
results_df = comparator.run_comparison(test_cases)

if not results_df.empty:
    print("\n📊 Results Summary:")
    print(results_df.groupby('model').agg({
        'response_time': 'mean',
        'cost': 'mean',
        'quality_score': 'mean',
        'citation_accuracy': 'mean'
    }).round(4))
else:
    print("❌ No results generated. Please check your API keys.")

## Exercise 2: Cost-Performance Optimization

Let's build a system that automatically selects the best model based on cost and performance requirements.

In [None]:
class RAGOptimizer:
    """Optimize model selection for RAG tasks based on requirements"""
    
    def __init__(self, comparator: LLMComparator):
        self.comparator = comparator
        self.model_profiles = self.create_model_profiles()
    
    def create_model_profiles(self) -> Dict[str, Dict]:
        """Create performance profiles for each model"""
        profiles = {
            'gpt-4-turbo': {
                'speed': 0.6,        # Relative speed (0-1)
                'quality': 0.95,     # Quality score (0-1)
                'cost': 0.8,         # Cost factor (0-1, higher = more expensive)
                'context_window': 128000,
                'specialties': ['complex reasoning', 'analysis', 'code'],
                'best_for': 'high-quality responses, complex queries'
            },
            'gpt-3.5-turbo': {
                'speed': 0.9,
                'quality': 0.8,
                'cost': 0.3,
                'context_window': 16000,
                'specialties': ['general queries', 'speed'],
                'best_for': 'fast responses, simple to medium complexity'
            },
            'claude-3-opus': {
                'speed': 0.5,
                'quality': 0.98,
                'cost': 0.9,
                'context_window': 200000,
                'specialties': ['analysis', 'reasoning', 'long context'],
                'best_for': 'highest quality, complex analysis'
            },
            'claude-3-sonnet': {
                'speed': 0.7,
                'quality': 0.85,
                'cost': 0.4,
                'context_window': 200000,
                'specialties': ['balanced performance', 'long context'],
                'best_for': 'balanced cost-quality, long documents'
            },
            'gemini-pro': {
                'speed': 0.8,
                'quality': 0.82,
                'cost': 0.2,
                'context_window': 30000,
                'specialties': ['multimodal', 'cost-effective'],
                'best_for': 'cost-effective, multimodal tasks'
            }
        }
        return profiles
    
    def calculate_suitability_score(self, model_name: str, requirements: Dict) -> float:
        """Calculate how suitable a model is for given requirements"""
        if model_name not in self.model_profiles:
            return 0.0
        
        profile = self.model_profiles[model_name]
        
        # Weight factors based on requirements
        speed_weight = requirements.get('speed_priority', 0.3)
        quality_weight = requirements.get('quality_priority', 0.4)
        cost_weight = requirements.get('cost_priority', 0.3)
        
        # Calculate weighted score
        speed_score = profile['speed'] * speed_weight
        quality_score = profile['quality'] * quality_weight
        cost_score = (1 - profile['cost']) * cost_weight  # Invert cost (lower cost = higher score)
        
        total_score = speed_score + quality_score + cost_score
        
        # Apply context window requirement
        required_context = requirements.get('context_length', 0)
        if required_context > profile['context_window']:
            total_score *= 0.1  # Heavy penalty for insufficient context window
        
        return total_score
    
    def recommend_model(self, requirements: Dict) -> Tuple[str, Dict]:
        """Recommend the best model based on requirements"""
        scores = {}
        
        for model_name in self.model_profiles.keys():
            if model_name in self.comparator.models:
                score = self.calculate_suitability_score(model_name, requirements)
                scores[model_name] = score
        
        if not scores:
            return None, {}
        
        best_model = max(scores.keys(), key=lambda k: scores[k])
        recommendation = {
            'model': best_model,
            'score': scores[best_model],
            'profile': self.model_profiles[best_model],
            'all_scores': scores
        }
        
        return best_model, recommendation
    
    def create_optimization_report(self, scenarios: List[Dict]) -> pd.DataFrame:
        """Create optimization recommendations for different scenarios"""
        reports = []
        
        for scenario in scenarios:
            best_model, recommendation = self.recommend_model(scenario['requirements'])
            
            if best_model:
                reports.append({
                    'scenario': scenario['name'],
                    'recommended_model': best_model,
                    'score': recommendation['score'],
                    'reason': recommendation['profile']['best_for'],
                    'estimated_cost': self.estimate_scenario_cost(best_model, scenario),
                    'speed': recommendation['profile']['speed'],
                    'quality': recommendation['profile']['quality']
                })
        
        return pd.DataFrame(reports)
    
    def estimate_scenario_cost(self, model_name: str, scenario: Dict) -> float:
        """Estimate cost for a scenario"""
        monthly_queries = scenario.get('monthly_queries', 1000)
        avg_tokens = scenario.get('avg_tokens', 2000)
        
        cost_per_query = self.comparator.estimate_cost(model_name, avg_tokens)
        return monthly_queries * cost_per_query

# Initialize optimizer
optimizer = RAGOptimizer(comparator)
print("✅ RAG Optimizer initialized")

### Define Optimization Scenarios

In [None]:
# Define different usage scenarios
optimization_scenarios = [
    {
        'name': 'High-Volume Customer Support',
        'requirements': {
            'speed_priority': 0.5,      # High speed priority
            'quality_priority': 0.3,    # Medium quality priority
            'cost_priority': 0.2,       # Low cost priority (speed more important)
            'context_length': 4000,     # Typical support context
        },
        'monthly_queries': 10000,
        'avg_tokens': 1500
    },
    {
        'name': 'Research Analysis',
        'requirements': {
            'speed_priority': 0.2,      # Low speed priority
            'quality_priority': 0.6,    # High quality priority
            'cost_priority': 0.2,       # Medium cost priority
            'context_length': 50000,    # Long research documents
        },
        'monthly_queries': 500,
        'avg_tokens': 8000
    },
    {
        'name': 'Budget-Conscious Startup',
        'requirements': {
            'speed_priority': 0.3,      # Medium speed priority
            'quality_priority': 0.2,    # Medium quality priority
            'cost_priority': 0.5,       # High cost priority
            'context_length': 8000,     # Standard context
        },
        'monthly_queries': 2000,
        'avg_tokens': 3000
    },
    {
        'name': 'Premium Enterprise Service',
        'requirements': {
            'speed_priority': 0.3,      # Medium speed priority
            'quality_priority': 0.7,    # Highest quality priority
            'cost_priority': 0.0,       # No cost constraints
            'context_length': 100000,   # Very long contexts
        },
        'monthly_queries': 1000,
        'avg_tokens': 12000
    },
    {
        'name': 'Real-time Chat Application',
        'requirements': {
            'speed_priority': 0.6,      # Highest speed priority
            'quality_priority': 0.3,    # Medium quality priority
            'cost_priority': 0.1,       # Low cost priority
            'context_length': 2000,     # Short contexts for speed
        },
        'monthly_queries': 20000,
        'avg_tokens': 800
    }
]

print(f"📋 Created {len(optimization_scenarios)} optimization scenarios")

### Generate Optimization Report

In [None]:
# Generate optimization recommendations
optimization_report = optimizer.create_optimization_report(optimization_scenarios)

print("🎯 Model Optimization Recommendations:")
print("="*80)

if not optimization_report.empty:
    for _, row in optimization_report.iterrows():
        print(f"\n📊 {row['scenario']}")
        print(f"   Recommended: {row['recommended_model']}")
        print(f"   Reason: {row['reason']}")
        print(f"   Estimated Monthly Cost: ${row['estimated_cost']:.2f}")
        print(f"   Speed Score: {row['speed']:.2f} | Quality Score: {row['quality']:.2f}")
    
    # Display summary table
    print("\n📋 Summary Table:")
    display_cols = ['scenario', 'recommended_model', 'estimated_cost', 'speed', 'quality']
    print(optimization_report[display_cols].to_string(index=False))
else:
    print("❌ No optimization report generated. Check model availability.")

## Exercise 3: Multi-Model RAG System

Let's build a sophisticated RAG system that can route queries to different models based on query characteristics.

In [None]:
import re
from enum import Enum

class QueryType(Enum):
    SIMPLE = "simple"
    COMPLEX = "complex"
    ANALYTICAL = "analytical"
    CODE = "code"
    FACTUAL = "factual"
    CREATIVE = "creative"

class MultiModelRAGSystem:
    """Advanced RAG system with intelligent model routing"""
    
    def __init__(self, comparator: LLMComparator):
        self.comparator = comparator
        self.routing_rules = self.setup_routing_rules()
        self.fallback_chain = self.setup_fallback_chain()
        self.query_cache = {}
    
    def setup_routing_rules(self) -> Dict[QueryType, List[str]]:
        """Define which models to use for different query types"""
        return {
            QueryType.SIMPLE: ['gpt-3.5-turbo', 'gemini-pro', 'claude-3-sonnet'],
            QueryType.COMPLEX: ['gpt-4-turbo', 'claude-3-opus', 'claude-3-sonnet'],
            QueryType.ANALYTICAL: ['claude-3-opus', 'gpt-4-turbo', 'claude-3-sonnet'],
            QueryType.CODE: ['gpt-4-turbo', 'claude-3-opus', 'gpt-3.5-turbo'],
            QueryType.FACTUAL: ['gpt-3.5-turbo', 'gemini-pro', 'claude-3-sonnet'],
            QueryType.CREATIVE: ['gpt-4-turbo', 'claude-3-opus', 'claude-3-sonnet']
        }
    
    def setup_fallback_chain(self) -> List[str]:
        """Define fallback order if primary models fail"""
        return ['gpt-3.5-turbo', 'gemini-pro', 'claude-3-sonnet', 'gpt-4-turbo']
    
    def classify_query(self, query: str) -> QueryType:
        """Classify query to determine appropriate model"""
        query_lower = query.lower()
        
        # Code-related patterns
        code_patterns = ['code', 'function', 'algorithm', 'programming', 'debug', 'syntax']
        if any(pattern in query_lower for pattern in code_patterns):
            return QueryType.CODE
        
        # Analytical patterns
        analytical_patterns = ['analyze', 'compare', 'evaluate', 'assess', 'critique', 'implications']
        if any(pattern in query_lower for pattern in analytical_patterns):
            return QueryType.ANALYTICAL
        
        # Creative patterns
        creative_patterns = ['create', 'generate', 'write', 'design', 'brainstorm', 'imagine']
        if any(pattern in query_lower for pattern in creative_patterns):
            return QueryType.CREATIVE
        
        # Complex patterns (multiple questions, compound queries)
        if len(query.split('?')) > 2 or len(query.split(' and ')) > 2:
            return QueryType.COMPLEX
        
        # Simple factual patterns
        factual_patterns = ['what is', 'who is', 'when did', 'where is', 'how many']
        if any(pattern in query_lower for pattern in factual_patterns):
            return QueryType.FACTUAL
        
        # Default to simple for short queries
        if len(query.split()) < 10:
            return QueryType.SIMPLE
        
        return QueryType.COMPLEX
    
    def select_model(self, query_type: QueryType) -> str:
        """Select the best available model for the query type"""
        preferred_models = self.routing_rules.get(query_type, self.fallback_chain)
        
        # Find first available model
        for model in preferred_models:
            if model in self.comparator.models:
                return model
        
        # Fallback to any available model
        available_models = list(self.comparator.models.keys())
        return available_models[0] if available_models else None
    
    def query_with_fallback(self, query: str, context: str, max_retries: int = 3) -> Dict:
        """Query with automatic fallback on failures"""
        query_type = self.classify_query(query)
        
        # Try preferred models first
        attempted_models = []
        
        for attempt in range(max_retries):
            try:
                # Select model (avoid already attempted models)
                available_models = [m for m in self.routing_rules.get(query_type, self.fallback_chain) 
                                  if m in self.comparator.models and m not in attempted_models]
                
                if not available_models:
                    # Try fallback chain
                    available_models = [m for m in self.fallback_chain 
                                      if m in self.comparator.models and m not in attempted_models]
                
                if not available_models:
                    break  # No more models to try
                
                selected_model = available_models[0]
                attempted_models.append(selected_model)
                
                print(f"🎯 Routing {query_type.value} query to {selected_model}")
                
                # Execute query
                start_time = time.time()
                model = self.comparator.models[selected_model]
                prompt = self.comparator.create_rag_prompt(query, context)
                
                response = model.invoke([HumanMessage(content=prompt)])
                response_time = time.time() - start_time
                
                return {
                    'success': True,
                    'model_used': selected_model,
                    'query_type': query_type.value,
                    'response': response.content,
                    'response_time': response_time,
                    'attempt': attempt + 1
                }
                
            except Exception as e:
                print(f"⚠️ {selected_model} failed (attempt {attempt + 1}): {e}")
                continue
        
        return {
            'success': False,
            'error': 'All models failed',
            'attempted_models': attempted_models,
            'query_type': query_type.value
        }
    
    def batch_process(self, queries: List[Dict]) -> List[Dict]:
        """Process multiple queries with optimal model routing"""
        results = []
        
        for i, query_data in enumerate(queries):
            print(f"\n🔄 Processing query {i+1}/{len(queries)}")
            
            result = self.query_with_fallback(
                query_data['query'],
                query_data['context']
            )
            
            result['original_query'] = query_data['query']
            results.append(result)
        
        return results
    
    def get_routing_stats(self, results: List[Dict]) -> pd.DataFrame:
        """Generate routing statistics"""
        stats = []
        
        for result in results:
            if result['success']:
                stats.append({
                    'query_type': result['query_type'],
                    'model_used': result['model_used'],
                    'response_time': result['response_time'],
                    'attempt': result['attempt']
                })
        
        return pd.DataFrame(stats)

# Initialize multi-model system
multi_rag = MultiModelRAGSystem(comparator)
print("✅ Multi-Model RAG System initialized")

### Test Multi-Model Routing

In [None]:
# Create diverse test queries
test_queries = [
    {
        'query': 'What is photosynthesis?',
        'context': 'Photosynthesis is the process by which plants convert light energy into chemical energy. During photosynthesis, plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.'
    },
    {
        'query': 'Analyze the economic implications of renewable energy adoption and compare it with traditional fossil fuel infrastructure investments.',
        'context': 'The transition to renewable energy requires significant upfront investments but offers long-term economic benefits. Traditional fossil fuel infrastructure has high operational costs and environmental externalities.'
    },
    {
        'query': 'Write a Python function that implements binary search algorithm.',
        'context': 'Binary search is an efficient algorithm for finding an item from a sorted list of items. It works by repeatedly dividing the search space in half.'
    },
    {
        'query': 'Create a creative story about a time-traveling scientist.',
        'context': 'Science fiction often explores themes of time travel and scientific discovery. Time travel stories can examine cause and effect, paradoxes, and the nature of destiny.'
    },
    {
        'query': 'How many planets are in our solar system and when was Pluto reclassified?',
        'context': 'The solar system contains eight planets since Pluto was reclassified as a dwarf planet in 2006 by the International Astronomical Union.'
    }
]

print(f"🧪 Testing with {len(test_queries)} diverse queries")

In [None]:
# Process queries through multi-model system
print("🚀 Starting multi-model routing test...")
routing_results = multi_rag.batch_process(test_queries)

# Display results
print("\n📊 Routing Results:")
print("="*80)

for result in routing_results:
    if result['success']:
        print(f"\n✅ Query: {result['original_query'][:60]}...")
        print(f"   Type: {result['query_type']} | Model: {result['model_used']}")
        print(f"   Time: {result['response_time']:.2f}s | Attempt: {result['attempt']}")
        print(f"   Response: {result['response'][:100]}...")
    else:
        print(f"\n❌ Failed Query: {result['original_query'][:60]}...")
        print(f"   Error: {result['error']}")

# Generate routing statistics
if any(r['success'] for r in routing_results):
    stats_df = multi_rag.get_routing_stats(routing_results)
    
    print("\n📈 Routing Statistics:")
    print(stats_df.groupby(['query_type', 'model_used']).size().unstack(fill_value=0))
    
    print("\n⏱️ Average Response Times by Model:")
    print(stats_df.groupby('model_used')['response_time'].mean().round(2))

## Exercise 4: Production-Ready Integration Patterns

Let's implement production-ready patterns for LLM integration including rate limiting, error handling, and monitoring.

In [None]:
import asyncio
from collections import defaultdict, deque
from datetime import datetime, timedelta
import logging

class ProductionRAGManager:
    """Production-ready RAG manager with monitoring and reliability features"""
    
    def __init__(self, multi_rag: MultiModelRAGSystem):
        self.multi_rag = multi_rag
        self.rate_limiter = RateLimiter()
        self.circuit_breaker = CircuitBreaker()
        self.monitor = RAGMonitor()
        self.cache = QueryCache()
        
        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def query(self, query: str, context: str, **kwargs) -> Dict:
        """Production query with full safety features"""
        query_id = f"q_{datetime.now().timestamp()}"
        
        try:
            # Check cache first
            cached_result = self.cache.get(query, context)
            if cached_result:
                self.logger.info(f"Cache hit for query {query_id}")
                return cached_result
            
            # Check rate limits
            if not self.rate_limiter.can_proceed(kwargs.get('user_id', 'default')):
                return {
                    'success': False,
                    'error': 'Rate limit exceeded',
                    'query_id': query_id
                }
            
            # Check circuit breaker
            if not self.circuit_breaker.can_call():
                return {
                    'success': False,
                    'error': 'Service temporarily unavailable',
                    'query_id': query_id
                }
            
            # Execute query
            start_time = time.time()
            result = self.multi_rag.query_with_fallback(query, context)
            execution_time = time.time() - start_time
            
            # Update circuit breaker
            if result['success']:
                self.circuit_breaker.record_success()
                
                # Cache successful result
                self.cache.set(query, context, result)
            else:
                self.circuit_breaker.record_failure()
            
            # Record metrics
            self.monitor.record_query(
                query_id=query_id,
                success=result['success'],
                model_used=result.get('model_used'),
                execution_time=execution_time,
                query_type=result.get('query_type')
            )
            
            result['query_id'] = query_id
            return result
            
        except Exception as e:
            self.logger.error(f"Query {query_id} failed: {e}")
            self.circuit_breaker.record_failure()
            
            return {
                'success': False,
                'error': str(e),
                'query_id': query_id
            }
    
    def get_health_status(self) -> Dict:
        """Get system health status"""
        return {
            'circuit_breaker_state': self.circuit_breaker.state,
            'cache_hit_rate': self.cache.get_hit_rate(),
            'total_queries': self.monitor.get_total_queries(),
            'success_rate': self.monitor.get_success_rate(),
            'avg_response_time': self.monitor.get_avg_response_time()
        }

class RateLimiter:
    """Token bucket rate limiter"""
    
    def __init__(self, requests_per_minute: int = 60):
        self.requests_per_minute = requests_per_minute
        self.buckets = defaultdict(lambda: {
            'tokens': requests_per_minute,
            'last_update': datetime.now()
        })
    
    def can_proceed(self, user_id: str) -> bool:
        """Check if request can proceed"""
        now = datetime.now()
        bucket = self.buckets[user_id]
        
        # Refill tokens based on time passed
        time_passed = (now - bucket['last_update']).total_seconds()
        tokens_to_add = (time_passed / 60) * self.requests_per_minute
        
        bucket['tokens'] = min(
            self.requests_per_minute,
            bucket['tokens'] + tokens_to_add
        )
        bucket['last_update'] = now
        
        if bucket['tokens'] >= 1:
            bucket['tokens'] -= 1
            return True
        
        return False

class CircuitBreaker:
    """Circuit breaker for fault tolerance"""
    
    def __init__(self, failure_threshold: int = 5, timeout_seconds: int = 60):
        self.failure_threshold = failure_threshold
        self.timeout_seconds = timeout_seconds
        self.failure_count = 0
        self.last_failure_time = None
        self.state = 'CLOSED'  # CLOSED, OPEN, HALF_OPEN
    
    def can_call(self) -> bool:
        """Check if calls are allowed"""
        if self.state == 'CLOSED':
            return True
        
        if self.state == 'OPEN':
            if self._should_attempt_reset():
                self.state = 'HALF_OPEN'
                return True
            return False
        
        if self.state == 'HALF_OPEN':
            return True
        
        return False
    
    def record_success(self):
        """Record successful call"""
        self.failure_count = 0
        self.state = 'CLOSED'
    
    def record_failure(self):
        """Record failed call"""
        self.failure_count += 1
        self.last_failure_time = datetime.now()
        
        if self.failure_count >= self.failure_threshold:
            self.state = 'OPEN'
    
    def _should_attempt_reset(self) -> bool:
        """Check if enough time has passed to attempt reset"""
        if self.last_failure_time is None:
            return False
        
        return (datetime.now() - self.last_failure_time).total_seconds() > self.timeout_seconds

class QueryCache:
    """Simple in-memory cache with TTL"""
    
    def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600):
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.cache = {}
        self.access_times = deque()
        self.hits = 0
        self.misses = 0
    
    def _generate_key(self, query: str, context: str) -> str:
        """Generate cache key"""
        import hashlib
        content = f"{query}:{context}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def get(self, query: str, context: str) -> Optional[Dict]:
        """Get cached result"""
        key = self._generate_key(query, context)
        
        if key in self.cache:
            entry = self.cache[key]
            
            # Check TTL
            if (datetime.now() - entry['timestamp']).total_seconds() < self.ttl_seconds:
                self.hits += 1
                return entry['result']
            else:
                del self.cache[key]
        
        self.misses += 1
        return None
    
    def set(self, query: str, context: str, result: Dict):
        """Set cached result"""
        key = self._generate_key(query, context)
        
        # Remove oldest entries if cache is full
        while len(self.cache) >= self.max_size and self.access_times:
            oldest_key = self.access_times.popleft()
            self.cache.pop(oldest_key, None)
        
        self.cache[key] = {
            'result': result,
            'timestamp': datetime.now()
        }
        self.access_times.append(key)
    
    def get_hit_rate(self) -> float:
        """Get cache hit rate"""
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

class RAGMonitor:
    """Monitor RAG system performance"""
    
    def __init__(self):
        self.queries = []
        self.model_usage = defaultdict(int)
    
    def record_query(self, query_id: str, success: bool, model_used: str,
                    execution_time: float, query_type: str):
        """Record query metrics"""
        self.queries.append({
            'query_id': query_id,
            'timestamp': datetime.now(),
            'success': success,
            'model_used': model_used,
            'execution_time': execution_time,
            'query_type': query_type
        })
        
        if model_used:
            self.model_usage[model_used] += 1
    
    def get_total_queries(self) -> int:
        """Get total number of queries"""
        return len(self.queries)
    
    def get_success_rate(self) -> float:
        """Get overall success rate"""
        if not self.queries:
            return 0.0
        
        successful = sum(1 for q in self.queries if q['success'])
        return successful / len(self.queries)
    
    def get_avg_response_time(self) -> float:
        """Get average response time"""
        if not self.queries:
            return 0.0
        
        total_time = sum(q['execution_time'] for q in self.queries if q['success'])
        successful_queries = sum(1 for q in self.queries if q['success'])
        
        return total_time / successful_queries if successful_queries > 0 else 0.0
    
    def get_model_usage_stats(self) -> Dict:
        """Get model usage statistics"""
        return dict(self.model_usage)

# Initialize production manager
prod_manager = ProductionRAGManager(multi_rag)
print("✅ Production RAG Manager initialized")

### Test Production Features

In [None]:
# Test production features
print("🧪 Testing Production RAG Manager")

# Test queries with different users
test_production_queries = [
    {
        'query': 'What is artificial intelligence?',
        'context': 'Artificial intelligence is the simulation of human intelligence in machines.',
        'user_id': 'user1'
    },
    {
        'query': 'How does machine learning work?',
        'context': 'Machine learning is a method of data analysis that automates analytical model building.',
        'user_id': 'user1'
    },
    {
        'query': 'What is artificial intelligence?',  # Same query to test caching
        'context': 'Artificial intelligence is the simulation of human intelligence in machines.',
        'user_id': 'user2'
    }
]

# Execute test queries
production_results = []
for i, query_data in enumerate(test_production_queries):
    print(f"\n📤 Query {i+1}: {query_data['query'][:40]}...")
    
    result = prod_manager.query(
        query_data['query'],
        query_data['context'],
        user_id=query_data['user_id']
    )
    
    production_results.append(result)
    
    if result['success']:
        print(f"✅ Success | Model: {result.get('model_used', 'cache')} | ID: {result['query_id']}")
    else:
        print(f"❌ Failed | Error: {result['error']} | ID: {result['query_id']}")

# Display health status
print("\n🏥 System Health Status:")
health = prod_manager.get_health_status()
for key, value in health.items():
    print(f"   {key}: {value}")

## Visualization and Analysis

In [None]:
# Create comprehensive visualization of model performance
if not optimization_report.empty:
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Model recommendations by scenario
    model_counts = optimization_report['recommended_model'].value_counts()
    ax1.pie(model_counts.values, labels=model_counts.index, autopct='%1.1f%%')
    ax1.set_title('Model Recommendations Distribution')
    
    # Cost vs Quality scatter
    ax2.scatter(optimization_report['estimated_cost'], optimization_report['quality'], 
               c=optimization_report['speed'], cmap='viridis', s=100)
    ax2.set_xlabel('Estimated Monthly Cost ($)')
    ax2.set_ylabel('Quality Score')
    ax2.set_title('Cost vs Quality (Color = Speed)')
    
    # Speed vs Quality by model
    for model in optimization_report['recommended_model'].unique():
        model_data = optimization_report[optimization_report['recommended_model'] == model]
        ax3.scatter(model_data['speed'], model_data['quality'], label=model, s=100)
    ax3.set_xlabel('Speed Score')
    ax3.set_ylabel('Quality Score')
    ax3.set_title('Speed vs Quality by Model')
    ax3.legend()
    
    # Cost comparison
    scenario_costs = optimization_report.set_index('scenario')['estimated_cost']
    scenario_costs.plot(kind='bar', ax=ax4)
    ax4.set_title('Estimated Monthly Costs by Scenario')
    ax4.set_ylabel('Cost ($)')
    ax4.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()

    # Display optimization insights
    print("\n💡 Key Insights:")
    print(f"   • Most recommended model: {model_counts.index[0]} ({model_counts.iloc[0]} scenarios)")
    print(f"   • Highest cost scenario: {scenario_costs.idxmax()} (${scenario_costs.max():.2f}/month)")
    print(f"   • Most cost-effective scenario: {scenario_costs.idxmin()} (${scenario_costs.min():.2f}/month)")
    
    cost_range = scenario_costs.max() - scenario_costs.min()
    print(f"   • Cost variation: ${cost_range:.2f}/month ({cost_range/scenario_costs.mean()*100:.1f}% of average)")
else:
    print("⚠️ No data available for visualization. Make sure models are properly configured.")

## Key Takeaways

### 1. Model Selection Strategy
- **Context Window**: Critical for RAG applications with long documents
- **Cost Structure**: Balance input/output token costs with query volume
- **Query Classification**: Route different query types to optimal models

### 2. Production Considerations
- **Rate Limiting**: Prevent abuse and manage costs
- **Circuit Breaker**: Handle model failures gracefully
- **Caching**: Reduce costs and improve response times
- **Monitoring**: Track performance and costs continuously

### 3. Multi-Model Architecture Benefits
- **Cost Optimization**: Use cheaper models for simple queries
- **Reliability**: Fallback options when primary models fail
- **Specialization**: Leverage model strengths for specific tasks

### 4. Best Practices
- Start with a single reliable model, then add complexity
- Monitor costs and performance metrics continuously
- Implement proper error handling and retry logic
- Use caching strategically to reduce API calls
- Plan for model deprecation and migration

---

## Discussion Questions

1. **Cost vs Quality Trade-offs**: When is it worth paying 10x more for a premium model?

2. **Model Routing Logic**: How would you handle edge cases in query classification?

3. **Production Monitoring**: What metrics are most important for a production RAG system?

4. **Vendor Lock-in**: How do you balance using proprietary APIs vs open source models?

5. **Future Proofing**: How do you prepare for new model releases and capability changes?

## Next Steps

- Experiment with different model routing strategies
- Implement cost tracking and budgeting
- Add support for streaming responses
- Explore model fine-tuning for specific domains
- Consider edge deployment for latency-critical applications