# Production REST APIs for ML Services

## Overview
Building production-grade REST APIs for ML inference with:
- **FastAPI**: Modern, async Python framework
- **Authentication & Authorization**: OAuth2, JWT, API keys
- **Rate Limiting**: Protect against abuse
- **Versioning**: Breaking changes without downtime
- **Monitoring**: Prometheus metrics, structured logging
- **Performance**: Caching, batch processing, load balancing

## Why Production APIs Matter
- 90% of ML costs are inference (not training)
- Availability = revenue (99.9% uptime = 8.76 hours downtime/year)
- Security breaches are expensive (avg $4.35M per breach)
- Performance impacts UX (100ms delay = 7% conversion drop)

## Interview Focus
- RESTful design principles
- Async programming patterns
- Security best practices
- API versioning strategies
- Performance optimization

In [None]:
# Installation
# pip install fastapi uvicorn pydantic python-jose python-multipart prometheus-client redis slowapi

from fastapi import FastAPI, HTTPException, Depends, status, Request, Header
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from jose import JWTError, jwt
import torch
import numpy as np
import asyncio
import time
import logging
from functools import wraps
import hashlib

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

## Part 1: Core API Structure

### Production-Ready FastAPI Setup

In [None]:
class Config:
    """Centralized configuration with environment variables."""
    API_TITLE = "ML Inference API"
    API_VERSION = "v1"
    SECRET_KEY = "your-secret-key-change-in-production"  # Use env vars in production
    ALGORITHM = "HS256"
    ACCESS_TOKEN_EXPIRE_MINUTES = 30
    RATE_LIMIT = "100/minute"
    
    # Model settings
    MODEL_PATH = "./models/"
    BATCH_SIZE = 32
    MAX_SEQUENCE_LENGTH = 512

# Initialize FastAPI with metadata
app = FastAPI(
    title=Config.API_TITLE,
    description="Production ML inference service with authentication and monitoring",
    version=Config.API_VERSION,
    docs_url=f"/{Config.API_VERSION}/docs",
    redoc_url=f"/{Config.API_VERSION}/redoc"
)

# CORS middleware for cross-origin requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure properly in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

print(f"✅ FastAPI app initialized: {Config.API_TITLE} {Config.API_VERSION}")

## Part 2: Request/Response Models

### Type-Safe Schemas with Validation

In [None]:
class InferenceRequest(BaseModel):
    """Input schema with validation."""
    text: str = Field(..., min_length=1, max_length=5000, description="Input text for inference")
    model_version: Optional[str] = Field("latest", description="Model version to use")
    temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
    max_tokens: Optional[int] = Field(100, ge=1, le=2048, description="Max output tokens")
    
    @validator('text')
    def validate_text(cls, v):
        if not v.strip():
            raise ValueError('Text cannot be empty')
        return v.strip()
    
    class Config:
        schema_extra = {
            "example": {
                "text": "What is machine learning?",
                "model_version": "v2.1",
                "temperature": 0.7,
                "max_tokens": 150
            }
        }

class InferenceResponse(BaseModel):
    """Output schema with metadata."""
    request_id: str = Field(..., description="Unique request identifier")
    result: str = Field(..., description="Model prediction")
    confidence: Optional[float] = Field(None, ge=0.0, le=1.0, description="Prediction confidence")
    latency_ms: float = Field(..., description="Inference latency in milliseconds")
    model_version: str = Field(..., description="Model version used")
    timestamp: datetime = Field(default_factory=datetime.utcnow)

class BatchInferenceRequest(BaseModel):
    """Batch processing for efficiency."""
    inputs: List[str] = Field(..., min_items=1, max_items=100)
    model_version: Optional[str] = "latest"

class HealthResponse(BaseModel):
    """Health check response."""
    status: str
    version: str
    uptime_seconds: float
    gpu_available: bool
    models_loaded: List[str]

class ErrorResponse(BaseModel):
    """Standardized error response."""
    error: str
    detail: Optional[str] = None
    request_id: Optional[str] = None
    timestamp: datetime = Field(default_factory=datetime.utcnow)

print("✅ Request/Response models defined")

## Part 3: Authentication & Authorization

### JWT-Based Authentication

In [None]:
# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"/{Config.API_VERSION}/token")
security = HTTPBearer()

class User(BaseModel):
    username: str
    email: Optional[str] = None
    full_name: Optional[str] = None
    disabled: Optional[bool] = None
    tier: str = "free"  # free, pro, enterprise

class UserInDB(User):
    hashed_password: str

