# Part IV: Structuring Large Applications

## Chapter 10: Middleware and Events

Middleware is a powerful mechanism that sits between the server and your application code, allowing you to process requests before they reach your endpoints and responses before they're sent to clients. Combined with lifespan events for startup and shutdown operations, these tools enable you to implement cross-cutting concerns like authentication, logging, error handling, and performance monitoring consistently across your entire application.

---

### 10.1 Middleware Basics: Intercepting Requests and Responses

Middleware in FastAPI (powered by Starlette) is a function or class that processes every request before it reaches your path operation and every response before it leaves your application. This enables global behavior that applies to all endpoints.

#### Understanding Middleware Flow

```
┌─────────────────────────────────────────────────────────────────┐
│                     Request Processing Flow                      │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   ┌─────────┐    ┌─────────────┐    ┌─────────────┐             │
│   │ Client  │───▶│ Middleware  │───▶│ Middleware  │───┐         │
│   │ Request │    │    #1       │    │    #2       │   │         │
│   └─────────┘    └─────────────┘    └─────────────┘   │         │
│                                                        │         │
│                                          ┌─────────────▼───┐     │
│                                          │  Path Operation │     │
│                                          │   (Endpoint)    │     │
│                                          └─────────────┬───┘     │
│                                                        │         │
│   ┌──────────┐   ┌─────────────┐    ┌─────────────┐   │         │
│   │ Client   │◀──│ Middleware  │◀───│ Middleware  │◀──┘         │
│   │ Response │   │    #2       │    │    #1       │             │
│   └──────────┘   └─────────────┘    └─────────────┘             │
│                                                                  │
│   Note: Response travels in reverse order through middleware    │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘
```

#### The `@app.middleware` Decorator

The simplest way to create middleware is using the `@app.middleware("http")` decorator. This registers a function that will be called for every HTTP request.

```python
# middleware_basics.py
import time
from fastapi import FastAPI, Request, Response

app = FastAPI()


@app.middleware("http")
async def log_requests_middleware(request: Request, call_next) -> Response:
    """
    Middleware that logs every request and its processing time.
    
    Args:
        request: The incoming HTTP request object
        call_next: A function that passes the request to the next handler
    
    Returns:
        Response: The HTTP response from the endpoint
    """
    # === PRE-PROCESSING ===
    # This code runs BEFORE the endpoint is called
    
    start_time = time.perf_counter()
    request_id = f"req-{int(start_time * 1000)}"
    
    # Log the incoming request
    print(f"[{request_id}] --> {request.method} {request.url.path}")
    
    # Store request ID in state for use in endpoints
    request.state.request_id = request_id
    
    # === CALL NEXT HANDLER ===
    # This passes control to the next middleware or endpoint
    # The endpoint processes the request and returns a response
    response = await call_next(request)
    
    # === POST-PROCESSING ===
    # This code runs AFTER the endpoint returns a response
    
    # Calculate processing time
    process_time = time.perf_counter() - start_time
    process_time_ms = round(process_time * 1000, 2)
    
    # Log the response
    print(f"[{request_id}] <-- {response.status_code} ({process_time_ms}ms)")
    
    # Add custom headers to response
    response.headers["X-Process-Time"] = f"{process_time_ms}ms"
    response.headers["X-Request-ID"] = request_id
    
    return response


@app.get("/")
async def root():
    """Simple endpoint to demonstrate middleware."""
    return {"message": "Hello, World!"}


@app.get("/slow")
async def slow_endpoint():
    """Endpoint that simulates slow processing."""
    import asyncio
    await asyncio.sleep(1)  # Simulate slow operation
    return {"message": "This was slow"}
```

**Key Points:**

1. **`@app.middleware("http")`**: Registers the function as HTTP middleware. It will be called for every request.

2. **`request` parameter**: Contains all information about the incoming request (method, URL, headers, body, etc.).

3. **`call_next` function**: This is crucial—it passes the request to the next middleware or the actual endpoint. Without calling this, the request would never reach your endpoint.

4. **Pre-processing**: Code before `call_next` runs before the endpoint. Use this for request logging, authentication checks, rate limiting, etc.

5. **Post-processing**: Code after `call_next` runs after the endpoint returns. Use this for response modification, logging, cleanup, etc.

#### Middleware for Request Timing and Headers

A common use case is measuring request duration and adding performance headers:

```python
# request_timing_middleware.py
import time
from uuid import uuid4
from fastapi import FastAPI, Request, Response

app = FastAPI()


@app.middleware("http")
async def add_timing_headers(request: Request, call_next) -> Response:
    """
    Middleware that adds timing and correlation headers to all responses.
    
    Headers added:
    - X-Request-ID: Unique identifier for request tracking
    - X-Process-Time: Time taken to process the request in seconds
    - X-Response-Time: Total response time including network
    """
    # Generate unique request ID
    request_id = str(uuid4())[:8]
    
    # Record start time with high precision
    start_time = time.perf_counter()
    
    # Attach request ID to request state (accessible in endpoints)
    request.state.request_id = request_id
    
    # Process the request
    response = await call_next(request)
    
    # Calculate timing
    end_time = time.perf_counter()
    process_time_seconds = end_time - start_time
    process_time_ms = round(process_time_seconds * 1000, 3)
    
    # Add headers to response
    response.headers["X-Request-ID"] = request_id
    response.headers["X-Process-Time"] = f"{process_time_seconds:.6f}"
    response.headers["X-Process-Time-Ms"] = f"{process_time_ms}"
    
    return response


@app.get("/api/data")
async def get_data(request: Request):
    """
    Example endpoint showing access to request state.
    The request_id set by middleware is available here.
    """
    # Access the request ID from middleware
    request_id = getattr(request.state, "request_id", "unknown")
    
    return {
        "data": "Some data",
        "request_id": request_id,
    }
```

