# Lab 3.5.7 Solutions: Production RAG System

Complete solutions for building a production-ready RAG system with caching, monitoring, and error handling.

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field
from datetime import datetime
import numpy as np
import torch
import time
import json
import hashlib
from collections import OrderedDict
import threading
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ProductionRAG")

print(f"GPU available: {torch.cuda.is_available()}")

## Exercise 1 Solution: Implement Caching Layer

**Task**: Build a caching system for embeddings and query results.

In [None]:
class LRUCache:
    """
    Thread-safe LRU cache with TTL support.
    
    Features:
    - Least Recently Used eviction
    - Time-to-live expiration
    - Thread-safe operations
    - Memory size tracking
    """
    
    def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600):
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.cache: OrderedDict = OrderedDict()
        self.timestamps: Dict[str, float] = {}
        self.lock = threading.RLock()
        self.hits = 0
        self.misses = 0
    
    def _hash_key(self, key: Any) -> str:
        """Create a hash key for any input."""
        if isinstance(key, str):
            return hashlib.md5(key.encode()).hexdigest()
        return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()
    
    def _is_expired(self, key: str) -> bool:
        """Check if a cache entry has expired."""
        if key not in self.timestamps:
            return True
        return time.time() - self.timestamps[key] > self.ttl_seconds
    
    def get(self, key: Any) -> Optional[Any]:
        """Get a value from the cache."""
        hashed_key = self._hash_key(key)
        
        with self.lock:
            if hashed_key in self.cache:
                if self._is_expired(hashed_key):
                    # Remove expired entry
                    del self.cache[hashed_key]
                    del self.timestamps[hashed_key]
                    self.misses += 1
                    return None
                
                # Move to end (most recently used)
                self.cache.move_to_end(hashed_key)
                self.hits += 1
                return self.cache[hashed_key]
            
            self.misses += 1
            return None
    
    def set(self, key: Any, value: Any) -> None:
        """Set a value in the cache."""
        hashed_key = self._hash_key(key)
        
        with self.lock:
            if hashed_key in self.cache:
                # Update existing
                self.cache.move_to_end(hashed_key)
            else:
                # Evict if necessary
                while len(self.cache) >= self.max_size:
                    oldest = next(iter(self.cache))
                    del self.cache[oldest]
                    del self.timestamps[oldest]
            
            self.cache[hashed_key] = value
            self.timestamps[hashed_key] = time.time()
    
    def clear(self) -> None:
        """Clear the cache."""
        with self.lock:
            self.cache.clear()
            self.timestamps.clear()
    
    @property
    def stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        total = self.hits + self.misses
        return {
            "size": len(self.cache),
            "max_size": self.max_size,
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": self.hits / total if total > 0 else 0
        }


class EmbeddingCache:
    """
    Specialized cache for embeddings with GPU memory awareness.
    """
    
    def __init__(
        self,
        max_embeddings: int = 10000,
        store_on_gpu: bool = False
    ):
        self.cache = LRUCache(max_size=max_embeddings, ttl_seconds=86400)  # 24h TTL
        self.store_on_gpu = store_on_gpu and torch.cuda.is_available()
    
    def get_embedding(self, text: str) -> Optional[np.ndarray]:
        """Get cached embedding."""
        result = self.cache.get(text)
        if result is not None and self.store_on_gpu:
            return result.cpu().numpy()
        return result
    
    def set_embedding(self, text: str, embedding: np.ndarray) -> None:
        """Cache an embedding."""
        if self.store_on_gpu:
            embedding = torch.tensor(embedding, device="cuda")
        self.cache.set(text, embedding)
    
    def get_or_compute(
        self,
        text: str,
        compute_fn: Callable[[str], np.ndarray]
    ) -> np.ndarray:
        """Get from cache or compute and cache."""
        cached = self.get_embedding(text)
        if cached is not None:
            return cached
        
        embedding = compute_fn(text)
        self.set_embedding(text, embedding)
        return embedding

print("Caching classes defined")

## Exercise 2 Solution: Implement Error Handling

**Task**: Build robust error handling with retries and fallbacks.

