# Lab 3.3.7 Solutions: Production API with FastAPI

Complete solutions to all exercises from the production API lab.

## Setup

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path("../scripts").resolve()))

import asyncio
import time
import json
from dataclasses import dataclass, field
from typing import List, Dict, Optional, AsyncGenerator
from datetime import datetime, timedelta
import hashlib

---

## Exercise 1: Implement Rate Limiting

**Task**: Create a production-ready rate limiter with multiple strategies.

In [None]:
from collections import defaultdict
import threading


@dataclass
class RateLimitConfig:
    """Configuration for rate limiting."""
    requests_per_minute: int = 60
    requests_per_hour: int = 1000
    tokens_per_minute: int = 100000
    concurrent_requests: int = 10
    burst_multiplier: float = 1.5


class TokenBucketRateLimiter:
    """
    Token bucket rate limiter with support for:
    - Multiple time windows (per-minute, per-hour)
    - Token-based limits (for LLM usage)
    - Burst allowance
    - Concurrent request limiting
    """
    
    def __init__(self, config: RateLimitConfig = None):
        self.config = config or RateLimitConfig()
        self._lock = threading.Lock()
        
        # Per-client tracking
        self._request_counts: Dict[str, List[float]] = defaultdict(list)
        self._token_counts: Dict[str, List[tuple]] = defaultdict(list)
        self._concurrent: Dict[str, int] = defaultdict(int)
    
    def _clean_old_entries(self, entries: list, window_seconds: int) -> list:
        """Remove entries older than the window."""
        cutoff = time.time() - window_seconds
        return [e for e in entries if (e[0] if isinstance(e, tuple) else e) > cutoff]
    
    def check_rate_limit(self, client_id: str, 
                         estimated_tokens: int = 0) -> dict:
        """
        Check if a request should be allowed.
        
        Returns:
            dict with 'allowed', 'reason', and 'retry_after' if rejected
        """
        with self._lock:
            now = time.time()
            
            # Clean old entries
            self._request_counts[client_id] = self._clean_old_entries(
                self._request_counts[client_id], 3600  # 1 hour
            )
            self._token_counts[client_id] = self._clean_old_entries(
                self._token_counts[client_id], 60  # 1 minute
            )
            
            # Check concurrent requests
            if self._concurrent[client_id] >= self.config.concurrent_requests:
                return {
                    "allowed": False,
                    "reason": "concurrent_limit",
                    "message": f"Max {self.config.concurrent_requests} concurrent requests",
                    "retry_after": 1
                }
            
            # Check requests per minute
            minute_ago = now - 60
            requests_last_minute = len([r for r in self._request_counts[client_id] if r > minute_ago])
            
            burst_limit = int(self.config.requests_per_minute * self.config.burst_multiplier)
            if requests_last_minute >= burst_limit:
                oldest = min([r for r in self._request_counts[client_id] if r > minute_ago])
                retry_after = 60 - (now - oldest)
                return {
                    "allowed": False,
                    "reason": "rate_limit_minute",
                    "message": f"Rate limit: {self.config.requests_per_minute}/min",
                    "retry_after": max(1, int(retry_after))
                }
            
            # Check requests per hour
            if len(self._request_counts[client_id]) >= self.config.requests_per_hour:
                return {
                    "allowed": False,
                    "reason": "rate_limit_hour",
                    "message": f"Rate limit: {self.config.requests_per_hour}/hour",
                    "retry_after": 3600
                }
            
            # Check token limit
            tokens_last_minute = sum(t[1] for t in self._token_counts[client_id])
            if tokens_last_minute + estimated_tokens > self.config.tokens_per_minute:
                return {
                    "allowed": False,
                    "reason": "token_limit",
                    "message": f"Token limit: {self.config.tokens_per_minute}/min",
                    "retry_after": 60
                }
            
            # Request allowed
            return {
                "allowed": True,
                "remaining_requests": self.config.requests_per_minute - requests_last_minute - 1,
                "remaining_tokens": self.config.tokens_per_minute - tokens_last_minute - estimated_tokens
            }
    
    def record_request(self, client_id: str, tokens_used: int = 0):
        """Record a completed request."""
        with self._lock:
            now = time.time()
            self._request_counts[client_id].append(now)
            if tokens_used > 0:
                self._token_counts[client_id].append((now, tokens_used))
    
    def acquire_concurrent(self, client_id: str) -> bool:
        """Acquire a concurrent request slot."""
        with self._lock:
            if self._concurrent[client_id] < self.config.concurrent_requests:
                self._concurrent[client_id] += 1
                return True
            return False
    
    def release_concurrent(self, client_id: str):
        """Release a concurrent request slot."""
        with self._lock:
            self._concurrent[client_id] = max(0, self._concurrent[client_id] - 1)
    
    def get_client_stats(self, client_id: str) -> dict:
        """Get usage statistics for a client."""
        with self._lock:
            now = time.time()
            minute_ago = now - 60
            
            requests_minute = len([r for r in self._request_counts[client_id] if r > minute_ago])
            tokens_minute = sum(t[1] for t in self._token_counts[client_id] if t[0] > minute_ago)
            
            return {
                "requests_last_minute": requests_minute,
                "requests_last_hour": len(self._request_counts[client_id]),
                "tokens_last_minute": tokens_minute,
                "concurrent_requests": self._concurrent[client_id],
                "limits": {
                    "requests_per_minute": self.config.requests_per_minute,
                    "requests_per_hour": self.config.requests_per_hour,
                    "tokens_per_minute": self.config.tokens_per_minute,
                    "concurrent_requests": self.config.concurrent_requests
                }
            }