#### Middleware for Error Handling and Logging

This middleware catches exceptions globally and ensures consistent error responses:

```python
# error_handling_middleware.py
import logging
import traceback
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()


@app.middleware("http")
async def error_handling_middleware(request: Request, call_next) -> Response:
    """
    Global error handling middleware.
    
    This middleware:
    1. Catches all unhandled exceptions
    2. Logs the error with full traceback
    3. Returns a consistent JSON error response
    4. Hides internal error details from clients (security)
    """
    try:
        # Attempt to process the request normally
        response = await call_next(request)
        return response
    
    except Exception as exc:
        # Log the full error for debugging
        logger.error(
            f"Unhandled exception processing {request.method} {request.url}:\n"
            f"{traceback.format_exc()}"
        )
        
        # Return a generic error to the client
        # In development, you might include more details
        return JSONResponse(
            status_code=500,
            content={
                "error": "Internal Server Error",
                "message": "An unexpected error occurred. Please try again later.",
                # Only include details in development
                # "details": str(exc) if DEBUG else None,
            },
        )


# Example endpoint that raises an error
@app.get("/error")
async def trigger_error():
    """Endpoint that intentionally raises an error for testing."""
    raise ValueError("This is a simulated error!")


# Example endpoint that works normally
@app.get("/normal")
async def normal_endpoint():
    """Normal endpoint that returns successfully."""
    return {"status": "ok"}
```

#### Request Body Access in Middleware

Reading the request body in middleware requires special handling because the body stream can only be read once:

```python
# body_access_middleware.py
import json
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse

app = FastAPI()


@app.middleware("http")
async def log_request_body_middleware(request: Request, call_next) -> Response:
    """
    Middleware that logs request body for debugging.
    
    IMPORTANT: Request body can only be read once, so we must
    store it and replace it for the endpoint to use.
    """
    # Only process body for certain content types
    if request.method in ["POST", "PUT", "PATCH"]:
        content_type = request.headers.get("content-type", "")
        
        if "application/json" in content_type:
            # Read the body (this consumes the stream)
            body_bytes = await request.body()
            
            try:
                # Parse and log the JSON body
                body_json = json.loads(body_bytes)
                print(f"Request Body: {json.dumps(body_json, indent=2)}")
            except json.JSONDecodeError:
                print(f"Request Body (raw): {body_bytes.decode()}")
            
            # Reconstruct the request with the body for the endpoint
            # This is needed because we already consumed the body stream
            async def receive():
                return {"type": "http.request", "body": body_bytes}
            
            request._receive = receive
    
    # Continue processing
    response = await call_next(request)
    return response


@app.post("/items")
async def create_item(request: Request):
    """Endpoint that receives JSON data."""
    body = await request.json()  # This works because we restored the body
    return {"received": body}
```

---

> **Industry Standard:** When reading request body in middleware, always store and restore the body stream. Without this, endpoints won't be able to read the request body since streams can only be consumed once.

---

---

### 10.2 Built-in Middleware: CORS, GZip, Trusted Host, and HTTPS Redirect

FastAPI includes several built-in middleware classes for common security and performance needs. These are production-ready and highly configurable.

#### CORS (Cross-Origin Resource Sharing) Middleware

CORS controls which domains can access your API from browsers. This is essential for APIs consumed by web applications:

```python
# cors_middleware.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# CORS Middleware Configuration
app.add_middleware(
    CORSMiddleware,
    # === ALLOWED ORIGINS ===
    # Specifies which domains can make requests
    # Use ["*"] for development only - be specific in production
    allow_origins=[
        "http://localhost:3000",       # Local React/Vue dev server
        "http://localhost:8080",       # Another local frontend
        "https://example.com",          # Production frontend
        "https://www.example.com",      # www subdomain
    ],
    
    # === CREDENTIALS ===
    # Allows cookies and authorization headers
    # Required for authenticated requests with cookies
    # When True, allow_origins cannot be ["*"]
    allow_credentials=True,
    
    # === ALLOWED METHODS ===
    # Which HTTP methods are allowed
    # ["*"] allows all methods
    allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
    
    # === ALLOWED HEADERS ===
    # Which request headers are allowed
    # ["*"] allows all headers
    allow_headers=[
        "Content-Type",
        "Authorization",
        "X-Requested-With",
        "X-Request-ID",
    ],
    
    # === EXPOSED HEADERS ===
    # Which response headers the browser can access
    # By default, browsers only see standard headers
    expose_headers=[
        "X-Total-Count",      # Pagination total
        "X-Request-ID",       # Request tracking
        "X-Process-Time",     # Performance metrics
    ],
    
    # === MAX AGE ===
    # How long (seconds) browsers cache CORS preflight responses
    # Preflight = OPTIONS request browsers send before actual request
    max_age=3600,  # 1 hour
)


@app.get("/api/data")
async def get_data():
    """
    This endpoint is now accessible from allowed origins.
    Browsers will check CORS headers before allowing requests.
    """
    return {"data": "This is accessible cross-origin"}


# === CONFIGURATION FROM ENVIRONMENT ===
# In production, load allowed origins from configuration

from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    allowed_origins: list[str] = ["http://localhost:3000"]
    
    class Config:
        env_file = ".env"

# settings = Settings()
# app.add_middleware(
#     CORSMiddleware,
#     allow_origins=settings.allowed_origins,
#     ...
# )
```