# Mock database (use real DB in production)
fake_users_db = {
    "demo_user": {
        "username": "demo_user",
        "full_name": "Demo User",
        "email": "demo@example.com",
        "hashed_password": hashlib.sha256("demo_password".encode()).hexdigest(),
        "disabled": False,
        "tier": "pro"
    }
}

def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify password (use bcrypt in production)."""
    return hashlib.sha256(plain_password.encode()).hexdigest() == hashed_password

def get_user(username: str) -> Optional[UserInDB]:
    """Retrieve user from database."""
    if username in fake_users_db:
        user_dict = fake_users_db[username]
        return UserInDB(**user_dict)
    return None

def authenticate_user(username: str, password: str) -> Optional[UserInDB]:
    """Authenticate user credentials."""
    user = get_user(username)
    if not user or not verify_password(password, user.hashed_password):
        return None
    return user

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """Generate JWT access token."""
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, Config.SECRET_KEY, algorithm=Config.ALGORITHM)
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
    """Validate token and return current user."""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    try:
        payload = jwt.decode(token, Config.SECRET_KEY, algorithms=[Config.ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    
    user = get_user(username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
    """Check if user is active."""
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

print("✅ Authentication system configured")

## Part 4: Rate Limiting

### Protect Against Abuse

In [None]:
from collections import defaultdict
from threading import Lock

class RateLimiter:
    """Token bucket rate limiter."""
    
    def __init__(self, rate: int, per: int):
        """
        Args:
            rate: Number of requests allowed
            per: Time window in seconds
        """
        self.rate = rate
        self.per = per
        self.allowance = defaultdict(lambda: rate)
        self.last_check = defaultdict(lambda: time.time())
        self.lock = Lock()
    
    def is_allowed(self, key: str) -> bool:
        """Check if request is allowed."""
        with self.lock:
            current = time.time()
            time_passed = current - self.last_check[key]
            self.last_check[key] = current
            
            # Add tokens based on time passed
            self.allowance[key] += time_passed * (self.rate / self.per)
            
            if self.allowance[key] > self.rate:
                self.allowance[key] = self.rate
            
            if self.allowance[key] < 1.0:
                return False
            
            self.allowance[key] -= 1.0
            return True

# Create rate limiters for different tiers
rate_limiters = {
    "free": RateLimiter(rate=10, per=60),      # 10 req/min
    "pro": RateLimiter(rate=100, per=60),      # 100 req/min
    "enterprise": RateLimiter(rate=1000, per=60)  # 1000 req/min
}

async def check_rate_limit(request: Request, user: User = Depends(get_current_active_user)):
    """Dependency for rate limiting."""
    limiter = rate_limiters.get(user.tier, rate_limiters["free"])
    
    if not limiter.is_allowed(user.username):
        raise HTTPException(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            detail=f"Rate limit exceeded. Tier: {user.tier}"
        )
    
    return user

print("✅ Rate limiter configured")

## Part 5: ML Model Service

### Model Loading and Inference

In [None]:
class ModelService:
    """Singleton model service with lazy loading."""
    
    _instance = None
    _lock = asyncio.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelService, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    async def initialize(self):
        """Load models asynchronously."""
        if self._initialized:
            return
        
        async with self._lock:
            if self._initialized:
                return
            
            logger.info("Loading ML models...")
            
            # Simulate model loading
            await asyncio.sleep(1)
            
            self.models = {
                "latest": "v2.1",
                "v2.1": self._create_dummy_model(),
                "v2.0": self._create_dummy_model()
            }
            
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.start_time = time.time()
            self._initialized = True
            
            logger.info(f"Models loaded on {self.device}")
    
    def _create_dummy_model(self):
        """Create a simple model for demonstration."""
        return torch.nn.Linear(10, 2)
    
    async def predict(self, text: str, model_version: str = "latest", **kwargs) -> Dict[str, Any]:
        """Run inference."""
        start_time = time.time()
        
        # Resolve version
        if model_version == "latest":
            model_version = self.models["latest"]
        
        if model_version not in self.models:
            raise ValueError(f"Model version {model_version} not found")
        
        # Simulate inference
        await asyncio.sleep(0.1)  # Simulate model processing
        
        # Mock prediction
        result = f"Processed: {text[:50]}..."
        confidence = np.random.uniform(0.8, 0.99)
        
        latency_ms = (time.time() - start_time) * 1000
        
        return {
            "result": result,
            "confidence": confidence,
            "latency_ms": latency_ms,
            "model_version": model_version
        }
    
    async def batch_predict(self, texts: List[str], model_version: str = "latest") -> List[Dict[str, Any]]:
        """Batch inference for efficiency."""
        # Process in parallel
        tasks = [self.predict(text, model_version) for text in texts]
        return await asyncio.gather(*tasks)
    
    def get_health(self) -> Dict[str, Any]:
        """Return service health."""
        return {
            "status": "healthy",
            "uptime_seconds": time.time() - self.start_time,
            "gpu_available": torch.cuda.is_available(),
            "models_loaded": [k for k in self.models.keys() if k != "latest"]
        }

# Global model service
model_service = ModelService()

print("✅ Model service initialized")

## Part 6: API Endpoints

### RESTful Routes with Versioning

In [None]:
import uuid

# Startup event
@app.on_event("startup")
async def startup_event():
    """Initialize services on startup."""
    await model_service.initialize()
    logger.info("API server started")

# Health check (no auth required)
@app.get("/health", response_model=HealthResponse, tags=["Monitoring"])
async def health_check():
    """Health check endpoint for load balancers."""
    health = model_service.get_health()
    return HealthResponse(
        version=Config.API_VERSION,
        **health
    )

# Authentication endpoint
@app.post(f"/{Config.API_VERSION}/token", tags=["Authentication"])
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    """Obtain access token."""
    user = authenticate_user(form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    
    access_token_expires = timedelta(minutes=Config.ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username, "tier": user.tier},
        expires_delta=access_token_expires
    )
    
    return {
        "access_token": access_token,
        "token_type": "bearer",
        "expires_in": Config.ACCESS_TOKEN_EXPIRE_MINUTES * 60
    }

# User info endpoint
@app.get(f"/{Config.API_VERSION}/users/me", response_model=User, tags=["Users"])
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    """Get current user information."""
    return current_user

# Single inference endpoint
@app.post(
    f"/{Config.API_VERSION}/predict",
    response_model=InferenceResponse,
    tags=["Inference"],
    status_code=status.HTTP_200_OK
)
async def predict(
    request: InferenceRequest,
    user: User = Depends(check_rate_limit)
):
    """Run ML inference on input text."""
    request_id = str(uuid.uuid4())
    
    try:
        logger.info(f"Request {request_id} from user {user.username}")
        
        prediction = await model_service.predict(
            request.text,
            model_version=request.model_version,
            temperature=request.temperature,
            max_tokens=request.max_tokens
        )
        
        return InferenceResponse(
            request_id=request_id,
            **prediction
        )
    
    except Exception as e:
        logger.error(f"Error in request {request_id}: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

# Batch inference endpoint
@app.post(
    f"/{Config.API_VERSION}/predict/batch",
    tags=["Inference"],
    status_code=status.HTTP_200_OK
)
async def batch_predict(
    request: BatchInferenceRequest,
    user: User = Depends(check_rate_limit)
):
    """Batch inference for multiple inputs."""
    request_id = str(uuid.uuid4())
    
    try:
        logger.info(f"Batch request {request_id} with {len(request.inputs)} items")
        
        predictions = await model_service.batch_predict(
            request.inputs,
            model_version=request.model_version
        )
        
        return {
            "request_id": request_id,
            "predictions": predictions,
            "count": len(predictions)
        }
    
    except Exception as e:
        logger.error(f"Error in batch request {request_id}: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

# Model info endpoint
@app.get(f"/{Config.API_VERSION}/models", tags=["Models"])
async def list_models(user: User = Depends(get_current_active_user)):
    """List available models."""
    return {
        "models": model_service.get_health()["models_loaded"],
        "default": model_service.models["latest"]
    }

# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    """Custom error response format."""
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "error": exc.detail,
            "status_code": exc.status_code,
            "timestamp": datetime.utcnow().isoformat()
        }
    )

print(f"✅ API endpoints configured")
print(f"\nAPI Documentation: http://localhost:8000/{Config.API_VERSION}/docs")
print(f"\nTest credentials:")
print(f"  Username: demo_user")
print(f"  Password: demo_password")

## Part 7: Monitoring & Metrics

### Prometheus Integration

In [None]:
from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
from fastapi.responses import Response

# Define metrics
REQUEST_COUNT = Counter(
    'api_requests_total',
    'Total API requests',
    ['method', 'endpoint', 'status']
)

REQUEST_LATENCY = Histogram(
    'api_request_latency_seconds',
    'API request latency',
    ['method', 'endpoint']
)

INFERENCE_LATENCY = Histogram(
    'inference_latency_seconds',
    'Model inference latency',
    ['model_version']
)

ACTIVE_REQUESTS = Gauge(
    'api_active_requests',
    'Number of active requests'
)

# Middleware for automatic metrics
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
    """Track request metrics."""
    ACTIVE_REQUESTS.inc()
    
    start_time = time.time()
    
    try:
        response = await call_next(request)
        status_code = response.status_code
    except Exception as e:
        status_code = 500
        raise
    finally:
        latency = time.time() - start_time
        
        REQUEST_COUNT.labels(
            method=request.method,
            endpoint=request.url.path,
            status=status_code
        ).inc()
        
        REQUEST_LATENCY.labels(
            method=request.method,
            endpoint=request.url.path
        ).observe(latency)
        
        ACTIVE_REQUESTS.dec()
    
    return response

# Metrics endpoint
@app.get("/metrics", tags=["Monitoring"])
async def metrics():
    """Prometheus metrics endpoint."""
    return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)

print("✅ Prometheus metrics configured")
print("Metrics available at: http://localhost:8000/metrics")

## Part 8: Running the API

### Start Server and Test

In [None]:
# To run the server:
# uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4 --reload

# Example client code
import requests

class APIClient:
    """Production API client with retries and error handling."""
    
    def __init__(self, base_url: str, username: str, password: str):
        self.base_url = base_url
        self.token = self._login(username, password)
        self.headers = {"Authorization": f"Bearer {self.token}"}
    
    def _login(self, username: str, password: str) -> str:
        """Authenticate and get token."""
        response = requests.post(
            f"{self.base_url}/v1/token",
            data={"username": username, "password": password}
        )
        response.raise_for_status()
        return response.json()["access_token"]
    
    def predict(self, text: str, **kwargs) -> dict:
        """Single prediction."""
        response = requests.post(
            f"{self.base_url}/v1/predict",
            json={"text": text, **kwargs},
            headers=self.headers
        )
        response.raise_for_status()
        return response.json()
    
    def batch_predict(self, texts: List[str], **kwargs) -> dict:
        """Batch prediction."""
        response = requests.post(
            f"{self.base_url}/v1/predict/batch",
            json={"inputs": texts, **kwargs},
            headers=self.headers
        )
        response.raise_for_status()
        return response.json()

# Example usage
print("\n=== Example Client Usage ===")
print("""
# Initialize client
client = APIClient(
    base_url="http://localhost:8000",
    username="demo_user",
    password="demo_password"
)