# Demonstrate rate limiter
print("üö¶ Rate Limiter Demo")
print("=" * 50)

config = RateLimitConfig(
    requests_per_minute=10,
    tokens_per_minute=1000,
    concurrent_requests=3
)
limiter = TokenBucketRateLimiter(config)

client_id = "user_123"

# Simulate requests
print("\nSimulating requests...")
for i in range(15):
    result = limiter.check_rate_limit(client_id, estimated_tokens=100)
    
    if result["allowed"]:
        limiter.record_request(client_id, tokens_used=100)
        print(f"  Request {i+1}: ‚úì Allowed (remaining: {result['remaining_requests']})")
    else:
        print(f"  Request {i+1}: ‚úó Rejected ({result['reason']})")
        print(f"              Retry after: {result['retry_after']}s")

# Show stats
stats = limiter.get_client_stats(client_id)
print(f"\nüìä Client Stats:")
print(f"   Requests (last minute): {stats['requests_last_minute']}")
print(f"   Tokens (last minute): {stats['tokens_last_minute']}")

---

## Exercise 2: Implement SSE Streaming

**Task**: Create a robust Server-Sent Events implementation for token streaming.

In [None]:
@dataclass
class StreamChunk:
    """A chunk in the SSE stream."""
    id: str
    object: str
    created: int
    model: str
    choices: List[dict]
    usage: Optional[dict] = None