**How CORS Works:**

1. **Simple Requests**: Direct requests that don't trigger preflight
2. **Preflight Requests**: Browser sends `OPTIONS` request first for complex requests
3. **Credentials**: When enabled, origins must be explicit (not `*`)
4. **Exposed Headers**: Custom headers won't be visible to JavaScript unless exposed

#### GZip Compression Middleware

GZip compression reduces response size, improving transfer speed for large responses:

```python
# gzip_middleware.py
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware

app = FastAPI()

# Add GZip compression middleware
app.add_middleware(
    GZipMiddleware,
    # Minimum response size to compress (bytes)
    # Responses smaller than this won't be compressed
    # 1000 bytes (1KB) is a good default
    minimum_size=1000,
    
    # Compression level (0-9)
    # Higher = better compression but more CPU
    # 6 is a good balance (and the default)
    compresslevel=6,
)


@app.get("/large-data")
async def get_large_data():
    """
    This response will be compressed if larger than 1000 bytes.
    The browser automatically decompresses it.
    """
    # Generate large dataset
    items = [
        {
            "id": i,
            "name": f"Item {i}",
            "description": f"This is a detailed description for item {i}. " * 10,
            "tags": [f"tag-{j}" for j in range(10)],
        }
        for i in range(100)
    ]
    
    return {
        "items": items,
        "total": len(items),
        "message": "This response will be GZip compressed",
    }


@app.get("/small-data")
async def get_small_data():
    """
    Small responses (< 1000 bytes) won't be compressed.
    Compression overhead would outweigh benefits.
    """
    return {"status": "ok", "data": "small"}
```

**GZip Configuration Tips:**

| Setting | Recommended Value | Reason |
|---------|-------------------|--------|
| `minimum_size` | 1000-5000 bytes | Smaller responses don't benefit much |
| `compresslevel` | 4-6 | Good balance of compression vs CPU |
| For APIs | Enable | JSON compresses well (60-80% reduction) |
| For static files | Enable | CSS, JS, HTML compress excellently |

#### Trusted Host Middleware

Protects against Host header attacks by validating the `Host` header:

```python
# trusted_host_middleware.py
from fastapi import FastAPI
from fastapi.middleware.trustedhost import TrustedHostMiddleware

app = FastAPI()

# Add trusted host middleware
app.add_middleware(
    TrustedHostMiddleware,
    # List of allowed host values
    # Requests with other Host headers will be rejected with 400
    allowed_hosts=[
        "example.com",           # Production domain
        "www.example.com",       # www subdomain
        "api.example.com",       # API subdomain
        "localhost",             # Local development
        "127.0.0.1",             # Local IP
        "*.example.com",         # Wildcard for any subdomain
    ],
)


@app.get("/")
async def root():
    """
    Only accessible if the Host header matches allowed_hosts.
    
    Example:
    - curl -H "Host: example.com" http://localhost:8000/  ✓
    - curl -H "Host: evil.com" http://localhost:8000/     ✗ (400 Bad Request)
    """
    return {"message": "Host header is valid"}


# === DYNAMIC CONFIGURATION ===
# Load allowed hosts from environment

class Settings(BaseSettings):
    allowed_hosts: list[str] = ["localhost", "127.0.0.1"]
    
    @field_validator("allowed_hosts", mode="before")
    @classmethod
    def parse_hosts(cls, v: str | list[str]) -> list[str]:
        if isinstance(v, str):
            return [h.strip() for h in v.split(",")]
        return v

# settings = Settings()
# app.add_middleware(TrustedHostMiddleware, allowed_hosts=settings.allowed_hosts)
```

**Why Trusted Host Matters:**

- Prevents cache poisoning attacks
- Stops password reset emails with malicious links
- Blocks certain SSRF (Server-Side Request Forgery) attacks
- Essential when your app is behind a reverse proxy

#### HTTPS Redirect Middleware

Automatically redirects HTTP requests to HTTPS:

```python
# https_redirect_middleware.py
from fastapi import FastAPI
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

app = FastAPI()

# Add HTTPS redirect middleware
# ALL requests will be redirected to HTTPS
app.add_middleware(HTTPSRedirectMiddleware)


@app.get("/")
async def root():
    """
    Any HTTP request will be redirected to HTTPS.
    
    Example:
    - http://example.com/ → https://example.com/ (301 redirect)
    """
    return {
        "message": "You are now on HTTPS!",
        "security": "All requests must use HTTPS",
    }


# === CONDITIONAL HTTPS REDIRECT ===
# In development, you might not want HTTPS redirect

from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    environment: str = "development"
    debug: bool = True

settings = Settings()

if settings.environment == "production":
    # Only enable HTTPS redirect in production
    app.add_middleware(HTTPSRedirectMiddleware)
```

#### Combining Multiple Middleware