In [None]:
import functools
from typing import TypeVar, Callable
import traceback

T = TypeVar('T')

@dataclass
class RetryConfig:
    """Configuration for retry behavior."""
    max_retries: int = 3
    base_delay: float = 1.0  # seconds
    max_delay: float = 30.0
    exponential_backoff: bool = True
    retry_exceptions: tuple = (Exception,)


def with_retry(config: RetryConfig = RetryConfig()):
    """
    Decorator for automatic retry with exponential backoff.
    """
    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> T:
            last_exception = None
            
            for attempt in range(config.max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except config.retry_exceptions as e:
                    last_exception = e
                    
                    if attempt < config.max_retries:
                        # Calculate delay
                        if config.exponential_backoff:
                            delay = min(
                                config.base_delay * (2 ** attempt),
                                config.max_delay
                            )
                        else:
                            delay = config.base_delay
                        
                        logger.warning(
                            f"Attempt {attempt + 1} failed: {e}. "
                            f"Retrying in {delay:.1f}s..."
                        )
                        time.sleep(delay)
            
            # All retries exhausted
            raise last_exception
        
        return wrapper
    return decorator


class CircuitBreaker:
    """
    Circuit breaker pattern for failing gracefully.
    
    States:
    - CLOSED: Normal operation
    - OPEN: Failing fast, not calling service
    - HALF_OPEN: Testing if service recovered
    """
    
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"
    
    def __init__(
        self,
        failure_threshold: int = 5,
        reset_timeout: float = 60.0
    ):
        self.failure_threshold = failure_threshold
        self.reset_timeout = reset_timeout
        self.state = self.CLOSED
        self.failures = 0
        self.last_failure_time = 0
        self.lock = threading.Lock()
    
    def can_execute(self) -> bool:
        """Check if we should attempt execution."""
        with self.lock:
            if self.state == self.CLOSED:
                return True
            
            if self.state == self.OPEN:
                # Check if timeout has passed
                if time.time() - self.last_failure_time > self.reset_timeout:
                    self.state = self.HALF_OPEN
                    return True
                return False
            
            # HALF_OPEN - allow one attempt
            return True
    
    def record_success(self):
        """Record a successful execution."""
        with self.lock:
            self.failures = 0
            self.state = self.CLOSED
    
    def record_failure(self):
        """Record a failed execution."""
        with self.lock:
            self.failures += 1
            self.last_failure_time = time.time()
            
            if self.failures >= self.failure_threshold:
                self.state = self.OPEN
                logger.warning(f"Circuit breaker opened after {self.failures} failures")
    
    def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
        """Execute with circuit breaker protection."""
        if not self.can_execute():
            raise Exception("Circuit breaker is open")
        
        try:
            result = func(*args, **kwargs)
            self.record_success()
            return result
        except Exception as e:
            self.record_failure()
            raise

print("Error handling classes defined")

## Exercise 3 Solution: Implement Monitoring

**Task**: Build monitoring for latency, throughput, and quality metrics.

In [None]:
@dataclass
class QueryMetrics:
    """Metrics for a single query."""
    query_id: str
    timestamp: datetime
    query: str
    latency_ms: float
    retrieval_latency_ms: float
    generation_latency_ms: float
    num_contexts: int
    top_score: float
    cache_hit: bool
    success: bool
    error: Optional[str] = None


class RAGMonitor:
    """
    Comprehensive monitoring for RAG systems.
    
    Tracks:
    - Latency distribution (p50, p95, p99)
    - Throughput (QPS)
    - Error rates
    - Cache performance
    - Retrieval quality indicators
    """
    
    def __init__(self, window_size: int = 1000):
        self.window_size = window_size
        self.metrics: List[QueryMetrics] = []
        self.lock = threading.Lock()
        self.start_time = time.time()
    
    def record(self, metrics: QueryMetrics):
        """Record query metrics."""
        with self.lock:
            self.metrics.append(metrics)
            
            # Keep only recent metrics
            if len(self.metrics) > self.window_size:
                self.metrics = self.metrics[-self.window_size:]
    
    def get_latency_stats(self) -> Dict[str, float]:
        """Get latency statistics."""
        with self.lock:
            if not self.metrics:
                return {}
            
            latencies = [m.latency_ms for m in self.metrics if m.success]
            if not latencies:
                return {}
            
            return {
                "mean": np.mean(latencies),
                "p50": np.percentile(latencies, 50),
                "p95": np.percentile(latencies, 95),
                "p99": np.percentile(latencies, 99),
                "min": np.min(latencies),
                "max": np.max(latencies)
            }
    
    def get_throughput(self) -> float:
        """Get queries per second."""
        with self.lock:
            if not self.metrics:
                return 0
            
            elapsed = time.time() - self.start_time
            return len(self.metrics) / elapsed if elapsed > 0 else 0
    
    def get_error_rate(self) -> float:
        """Get error rate."""
        with self.lock:
            if not self.metrics:
                return 0
            
            errors = sum(1 for m in self.metrics if not m.success)
            return errors / len(self.metrics)
    
    def get_cache_stats(self) -> Dict[str, float]:
        """Get cache performance."""
        with self.lock:
            if not self.metrics:
                return {}
            
            hits = sum(1 for m in self.metrics if m.cache_hit)
            return {
                "hit_rate": hits / len(self.metrics),
                "hits": hits,
                "misses": len(self.metrics) - hits
            }
    
    def get_retrieval_stats(self) -> Dict[str, float]:
        """Get retrieval quality indicators."""
        with self.lock:
            if not self.metrics:
                return {}
            
            successful = [m for m in self.metrics if m.success]
            if not successful:
                return {}
            
            return {
                "avg_contexts": np.mean([m.num_contexts for m in successful]),
                "avg_top_score": np.mean([m.top_score for m in successful]),
                "avg_retrieval_ms": np.mean([m.retrieval_latency_ms for m in successful])
            }
    
    def get_summary(self) -> Dict[str, Any]:
        """Get full monitoring summary."""
        return {
            "total_queries": len(self.metrics),
            "qps": self.get_throughput(),
            "error_rate": self.get_error_rate(),
            "latency": self.get_latency_stats(),
            "cache": self.get_cache_stats(),
            "retrieval": self.get_retrieval_stats()
        }
    
    def print_summary(self):
        """Print formatted summary."""
        summary = self.get_summary()
        
        print("="*60)
        print("RAG SYSTEM MONITORING SUMMARY")
        print("="*60)
        print(f"Total Queries: {summary['total_queries']}")
        print(f"Throughput: {summary['qps']:.2f} QPS")
        print(f"Error Rate: {summary['error_rate']*100:.2f}%")
        
        if summary['latency']:
            print(f"\nLatency:")
            print(f"  Mean: {summary['latency']['mean']:.2f}ms")
            print(f"  P50:  {summary['latency']['p50']:.2f}ms")
            print(f"  P95:  {summary['latency']['p95']:.2f}ms")
            print(f"  P99:  {summary['latency']['p99']:.2f}ms")
        
        if summary['cache']:
            print(f"\nCache:")
            print(f"  Hit Rate: {summary['cache']['hit_rate']*100:.2f}%")
        
        if summary['retrieval']:
            print(f"\nRetrieval:")
            print(f"  Avg Contexts: {summary['retrieval']['avg_contexts']:.1f}")
            print(f"  Avg Top Score: {summary['retrieval']['avg_top_score']:.3f}")

print("RAGMonitor class defined")

## Exercise 4 Solution: Build Production RAG System

**Task**: Combine all components into a production-ready system.

In [None]:
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import uuid


@dataclass
class RAGConfig:
    """Configuration for production RAG system."""
    # Retrieval
    k: int = 5
    score_threshold: float = 0.3
    
    # Caching
    enable_cache: bool = True
    cache_size: int = 10000
    cache_ttl: int = 3600
    
    # Error handling
    max_retries: int = 3
    circuit_breaker_threshold: int = 5
    
    # Models
    embedding_model: str = "BAAI/bge-large-en-v1.5"
    llm_model: str = "llama3.1:8b"
    
    # Monitoring
    enable_monitoring: bool = True


class ProductionRAG:
    """
    Production-ready RAG system.
    
    Features:
    - LRU caching for queries and embeddings
    - Automatic retries with exponential backoff
    - Circuit breaker for failing gracefully
    - Comprehensive monitoring
    - GPU acceleration
    """
    
    def __init__(self, config: RAGConfig = RAGConfig()):
        self.config = config
        
        # Initialize components
        self._init_caching()
        self._init_error_handling()
        self._init_monitoring()
        self._init_models()
        
        self.vectorstore = None
        logger.info("ProductionRAG initialized")
    
    def _init_caching(self):
        """Initialize caching layers."""
        if self.config.enable_cache:
            self.query_cache = LRUCache(
                max_size=self.config.cache_size,
                ttl_seconds=self.config.cache_ttl
            )
            self.embedding_cache = EmbeddingCache(
                max_embeddings=self.config.cache_size
            )
        else:
            self.query_cache = None
            self.embedding_cache = None
    
    def _init_error_handling(self):
        """Initialize error handling."""
        self.llm_circuit_breaker = CircuitBreaker(
            failure_threshold=self.config.circuit_breaker_threshold
        )
        self.retry_config = RetryConfig(max_retries=self.config.max_retries)
    
    def _init_monitoring(self):
        """Initialize monitoring."""
        if self.config.enable_monitoring:
            self.monitor = RAGMonitor()
        else:
            self.monitor = None
    
    def _init_models(self):
        """Initialize embedding model."""
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.embedding_model = HuggingFaceEmbeddings(
            model_name=self.config.embedding_model,
            model_kwargs={"device": device},
            encode_kwargs={"normalize_embeddings": True, "batch_size": 64}
        )
        logger.info(f"Embedding model loaded on {device}")
    
    def index_documents(self, documents_path: str):
        """Index documents into vector store."""
        # Load documents
        documents = []
        for file_path in Path(documents_path).glob("*.md"):
            content = file_path.read_text(encoding='utf-8')
            documents.append(Document(
                page_content=content,
                metadata={"source": file_path.name}
            ))
        
        # Chunk
        splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
        chunks = splitter.split_documents(documents)
        
        # Create vector store
        self.vectorstore = Chroma.from_documents(
            documents=chunks,
            embedding=self.embedding_model,
            persist_directory="./production_chroma_db"
        )
        
        logger.info(f"Indexed {len(chunks)} chunks from {len(documents)} documents")
    
    def _retrieve(self, query: str) -> List[Dict]:
        """Retrieve with caching."""
        # Check cache
        if self.query_cache:
            cached = self.query_cache.get(f"retrieve:{query}")
            if cached is not None:
                return cached, True
        
        # Perform retrieval
        results = self.vectorstore.similarity_search_with_score(
            query, k=self.config.k
        )
        
        formatted = [
            {
                "content": doc.page_content,
                "metadata": doc.metadata,
                "score": 1 / (1 + score)
            }
            for doc, score in results
        ]
        
        # Cache results
        if self.query_cache:
            self.query_cache.set(f"retrieve:{query}", formatted)
        
        return formatted, False
    
    @with_retry(RetryConfig(max_retries=2))
    def _generate(self, query: str, contexts: List[str]) -> str:
        """Generate with retries and circuit breaker."""
        def _call_llm():
            import ollama
            
            context_str = "\n\n".join(contexts[:5])
            prompt = f"""Answer based ONLY on the context. If unsure, say so.

Context:
{context_str}

Question: {query}

Answer:"""
            
            response = ollama.chat(
                model=self.config.llm_model,
                messages=[{"role": "user", "content": prompt}]
            )
            return response["message"]["content"]
        
        return self.llm_circuit_breaker.execute(_call_llm)
    
    def query(self, question: str) -> Dict[str, Any]:
        """Execute RAG query with full production features."""
        query_id = str(uuid.uuid4())[:8]
        start_time = time.time()
        cache_hit = False
        error = None
        
        try:
            # Retrieval
            retrieval_start = time.time()
            retrieved, cache_hit = self._retrieve(question)
            retrieval_time = (time.time() - retrieval_start) * 1000
            
            contexts = [r["content"] for r in retrieved]
            
            # Generation
            generation_start = time.time()
            answer = self._generate(question, contexts)
            generation_time = (time.time() - generation_start) * 1000
            
            total_time = (time.time() - start_time) * 1000
            
            result = {
                "query_id": query_id,
                "question": question,
                "answer": answer,
                "contexts": contexts,
                "scores": [r["score"] for r in retrieved],
                "latency_ms": total_time,
                "retrieval_ms": retrieval_time,
                "generation_ms": generation_time,
                "cache_hit": cache_hit
            }
            
            # Record metrics
            if self.monitor:
                self.monitor.record(QueryMetrics(
                    query_id=query_id,
                    timestamp=datetime.now(),
                    query=question,
                    latency_ms=total_time,
                    retrieval_latency_ms=retrieval_time,
                    generation_latency_ms=generation_time,
                    num_contexts=len(contexts),
                    top_score=retrieved[0]["score"] if retrieved else 0,
                    cache_hit=cache_hit,
                    success=True
                ))
            
            return result
            
        except Exception as e:
            error = str(e)
            logger.error(f"Query failed: {error}")
            
            # Record failure
            if self.monitor:
                self.monitor.record(QueryMetrics(
                    query_id=query_id,
                    timestamp=datetime.now(),
                    query=question,
                    latency_ms=(time.time() - start_time) * 1000,
                    retrieval_latency_ms=0,
                    generation_latency_ms=0,
                    num_contexts=0,
                    top_score=0,
                    cache_hit=cache_hit,
                    success=False,
                    error=error
                ))
            
            return {
                "query_id": query_id,
                "question": question,
                "answer": "I'm sorry, I encountered an error processing your request.",
                "error": error
            }

print("ProductionRAG class defined")

In [None]:
# Initialize production RAG
config = RAGConfig(
    k=5,
    enable_cache=True,
    cache_size=1000,
    enable_monitoring=True
)

rag = ProductionRAG(config)
rag.index_documents("../data/sample_documents")

In [None]:
# Test queries
test_queries = [
    "What is the memory capacity of DGX Spark?",
    "How does LoRA work?",
    "Explain transformer attention",
    "What is the memory capacity of DGX Spark?",  # Repeat for cache
    "What quantization methods are available?"
]

print("Running test queries...\n")
for query in test_queries:
    result = rag.query(query)
    print(f"Q: {query}")
    print(f"A: {result['answer'][:200]}...")
    print(f"Latency: {result.get('latency_ms', 0):.0f}ms, Cache: {result.get('cache_hit', False)}")
    print()

In [None]:
# Print monitoring summary
if rag.monitor:
    rag.monitor.print_summary()

## Production Checklist

In [None]:
checklist = """
PRODUCTION RAG DEPLOYMENT CHECKLIST
====================================

PERFORMANCE
-----------
[x] GPU acceleration for embeddings
[x] Query result caching (LRU with TTL)
[x] Embedding caching
[x] Batch processing support
[ ] Async/concurrent query handling
[ ] Connection pooling for vector store

RELIABILITY
-----------
[x] Automatic retries with backoff
[x] Circuit breaker pattern
[x] Graceful error handling
[ ] Health check endpoints
[ ] Graceful shutdown

MONITORING
----------
[x] Latency tracking (p50, p95, p99)
[x] Throughput monitoring
[x] Error rate tracking
[x] Cache hit rates
[ ] Export to Prometheus/Grafana
[ ] Alerting rules

SECURITY
--------
[ ] Input validation and sanitization
[ ] Rate limiting
[ ] Authentication/authorization
[ ] Audit logging

QUALITY
-------
[ ] Automated evaluation pipeline
[ ] A/B testing framework
[ ] User feedback collection
[ ] Continuous monitoring of RAGAS metrics

DEPLOYMENT
----------
[ ] Containerization (Docker)
[ ] Kubernetes deployment configs
[ ] Load balancer configuration
[ ] Auto-scaling rules
"""

print(checklist)