class SSEStreamManager:
    """
    Manage Server-Sent Events streaming for LLM responses.
    
    Features:
    - OpenAI-compatible format
    - Heartbeat to keep connection alive
    - Graceful error handling
    - Usage tracking
    """
    
    def __init__(self, model: str = "gpt-3.5-turbo"):
        self.model = model
        self.request_id = self._generate_id()
        self.created = int(time.time())
        self.total_tokens = 0
        self.prompt_tokens = 0
    
    def _generate_id(self) -> str:
        """Generate a unique request ID."""
        return f"chatcmpl-{hashlib.md5(str(time.time()).encode()).hexdigest()[:24]}"
    
    def format_chunk(self, 
                     content: str = "",
                     finish_reason: Optional[str] = None,
                     role: Optional[str] = None) -> str:
        """
        Format a chunk in OpenAI-compatible SSE format.
        """
        delta = {}
        if role:
            delta["role"] = role
        if content:
            delta["content"] = content
            self.total_tokens += 1  # Approximate
        
        chunk = {
            "id": self.request_id,
            "object": "chat.completion.chunk",
            "created": self.created,
            "model": self.model,
            "choices": [{
                "index": 0,
                "delta": delta,
                "finish_reason": finish_reason
            }]
        }
        
        return f"data: {json.dumps(chunk)}\n\n"
    
    def format_done(self) -> str:
        """Format the final [DONE] message."""
        return "data: [DONE]\n\n"
    
    def format_error(self, error: str, code: str = "internal_error") -> str:
        """Format an error message."""
        error_chunk = {
            "error": {
                "message": error,
                "type": "api_error",
                "code": code
            }
        }
        return f"data: {json.dumps(error_chunk)}\n\n"
    
    def format_heartbeat(self) -> str:
        """Format a heartbeat (comment) to keep connection alive."""
        return ": heartbeat\n\n"
    
    async def stream_response(self, 
                               tokens: List[str],
                               delay_ms: float = 50) -> AsyncGenerator[str, None]:
        """
        Stream tokens as SSE events.
        """
        # Send initial role chunk
        yield self.format_chunk(role="assistant")
        
        # Stream content tokens
        for i, token in enumerate(tokens):
            await asyncio.sleep(delay_ms / 1000)
            yield self.format_chunk(content=token)
            
            # Send heartbeat every 10 tokens
            if i > 0 and i % 10 == 0:
                yield self.format_heartbeat()
        
        # Send final chunk with finish_reason
        yield self.format_chunk(finish_reason="stop")
        
        # Send done
        yield self.format_done()
    
    def get_usage(self) -> dict:
        """Get token usage for this request."""
        return {
            "prompt_tokens": self.prompt_tokens,
            "completion_tokens": self.total_tokens,
            "total_tokens": self.prompt_tokens + self.total_tokens
        }


# Demonstrate SSE streaming
print("üì° SSE Streaming Demo")
print("=" * 50)

manager = SSEStreamManager(model="llama-3.1-8b")

# Simulate a response
tokens = ["Hello", "!", " I'm", " Claude", ",", " an", " AI", " assistant", "."]

print("\nStreaming response:")
print("-" * 40)

async def demo_stream():
    async for chunk in manager.stream_response(tokens, delay_ms=10):
        # Parse and display
        if chunk.startswith("data: {"):
            data = json.loads(chunk[6:-2])  # Remove "data: " and "\n\n"
            if "choices" in data:
                delta = data["choices"][0].get("delta", {})
                if "content" in delta:
                    print(delta["content"], end="", flush=True)
                if delta.get("role"):
                    print(f"[{delta['role']}] ", end="")
        elif "[DONE]" in chunk:
            print("\n[DONE]")

await demo_stream()

print(f"\nüìä Usage: {manager.get_usage()}")

---

## Exercise 3: Implement Health Checks and Monitoring

**Task**: Create a comprehensive health monitoring system.

In [None]:
from enum import Enum
from dataclasses import dataclass, field
from typing import Callable


class HealthStatus(Enum):
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"


@dataclass
class HealthCheckResult:
    """Result of a health check."""
    name: str
    status: HealthStatus
    latency_ms: float
    message: str = ""
    details: dict = field(default_factory=dict)