Middleware is executed in **reverse order** of addition (LIFO - Last In, First Out):

```python
# combined_middleware.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import time

app = FastAPI()


# === EXECUTION ORDER ===
# Request flows through middleware in this order:
# 1. Custom middleware (added last, executed first)
# 2. TrustedHostMiddleware
# 3. GZipMiddleware
# 4. CORSMiddleware (added first, executed last for request)

# Response flows in reverse:
# 1. CORSMiddleware
# 2. GZipMiddleware
# 3. TrustedHostMiddleware
# 4. Custom middleware


# CORS - Added first (executed last for requests)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# GZip Compression - Added second
app.add_middleware(
    GZipMiddleware,
    minimum_size=1000,
    compresslevel=6,
)

# Trusted Host - Added third
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["localhost", "127.0.0.1", "*.example.com"],
)

# Custom Timing Middleware - Added last (executed first for requests)
@app.middleware("http")
async def timing_middleware(request: Request, call_next) -> Response:
    """Custom middleware that runs first for requests."""
    start = time.perf_counter()
    
    # Request flows to next middleware (TrustedHost -> GZip -> CORS -> Endpoint)
    response = await call_next(request)
    
    process_time = time.perf_counter() - start
    response.headers["X-Process-Time"] = f"{process_time:.6f}"
    
    return response


@app.get("/")
async def root():
    return {"message": "All middleware configured"}
```

---

### 10.3 Custom Middleware: Creating Your Own Middleware Classes

For reusable, complex middleware logic, create dedicated middleware classes that inherit from Starlette's base middleware.

#### Basic Custom Middleware Class

```python
# custom_middleware.py
from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Callable, Awaitable
import time
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()


class TimingMiddleware(BaseHTTPMiddleware):
    """
    Custom middleware class for measuring request processing time.
    
    This middleware:
    1. Records the start time before processing
    2. Passes request to next handler
    3. Records end time after response
    4. Logs and adds timing headers to response
    
    Inherit from BaseHTTPMiddleware for HTTP-specific middleware.
    """
    
    async def dispatch(
        self,
        request: Request,
        call_next: RequestResponseEndpoint,
    ) -> Response:
        """
        The main middleware method that processes each request.
        
        Args:
            request: The incoming HTTP request
            call_next: The next middleware or endpoint to call
        
        Returns:
            Response: The HTTP response with added headers
        """
        # Pre-processing: Record start time
        start_time = time.perf_counter()
        
        # Call the next handler (middleware or endpoint)
        response = await call_next(request)
        
        # Post-processing: Calculate and record time
        process_time = time.perf_counter() - start_time
        process_time_ms = round(process_time * 1000, 4)
        
        # Add timing header
        response.headers["X-Process-Time-Ms"] = str(process_time_ms)
        
        # Log the request timing
        logger.info(
            f"{request.method} {request.url.path} - "
            f"{response.status_code} - {process_time_ms}ms"
        )
        
        return response


# Add the custom middleware to the app
app.add_middleware(TimingMiddleware)


@app.get("/")
async def root():
    return {"message": "Custom middleware is active"}
```

#### Rate Limiting Middleware

A practical example: rate limiting middleware to prevent abuse:

```python
# rate_limit_middleware.py
from fastapi import FastAPI, Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from collections import defaultdict
from datetime import datetime, timedelta
import asyncio
from typing import Dict, Tuple

app = FastAPI()


class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    Rate limiting middleware using the sliding window algorithm.
    
    Limits the number of requests per client IP within a time window.
    Uses an in-memory store (for demonstration - use Redis in production).
    
    Configuration:
    - requests_per_minute: Maximum requests allowed per minute
    - window_seconds: Time window in seconds
    """
    
    def __init__(
        self,
        app,
        requests_per_minute: int = 60,
        window_seconds: int = 60,
    ):
        super().__init__(app)
        self.requests_per_minute = requests_per_minute
        self.window_seconds = window_seconds
        
        # Store: IP -> list of request timestamps
        # In production, use Redis or similar
        self.request_history: Dict[str, list] = defaultdict(list)
        
        # Lock for thread-safe operations
        self._lock = asyncio.Lock()
    
    async def dispatch(
        self,
        request: Request,
        call_next: RequestResponseEndpoint,
    ) -> Response:
        """
        Process request with rate limiting.
        
        Steps:
        1. Extract client IP address
        2. Check request count within window
        3. Allow or reject based on limit
        4. Add rate limit headers to response
        """
        # Skip rate limiting for health checks
        if request.url.path in ["/health", "/metrics"]:
            return await call_next(request)
        
        # Get client IP (handle reverse proxy)
        client_ip = self._get_client_ip(request)
        
        # Check and update rate limit
        allowed, remaining, reset_time = await self._check_rate_limit(client_ip)
        
        if not allowed:
            # Return 429 Too Many Requests
            return Response(
                content='{"detail": "Rate limit exceeded. Please slow down."}',
                status_code=429,
                media_type="application/json",
                headers={
                    "X-RateLimit-Limit": str(self.requests_per_minute),
                    "X-RateLimit-Remaining": "0",
                    "X-RateLimit-Reset": str(reset_time),
                    "Retry-After": str(int(reset_time - datetime.utcnow().timestamp())),
                }
            )
        
        # Process request
        response = await call_next(request)
        
        # Add rate limit headers
        response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
        response.headers["X-RateLimit-Remaining"] = str(remaining)
        response.headers["X-RateLimit-Reset"] = str(reset_time)
        
        return response
    
    def _get_client_ip(self, request: Request) -> str:
        """
        Extract client IP from request.
        
        Handles common reverse proxy headers:
        - X-Forwarded-For: Standard header for proxied IPs
        - X-Real-IP: Alternative header
        - Direct connection: Uses request.client.host
        """
        # Check X-Forwarded-For (most common)
        forwarded_for = request.headers.get("X-Forwarded-For")
        if forwarded_for:
            # Return first IP (original client)
            return forwarded_for.split(",")[0].strip()
        
        # Check X-Real-IP
        real_ip = request.headers.get("X-Real-IP")
        if real_ip:
            return real_ip.strip()
        
        # Direct connection
        if request.client:
            return request.client.host
        
        return "unknown"
    
    async def _check_rate_limit(
        self,
        client_ip: str,
    ) -> Tuple[bool, int, int]:
        """
        Check if client has exceeded rate limit.
        
        Uses sliding window algorithm:
        1. Get timestamps of requests in current window
        2. Count them
        3. If under limit, add current timestamp
        
        Returns:
            Tuple of (allowed: bool, remaining: int, reset_time: int)
        """
        async with self._lock:
            now = datetime.utcnow()
            window_start = now - timedelta(seconds=self.window_seconds)
            reset_time = int((now + timedelta(seconds=self.window_seconds)).timestamp())
            
            # Get request history for this IP
            history = self.request_history[client_ip]
            
            # Remove timestamps outside the window (sliding window)
            history[:] = [ts for ts in history if ts > window_start]
            
            # Count requests in window
            current_count = len(history)
            remaining = max(0, self.requests_per_minute - current_count - 1)
            
            # Check if under limit
            if current_count >= self.requests_per_minute:
                return False, 0, reset_time
            
            # Add current request timestamp
            history.append(now)
            
            return True, remaining, reset_time


# Add rate limiting middleware
app.add_middleware(
    RateLimitMiddleware,
    requests_per_minute=10,  # Low for testing
    window_seconds=60,
)


@app.get("/")
async def root():
    return {"message": "Rate limiting is active"}


@app.get("/health")
async def health():
    """Health check - skips rate limiting."""
    return {"status": "healthy"}
```

#### Request ID Middleware with Context

For distributed tracing, add unique request IDs that propagate through your system:

```python
# request_id_middleware.py
import contextvars
from uuid import uuid4
from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Optional

app = FastAPI()

# Context variable to store request ID (accessible anywhere in the request)
request_id_context: contextvars.ContextVar[str] = contextvars.ContextVar("request_id")


def get_request_id() -> Optional[str]:
    """
    Get the current request ID from context.
    Can be called from anywhere during request processing.
    """
    try:
        return request_id_context.get()
    except LookupError:
        return None


class RequestIDMiddleware(BaseHTTPMiddleware):
    """
    Middleware that assigns a unique ID to each request.
    
    The request ID is:
    1. Generated if not provided in X-Request-ID header
    2. Stored in context for access anywhere
    3. Added to response headers
    4. Available in all log messages
    
    This enables:
    - Distributed tracing across services
    - Log correlation
    - Customer support ticket reference
    """
    
    async def dispatch(
        self,
        request: Request,
        call_next: RequestResponseEndpoint,
    ) -> Response:
        # Check for existing request ID (from load balancer, etc.)
        request_id = request.headers.get("X-Request-ID")
        
        if not request_id:
            # Generate new request ID
            request_id = str(uuid4())
        
        # Store in context for this request
        request_id_context.set(request_id)
        
        # Also store in request state for access in endpoints
        request.state.request_id = request_id
        
        # Process request
        response = await call_next(request)
        
        # Add request ID to response headers
        response.headers["X-Request-ID"] = request_id
        
        return response


app.add_middleware(RequestIDMiddleware)


# Custom logger that includes request ID
import logging

class RequestIDFormatter(logging.Formatter):
    """Log formatter that includes request ID."""
    
    def format(self, record):
        # Add request ID to log record
        record.request_id = get_request_id() or "no-request-id"
        return super().format(record)


# Configure logging with request ID
handler = logging.StreamHandler()
handler.setFormatter(
    RequestIDFormatter(
        "%(asctime)s [%(request_id)s] %(levelname)s: %(message)s"
    )
)
logging.basicConfig(level=logging.INFO, handlers=[handler])
logger = logging.getLogger(__name__)


@app.get("/")
async def root(request: Request):
    # Access request ID from request state
    request_id = request.state.request_id
    
    # Or from context
    context_request_id = get_request_id()
    
    # Log with request ID (automatically included)
    logger.info("Processing root endpoint")
    
    return {
        "message": "Hello",
        "request_id": request_id,
        "context_request_id": context_request_id,
    }
```

#### Authentication Middleware

A middleware that performs authentication before requests reach endpoints:

```python
# auth_middleware.py
from fastapi import FastAPI, Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Optional
import jwt

app = FastAPI()

# Configuration
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"

# Paths that don't require authentication
PUBLIC_PATHS = {
    "/",
    "/health",
    "/docs",
    "/openapi.json",
    "/auth/login",
    "/auth/register",
}


class AuthenticationMiddleware(BaseHTTPMiddleware):
    """
    Middleware that validates JWT tokens for protected routes.
    
    This middleware:
    1. Checks if the path requires authentication
    2. Extracts the JWT token from Authorization header
    3. Validates the token
    4. Attaches user info to request.state
    5. Returns 401 for invalid/missing tokens
    
    Note: For fine-grained control, use Depends() with OAuth2PasswordBearer
          instead of middleware. Middleware is for global protection.
    """
    
    async def dispatch(
        self,
        request: Request,
        call_next: RequestResponseEndpoint,
    ) -> Response:
        # Skip authentication for public paths
        if request.url.path in PUBLIC_PATHS:
            return await call_next(request)
        
        # Also skip OPTIONS requests (CORS preflight)
        if request.method == "OPTIONS":
            return await call_next(request)
        
        # Extract token from Authorization header
        token = self._extract_token(request)
        
        if not token:
            return Response(
                content='{"detail": "Not authenticated"}',
                status_code=401,
                media_type="application/json",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        # Validate token and get user info
        user_info = self._validate_token(token)
        
        if not user_info:
            return Response(
                content='{"detail": "Invalid or expired token"}',
                status_code=401,
                media_type="application/json",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        # Attach user info to request state for use in endpoints
        request.state.user = user_info
        
        # Continue processing
        return await call_next(request)
    
    def _extract_token(self, request: Request) -> Optional[str]:
        """Extract JWT token from Authorization header."""
        auth_header = request.headers.get("Authorization")
        
        if not auth_header:
            return None
        
        # Expected format: "Bearer <token>"
        parts = auth_header.split()
        
        if len(parts) != 2 or parts[0].lower() != "bearer":
            return None
        
        return parts[1]
    
    def _validate_token(self, token: str) -> Optional[dict]:
        """Validate JWT token and return user info."""
        try:
            payload = jwt.decode(
                token,
                SECRET_KEY,
                algorithms=[ALGORITHM],
            )
            return payload
        except jwt.ExpiredSignatureError:
            return None
        except jwt.InvalidTokenError:
            return None


app.add_middleware(AuthenticationMiddleware)


@app.get("/")
async def public():
    """Public endpoint - no authentication required."""
    return {"message": "Public endpoint"}


@app.get("/protected")
async def protected(request: Request):
    """Protected endpoint - requires valid token."""
    user = request.state.user
    return {
        "message": "Authenticated!",
        "user": user,
    }


@app.post("/auth/login")
async def login():
    """Login endpoint - generates token."""
    # In real app, validate credentials first
    token = jwt.encode(
        {"sub": "user123", "username": "alice"},
        SECRET_KEY,
        algorithm=ALGORITHM,
    )
    return {"access_token": token, "token_type": "bearer"}
```

---

### 10.4 Lifespan Events: Handling Startup and Shutdown Logic

Lifespan events allow you to run code when your application starts up and shuts down. This is essential for initializing resources (database connections, ML models, cache clients) and cleaning them up properly.

#### The Modern Approach: `lifespan` Context Manager

In FastAPI's current version, use the `lifespan` context manager instead of the deprecated `on_event` decorator:

```python
# lifespan_events.py
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === RESOURCE SIMULATION ===
# These represent resources that need initialization

class DatabaseConnection:
    """Simulated database connection."""
    
    def __init__(self):
        self.connected = False
    
    async def connect(self):
        """Establish database connection."""
        logger.info("Connecting to database...")
        # Simulate connection delay
        import asyncio
        await asyncio.sleep(0.5)
        self.connected = True
        logger.info("Database connected!")
    
    async def disconnect(self):
        """Close database connection."""
        logger.info("Disconnecting from database...")
        self.connected = False
        logger.info("Database disconnected!")
    
    async def query(self, sql: str):
        """Execute a query."""
        if not self.connected:
            raise RuntimeError("Database not connected")
        return [{"id": 1, "name": "Result"}]


class CacheClient:
    """Simulated cache client (Redis-like)."""
    
    def __init__(self):
        self.connected = False
        self._cache = {}
    
    async def connect(self):
        """Connect to cache server."""
        logger.info("Connecting to cache...")
        self.connected = True
        logger.info("Cache connected!")
    
    async def disconnect(self):
        """Disconnect from cache."""
        logger.info("Disconnecting from cache...")
        self._cache.clear()
        self.connected = False
        logger.info("Cache disconnected!")
    
    async def get(self, key: str):
        return self._cache.get(key)
    
    async def set(self, key: str, value: str):
        self._cache[key] = value


class MachineLearningModel:
    """Simulated ML model that needs loading."""
    
    def __init__(self):
        self.loaded = False
    
    async def load(self):
        """Load the model into memory."""
        logger.info("Loading ML model (this may take a while)...")
        import asyncio
        await asyncio.sleep(1)  # Simulate loading time
        self.loaded = True
        logger.info("ML model loaded!")
    
    async def unload(self):
        """Unload the model from memory."""
        logger.info("Unloading ML model...")
        self.loaded = False
        logger.info("ML model unloaded!")
    
    async def predict(self, data):
        if not self.loaded:
            raise RuntimeError("Model not loaded")
        return {"prediction": "positive", "confidence": 0.95}


# Global resources
db = DatabaseConnection()
cache = CacheClient()
ml_model = MachineLearningModel()


# === LIFESPAN CONTEXT MANAGER ===

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
    """
    Application lifespan context manager.
    
    Handles startup and shutdown events for the entire application.
    
    Code BEFORE yield runs at STARTUP (before any requests)
    Code AFTER yield runs at SHUTDOWN (after all requests complete)
    
    This is the modern replacement for @app.on_event("startup")
    and @app.on_event("shutdown").
    """
    # ========================================
    # STARTUP - Runs once when application starts
    # ========================================
    
    logger.info("=" * 50)
    logger.info("APPLICATION STARTUP")
    logger.info("=" * 50)
    
    # Initialize database
    await db.connect()
    
    # Initialize cache
    await cache.connect()
    
    # Load ML model
    await ml_model.load()
    
    # Store resources in app state for access in endpoints
    app.state.db = db
    app.state.cache = cache
    app.state.ml_model = ml_model
    
    logger.info("All resources initialized. Ready to serve requests!")
    logger.info("=" * 50)
    
    # ========================================
    # YIELD - Application runs here
    # ========================================
    # The application serves requests while inside this yield block
    
    yield
    
    # ========================================
    # SHUTDOWN - Runs once when application stops
    # ========================================
    
    logger.info("=" * 50)
    logger.info("APPLICATION SHUTDOWN")
    logger.info("=" * 50)
    
    # Clean up in reverse order of initialization
    await ml_model.unload()
    await cache.disconnect()
    await db.disconnect()
    
    logger.info("All resources cleaned up. Goodbye!")
    logger.info("=" * 50)


# Create application with lifespan
app = FastAPI(
    title="Lifespan Demo",
    description="Demonstrates startup/shutdown event handling",
    lifespan=lifespan,
)


# === ENDPOINTS ===

@app.get("/")
async def root():
    """Root endpoint showing resource status."""
    return {
        "message": "Application is running",
        "resources": {
            "database": "connected" if db.connected else "disconnected",
            "cache": "connected" if cache.connected else "disconnected",
            "ml_model": "loaded" if ml_model.loaded else "unloaded",
        },
    }


@app.get("/db/query")
async def db_query():
    """Endpoint using database connection."""
    results = await db.query("SELECT * FROM items")
    return {"results": results}


@app.get("/cache/{key}")
async def cache_get(key: str):
    """Endpoint using cache."""
    value = await cache.get(key)
    return {"key": key, "value": value}


@app.post("/cache/{key}")
async def cache_set(key: str, value: str):
    """Set a value in cache."""
    await cache.set(key, value)
    return {"key": key, "value": value}


@app.post("/predict")
async def predict(data: dict):
    """Endpoint using ML model."""
    prediction = await ml_model.predict(data)
    return prediction
```

#### Complete Production Example: Application with Real Resources