# Single prediction
result = client.predict("What is machine learning?")
print(f"Result: {result['result']}")
print(f"Latency: {result['latency_ms']:.2f}ms")

# Batch prediction
batch_result = client.batch_predict([
    "Question 1",
    "Question 2",
    "Question 3"
])
print(f"Processed {batch_result['count']} items")
""")

## Key Takeaways

### Production API Checklist:
1. ✅ **Authentication**: JWT tokens, secure password hashing
2. ✅ **Authorization**: Role-based access control
3. ✅ **Rate Limiting**: Tiered limits to prevent abuse
4. ✅ **Validation**: Pydantic models with type checking
5. ✅ **Versioning**: URL-based versioning (v1, v2)
6. ✅ **Monitoring**: Prometheus metrics, structured logging
7. ✅ **Error Handling**: Consistent error responses
8. ✅ **Performance**: Async endpoints, batch processing
9. ✅ **Documentation**: Auto-generated OpenAPI docs
10. ✅ **Health Checks**: For load balancers and orchestrators

### Deployment Considerations:
- **Load Balancing**: NGINX, HAProxy, AWS ALB
- **Auto-Scaling**: Based on CPU/memory/request count
- **SSL/TLS**: Let's Encrypt, AWS ACM
- **Caching**: Redis for embeddings/predictions
- **Database**: PostgreSQL for users, requests
- **Monitoring**: Prometheus + Grafana, DataDog
- **Logging**: ELK stack, CloudWatch

## Interview Questions

1. **How do you handle API versioning without breaking clients?**
   - URL-based: /v1/predict, /v2/predict
   - Header-based: Accept: application/vnd.api.v2+json
   - Deprecation policy: 6-month sunset period

2. **Explain JWT vs API Keys for authentication.**
   - JWT: Stateless, contains user info, expires automatically
   - API Keys: Stateful, requires DB lookup, manual revocation
   - Use JWT for user sessions, API keys for service-to-service

3. **How do you optimize API latency?**
   - Caching: Redis for frequent requests
   - Batching: Group requests for GPU efficiency
   - Async: Non-blocking I/O with asyncio
   - CDN: Cache responses at edge locations
   - Connection pooling: Reuse DB/HTTP connections

4. **What metrics do you track for ML APIs?**
   - Throughput: Requests per second (QPS)
   - Latency: P50, P95, P99 response times
   - Error rate: 4xx and 5xx percentages
   - Model metrics: Inference time, cache hit rate
   - Business: API usage per tier, costs