class HealthMonitor:
    """
    Comprehensive health monitoring for production LLM API.
    
    Monitors:
    - Inference engine connectivity
    - GPU memory and utilization
    - Request latency and error rates
    - Queue depth
    """
    
    def __init__(self):
        self.checks: Dict[str, Callable] = {}
        self.metrics = {
            "requests_total": 0,
            "requests_failed": 0,
            "latency_sum_ms": 0,
            "tokens_generated": 0,
        }
        self._last_check_time = 0
        self._cached_status = None
    
    def register_check(self, name: str, check_fn: Callable):
        """Register a health check function."""
        self.checks[name] = check_fn
    
    def record_request(self, latency_ms: float, tokens: int, success: bool):
        """Record metrics for a request."""
        self.metrics["requests_total"] += 1
        self.metrics["latency_sum_ms"] += latency_ms
        self.metrics["tokens_generated"] += tokens
        if not success:
            self.metrics["requests_failed"] += 1
    
    async def check_inference_engine(self, url: str = "http://localhost:8000") -> HealthCheckResult:
        """Check if inference engine is responsive."""
        start = time.time()
        try:
            # Simulate health check (would be actual HTTP request)
            await asyncio.sleep(0.01)  # Simulate network latency
            latency = (time.time() - start) * 1000
            
            # Simulated response
            engine_healthy = True
            
            if engine_healthy:
                return HealthCheckResult(
                    name="inference_engine",
                    status=HealthStatus.HEALTHY,
                    latency_ms=latency,
                    message="Engine responding normally"
                )
            else:
                return HealthCheckResult(
                    name="inference_engine",
                    status=HealthStatus.UNHEALTHY,
                    latency_ms=latency,
                    message="Engine not responding"
                )
        except Exception as e:
            return HealthCheckResult(
                name="inference_engine",
                status=HealthStatus.UNHEALTHY,
                latency_ms=(time.time() - start) * 1000,
                message=str(e)
            )
    
    def check_gpu_health(self) -> HealthCheckResult:
        """Check GPU memory and utilization."""
        start = time.time()
        
        # Simulated GPU metrics
        gpu_memory_used_gb = 45.2
        gpu_memory_total_gb = 128.0
        gpu_utilization = 0.65
        
        memory_ratio = gpu_memory_used_gb / gpu_memory_total_gb
        
        if memory_ratio > 0.95:
            status = HealthStatus.UNHEALTHY
            message = "GPU memory critically low"
        elif memory_ratio > 0.85:
            status = HealthStatus.DEGRADED
            message = "GPU memory running low"
        else:
            status = HealthStatus.HEALTHY
            message = "GPU memory OK"
        
        return HealthCheckResult(
            name="gpu",
            status=status,
            latency_ms=(time.time() - start) * 1000,
            message=message,
            details={
                "memory_used_gb": gpu_memory_used_gb,
                "memory_total_gb": gpu_memory_total_gb,
                "utilization": gpu_utilization
            }
        )
    
    def check_error_rate(self) -> HealthCheckResult:
        """Check request error rate."""
        start = time.time()
        
        total = self.metrics["requests_total"]
        failed = self.metrics["requests_failed"]
        
        if total == 0:
            error_rate = 0
        else:
            error_rate = failed / total
        
        if error_rate > 0.1:
            status = HealthStatus.UNHEALTHY
            message = f"High error rate: {error_rate:.1%}"
        elif error_rate > 0.05:
            status = HealthStatus.DEGRADED
            message = f"Elevated error rate: {error_rate:.1%}"
        else:
            status = HealthStatus.HEALTHY
            message = f"Error rate normal: {error_rate:.1%}"
        
        return HealthCheckResult(
            name="error_rate",
            status=status,
            latency_ms=(time.time() - start) * 1000,
            message=message,
            details={
                "total_requests": total,
                "failed_requests": failed,
                "error_rate": error_rate
            }
        )
    
    async def run_all_checks(self) -> dict:
        """Run all health checks and return combined status."""
        results = []
        
        # Run async check
        results.append(await self.check_inference_engine())
        
        # Run sync checks
        results.append(self.check_gpu_health())
        results.append(self.check_error_rate())
        
        # Determine overall status
        if any(r.status == HealthStatus.UNHEALTHY for r in results):
            overall = HealthStatus.UNHEALTHY
        elif any(r.status == HealthStatus.DEGRADED for r in results):
            overall = HealthStatus.DEGRADED
        else:
            overall = HealthStatus.HEALTHY
        
        return {
            "status": overall.value,
            "timestamp": datetime.now().isoformat(),
            "checks": [
                {
                    "name": r.name,
                    "status": r.status.value,
                    "latency_ms": r.latency_ms,
                    "message": r.message,
                    "details": r.details
                }
                for r in results
            ],
            "metrics": self.metrics
        }