```python
# app/main.py - Complete production setup
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import logging
import asyncio

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
import redis.asyncio as redis

from app.core.config import get_settings
from app.api.v1.router import api_router

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

settings = get_settings()


# === RESOURCE MANAGERS ===

class DatabaseManager:
    """Manages database connections and sessions."""
    
    def __init__(self, database_url: str):
        self.database_url = database_url
        self.engine = None
        self.session_factory = None
    
    async def connect(self):
        """Initialize database engine and session factory."""
        logger.info(f"Connecting to database: {self.database_url.split('@')[-1]}")
        
        self.engine = create_async_engine(
            self.database_url,
            pool_size=5,
            max_overflow=10,
            echo=settings.debug,
        )
        
        self.session_factory = async_sessionmaker(
            self.engine,
            class_=AsyncSession,
            expire_on_commit=False,
        )
        
        # Verify connection
        async with self.engine.begin() as conn:
            await conn.execute("SELECT 1")
        
        logger.info("Database connected successfully")
    
    async def disconnect(self):
        """Close database connections."""
        if self.engine:
            logger.info("Closing database connections...")
            await self.engine.dispose()
            logger.info("Database connections closed")
    
    def get_session(self) -> AsyncSession:
        """Get a new database session."""
        return self.session_factory()


class RedisManager:
    """Manages Redis connection."""
    
    def __init__(self, redis_url: str):
        self.redis_url = redis_url
        self.client = None
    
    async def connect(self):
        """Connect to Redis."""
        logger.info(f"Connecting to Redis: {self.redis_url}")
        
        self.client = redis.from_url(
            self.redis_url,
            encoding="utf-8",
            decode_responses=True,
        )
        
        # Verify connection
        await self.client.ping()
        
        logger.info("Redis connected successfully")
    
    async def disconnect(self):
        """Disconnect from Redis."""
        if self.client:
            logger.info("Closing Redis connection...")
            await self.client.close()
            logger.info("Redis connection closed")


# Global resource instances
db_manager = DatabaseManager(settings.database_url)
redis_manager = RedisManager(settings.redis_url)


# === LIFESPAN ===

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
    """
    Application lifespan manager.
    
    Handles:
    - Database connection pooling
    - Redis cache connection
    - Background task initialization
    - Graceful shutdown
    """
    # STARTUP
    logger.info("=" * 60)
    logger.info(f"Starting {settings.app_name} v{settings.app_version}")
    logger.info(f"Environment: {settings.environment}")
    logger.info("=" * 60)
    
    # Initialize resources in order of dependency
    try:
        # 1. Database (most critical)
        await db_manager.connect()
        
        # 2. Redis cache
        await redis_manager.connect()
        
        # 3. Store in app state
        app.state.db = db_manager
        app.state.redis = redis_manager.client
        
        logger.info("All resources initialized successfully")
        
    except Exception as e:
        logger.error(f"Failed to initialize resources: {e}")
        # Clean up any partially initialized resources
        await db_manager.disconnect()
        await redis_manager.disconnect()
        raise
    
    # Signal ready
    logger.info("=" * 60)
    logger.info("Application ready to serve requests")
    logger.info("=" * 60)
    
    yield  # Application runs here
    
    # SHUTDOWN
    logger.info("=" * 60)
    logger.info("Shutting down application...")
    logger.info("=" * 60)
    
    # Clean up in reverse order
    await redis_manager.disconnect()
    await db_manager.disconnect()
    
    logger.info("Shutdown complete")


# === CREATE APPLICATION ===

def create_application() -> FastAPI:
    """Create and configure the FastAPI application."""
    app = FastAPI(
        title=settings.app_name,
        version=settings.app_version,
        description=settings.app_description,
        openapi_url=f"{settings.api_v1_prefix}/openapi.json",
        docs_url=f"{settings.api_v1_prefix}/docs",
        redoc_url=f"{settings.api_v1_prefix}/redoc",
        lifespan=lifespan,
    )
    
    # Add middleware (in reverse order of execution for requests)
    
    # CORS - Handle cross-origin requests
    app.add_middleware(
        CORSMiddleware,
        allow_origins=settings.allowed_origins,
        allow_credentials=True,
        allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
        allow_headers=["*"],
        expose_headers=["X-Total-Count", "X-Request-ID"],
    )
    
    # GZip compression
    app.add_middleware(
        GZipMiddleware,
        minimum_size=1000,
        compresslevel=6,
    )
    
    # Trusted hosts (security)
    if settings.environment == "production":
        app.add_middleware(
            TrustedHostMiddleware,
            allowed_hosts=settings.allowed_hosts,
        )
    
    # Include API routers
    app.include_router(api_router, prefix=settings.api_v1_prefix)
    
    # Health check endpoints (no authentication required)
    @app.get("/health", tags=["Health"])
    async def health_check():
        """Basic health check."""
        return {
            "status": "healthy",
            "version": settings.app_version,
            "environment": settings.environment,
        }
    
    @app.get("/ready", tags=["Health"])
    async def readiness_check(request: Request):
        """
        Readiness check for orchestration.
        
        Verifies all dependencies are available.
        """
        checks = {}
        
        # Check database
        try:
            async with request.app.state.db.get_session() as session:
                await session.execute("SELECT 1")
            checks["database"] = "ok"
        except Exception as e:
            checks["database"] = f"error: {str(e)}"
        
        # Check Redis
        try:
            await request.app.state.redis.ping()
            checks["redis"] = "ok"
        except Exception as e:
            checks["redis"] = f"error: {str(e)}"
        
        all_healthy = all(v == "ok" for v in checks.values())
        
        return {
            "status": "ready" if all_healthy else "not_ready",
            "checks": checks,
        }
    
    return app


# Create app instance
app = create_application()


if __name__ == "__main__":
    import uvicorn
    
    uvicorn.run(
        "app.main:app",
        host="0.0.0.0",
        port=8000,
        reload=settings.debug,
        log_level="info",
    )
```

---

### Summary

In this chapter, you've learned to manage requests and application lifecycle:

1. **Middleware Basics**: Understanding how middleware intercepts requests and responses, using `@app.middleware` decorator for simple middleware, and the flow of request processing.

2. **Built-in Middleware**: Using `CORSMiddleware` for cross-origin requests, `GZipMiddleware` for compression, `TrustedHostMiddleware` for security, and `HTTPSRedirectMiddleware` for SSL enforcement.

3. **Custom Middleware**: Creating reusable middleware classes by inheriting from `BaseHTTPMiddleware`, implementing rate limiting, request ID tracking, and authentication.

4. **Lifespan Events**: Using the `lifespan` context manager for startup and shutdown logic, initializing resources like database connections and ML models, and ensuring proper cleanup.

---

### Exercises

1. **Request Logging Middleware**: Create middleware that logs all requests with:
   - Timestamp, method, path, status code, processing time
   - Request body (for POST/PUT/PATCH)
   - Response body (truncated for large responses)
   - Write logs to a file in JSON format

2. **Rate Limiter with Redis**: Implement a production-ready rate limiter that:
   - Uses Redis for distributed rate limiting
   - Supports different limits per endpoint
   - Handles authenticated users differently
   - Returns proper rate limit headers

3. **Caching Middleware**: Create middleware that:
   - Caches GET responses in Redis
   - Uses request path + query params as cache key
   - Respects Cache-Control headers
   - Allows cache invalidation via headers

4. **Graceful Shutdown**: Enhance the lifespan handler to:
   - Accept SIGTERM and SIGINT signals
   - Stop accepting new requests
   - Wait for existing requests to complete
   - Timeout and force-close after 30 seconds

---

### What's Next?

**Chapter 11: Security Fundamentals** will cover:
- Understanding security schemes in OpenAPI
- Implementing HTTP Basic authentication
- Password hashing and secure storage
- Security best practices for production APIs