# Demonstrate health monitoring
print("üè• Health Monitor Demo")
print("=" * 50)

monitor = HealthMonitor()

# Simulate some requests
for i in range(100):
    success = i % 20 != 0  # 5% failure rate
    monitor.record_request(
        latency_ms=50 + (i % 10) * 5,
        tokens=100 + i,
        success=success
    )

# Run health checks
async def run_health_check():
    health = await monitor.run_all_checks()
    
    print(f"\nüìä Overall Status: {health['status'].upper()}")
    print(f"   Timestamp: {health['timestamp']}")
    
    print("\n   Individual Checks:")
    for check in health['checks']:
        icon = "‚úì" if check['status'] == 'healthy' else "‚ö†" if check['status'] == 'degraded' else "‚úó"
        print(f"   {icon} {check['name']}: {check['status']} - {check['message']}")
    
    print(f"\n   Metrics:")
    print(f"   ‚Ä¢ Total requests: {health['metrics']['requests_total']}")
    print(f"   ‚Ä¢ Failed requests: {health['metrics']['requests_failed']}")
    print(f"   ‚Ä¢ Tokens generated: {health['metrics']['tokens_generated']}")

await run_health_check()

---

## Exercise 4: Implement Request Logging and Tracing

**Task**: Create a production logging system with distributed tracing support.

In [None]:
import uuid
import logging


@dataclass
class RequestTrace:
    """Trace information for a request."""
    trace_id: str
    span_id: str
    parent_span_id: Optional[str] = None
    operation: str = ""
    start_time: float = field(default_factory=time.time)
    end_time: Optional[float] = None
    attributes: dict = field(default_factory=dict)
    events: List[dict] = field(default_factory=list)
    status: str = "OK"
    
    @property
    def duration_ms(self) -> float:
        if self.end_time:
            return (self.end_time - self.start_time) * 1000
        return (time.time() - self.start_time) * 1000
    
    def add_event(self, name: str, attributes: dict = None):
        self.events.append({
            "name": name,
            "timestamp": time.time(),
            "attributes": attributes or {}
        })
    
    def finish(self, status: str = "OK"):
        self.end_time = time.time()
        self.status = status


class RequestLogger:
    """
    Production request logger with tracing support.
    
    Features:
    - Structured JSON logging
    - Distributed tracing (trace_id, span_id)
    - Request/response logging
    - Performance metrics
    """
    
    def __init__(self, service_name: str = "llm-api"):
        self.service_name = service_name
        self.traces: Dict[str, RequestTrace] = {}
        
        # Configure logger
        self.logger = logging.getLogger(service_name)
        self.logger.setLevel(logging.INFO)
    
    def start_trace(self, operation: str, 
                    trace_id: str = None,
                    parent_span_id: str = None) -> RequestTrace:
        """Start a new trace span."""
        trace = RequestTrace(
            trace_id=trace_id or str(uuid.uuid4()),
            span_id=str(uuid.uuid4())[:16],
            parent_span_id=parent_span_id,
            operation=operation
        )
        self.traces[trace.span_id] = trace
        return trace
    
    def log_request(self, trace: RequestTrace, request: dict):
        """Log an incoming request."""
        trace.attributes.update({
            "request.model": request.get("model", "unknown"),
            "request.max_tokens": request.get("max_tokens", 0),
            "request.stream": request.get("stream", False),
        })
        
        # Calculate prompt size
        messages = request.get("messages", [])
        prompt_chars = sum(len(m.get("content", "")) for m in messages)
        trace.attributes["request.prompt_chars"] = prompt_chars
        
        trace.add_event("request_received", {"prompt_chars": prompt_chars})
        
        self._log("info", "Request received", trace)
    
    def log_response(self, trace: RequestTrace, response: dict):
        """Log the response."""
        usage = response.get("usage", {})
        trace.attributes.update({
            "response.prompt_tokens": usage.get("prompt_tokens", 0),
            "response.completion_tokens": usage.get("completion_tokens", 0),
            "response.total_tokens": usage.get("total_tokens", 0),
        })
        
        trace.add_event("response_sent", usage)
        trace.finish("OK")
        
        self._log("info", "Request completed", trace)
    
    def log_error(self, trace: RequestTrace, error: str, error_type: str = "unknown"):
        """Log an error."""
        trace.attributes["error.message"] = error
        trace.attributes["error.type"] = error_type
        trace.add_event("error", {"message": error, "type": error_type})
        trace.finish("ERROR")
        
        self._log("error", f"Request failed: {error}", trace)
    
    def _log(self, level: str, message: str, trace: RequestTrace):
        """Internal logging with structured data."""
        log_data = {
            "service": self.service_name,
            "trace_id": trace.trace_id,
            "span_id": trace.span_id,
            "operation": trace.operation,
            "duration_ms": trace.duration_ms,
            "status": trace.status,
            "attributes": trace.attributes,
            "message": message
        }
        
        # In production, this would go to a log aggregator
        print(f"[{level.upper()}] {json.dumps(log_data, indent=2)}")
    
    def get_trace_summary(self, trace_id: str) -> dict:
        """Get a summary of all spans in a trace."""
        spans = [t for t in self.traces.values() if t.trace_id == trace_id]
        
        return {
            "trace_id": trace_id,
            "span_count": len(spans),
            "total_duration_ms": sum(s.duration_ms for s in spans),
            "spans": [
                {
                    "span_id": s.span_id,
                    "operation": s.operation,
                    "duration_ms": s.duration_ms,
                    "status": s.status,
                    "events": len(s.events)
                }
                for s in spans
            ]
        }


# Demonstrate logging
print("üìù Request Logger Demo")
print("=" * 50)

logger = RequestLogger(service_name="llm-api")

# Simulate a request
trace = logger.start_trace("chat.completion")

request = {
    "model": "llama-3.1-8b",
    "messages": [{"role": "user", "content": "What is machine learning?"}],
    "max_tokens": 256,
    "stream": True
}

print("\n1. Logging request:")
logger.log_request(trace, request)

# Simulate processing
time.sleep(0.1)
trace.add_event("inference_start")
time.sleep(0.2)
trace.add_event("inference_complete")

response = {
    "usage": {
        "prompt_tokens": 12,
        "completion_tokens": 156,
        "total_tokens": 168
    }
}

print("\n2. Logging response:")
logger.log_response(trace, response)

print("\n3. Trace summary:")
summary = logger.get_trace_summary(trace.trace_id)
print(json.dumps(summary, indent=2))

---

## Key Takeaways

1. **Rate Limiting**:
   - Use token bucket for flexible rate limiting
   - Track multiple windows (minute, hour)
   - Include token-based limits for LLMs
   - Support burst traffic with multipliers

2. **SSE Streaming**:
   - OpenAI-compatible format for drop-in replacement
   - Include heartbeats for connection keep-alive
   - Handle errors gracefully in stream

3. **Health Monitoring**:
   - Check inference engine, GPU, and error rates
   - Use three-level status (healthy/degraded/unhealthy)
   - Expose metrics for alerting

4. **Logging & Tracing**:
   - Structured JSON logs for aggregation
   - Distributed tracing with trace/span IDs
   - Track request lifecycle with events