# Authentication and Security in FastAPI

This notebook covers authentication and security in FastAPI:
- API key authentication
- JWT tokens (JSON Web Tokens)
- Rate limiting
- CORS configuration
- Input validation and sanitization

Security is critical for production APIs. FastAPI provides excellent tools for implementing secure authentication.

## Setup

In [None]:
from fastapi import FastAPI, Depends, HTTPException, status, Security, Request
from fastapi.security import APIKeyHeader, HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field, validator, EmailStr
from typing import Optional, Dict, List
from datetime import datetime, timedelta
from jose import JWTError, jwt
from passlib.context import CryptContext
import time
import secrets
from collections import defaultdict
import asyncio

## 1. API Key Authentication

Simple authentication using API keys in headers.

In [None]:
app = FastAPI(
    title="Secure API with Authentication",
    description="Authentication and security examples",
    version="1.0.0"
)

# Simulated API key database
API_KEYS = {
    "test_key_123": {"user": "testuser", "role": "user", "created": datetime.now().isoformat()},
    "admin_key_456": {"user": "admin", "role": "admin", "created": datetime.now().isoformat()},
    "dev_key_789": {"user": "developer", "role": "developer", "created": datetime.now().isoformat()},
}

# API key security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)

class APIKeyInfo(BaseModel):
    user: str
    role: str

async def get_api_key(api_key: str = Security(api_key_header)) -> APIKeyInfo:
    """Validate API key and return user info"""
    if api_key is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="API key missing. Include X-API-Key header."
        )
    
    if api_key not in API_KEYS:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API key"
        )
    
    key_info = API_KEYS[api_key]
    return APIKeyInfo(user=key_info["user"], role=key_info["role"])

async def require_admin(api_key_info: APIKeyInfo = Depends(get_api_key)):
    """Require admin role"""
    if api_key_info.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Admin access required"
        )
    return api_key_info

# Public endpoint - no authentication
@app.get("/public/status")
async def public_status():
    """Public endpoint - no authentication required"""
    return {
        "status": "online",
        "version": "1.0.0",
        "message": "This endpoint is public"
    }

# Protected endpoint - requires valid API key
@app.get("/protected/data")
async def protected_data(api_key_info: APIKeyInfo = Depends(get_api_key)):
    """Protected endpoint - requires valid API key"""
    return {
        "message": "This is protected data",
        "user": api_key_info.user,
        "role": api_key_info.role,
        "data": [1, 2, 3, 4, 5]
    }

# Admin-only endpoint
@app.get("/admin/users")
async def list_users(api_key_info: APIKeyInfo = Depends(require_admin)):
    """Admin-only endpoint"""
    return {
        "message": "Admin access granted",
        "users": list(API_KEYS.values()),
        "admin": api_key_info.user
    }

# Generate new API key (admin only)
@app.post("/admin/keys/generate")
async def generate_api_key(
    username: str,
    role: str = "user",
    api_key_info: APIKeyInfo = Depends(require_admin)
):
    """Generate new API key (admin only)"""
    new_key = secrets.token_urlsafe(32)
    API_KEYS[new_key] = {
        "user": username,
        "role": role,
        "created": datetime.now().isoformat()
    }
    return {
        "message": "API key generated",
        "api_key": new_key,
        "username": username,
        "role": role
    }

print("API key authentication endpoints added!")

## 2. JWT Token Authentication

More sophisticated authentication using JWT tokens.

In [None]:
# JWT Configuration
SECRET_KEY = "your-secret-key-keep-it-secret-in-production"  # In production, use environment variable
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# Simulated user database
USERS_DB = {
    "john": {
        "username": "john",
        "full_name": "John Doe",
        "email": "john@example.com",
        "hashed_password": pwd_context.hash("secret123"),
        "role": "user"
    },
    "alice": {
        "username": "alice",
        "full_name": "Alice Admin",
        "email": "alice@example.com",
        "hashed_password": pwd_context.hash("admin456"),
        "role": "admin"
    }
}

# Pydantic models
class Token(BaseModel):
    access_token: str
    token_type: str

class TokenData(BaseModel):
    username: Optional[str] = None
    role: Optional[str] = None

class User(BaseModel):
    username: str
    email: EmailStr
    full_name: Optional[str] = None
    role: str

class UserInDB(User):
    hashed_password: str

class LoginRequest(BaseModel):
    username: str
    password: str

# JWT helper functions
def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify password against hash"""
    return pwd_context.verify(plain_password, hashed_password)

def get_user(username: str) -> Optional[UserInDB]:
    """Get user from database"""
    if username in USERS_DB:
        user_dict = USERS_DB[username]
        return UserInDB(**user_dict)
    return None

def authenticate_user(username: str, password: str) -> Optional[UserInDB]:
    """Authenticate user credentials"""
    user = get_user(username)
    if not user:
        return None
    if not verify_password(password, user.hashed_password):
        return None
    return user

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """Create JWT access token"""
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

# JWT security scheme
bearer_scheme = HTTPBearer()

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> User:
    """Get current user from JWT token"""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    try:
        token = credentials.credentials
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        role: str = payload.get("role")
        
        if username is None:
            raise credentials_exception
        
        token_data = TokenData(username=username, role=role)
    except JWTError:
        raise credentials_exception
    
    user = get_user(username=token_data.username)
    if user is None:
        raise credentials_exception
    
    return User(**user.dict())

async def require_admin_jwt(current_user: User = Depends(get_current_user)):
    """Require admin role for JWT-authenticated users"""
    if current_user.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Admin access required"
        )
    return current_user

# JWT endpoints
@app.post("/jwt/login", response_model=Token)
async def login(login_request: LoginRequest):
    """Login with username/password and receive JWT token"""
    user = authenticate_user(login_request.username, login_request.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username, "role": user.role},
        expires_delta=access_token_expires
    )
    
    return {"access_token": access_token, "token_type": "bearer"}

@app.get("/jwt/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
    """Get current user information"""
    return current_user

@app.get("/jwt/protected")
async def jwt_protected_endpoint(current_user: User = Depends(get_current_user)):
    """Protected endpoint requiring JWT token"""
    return {
        "message": "Access granted",
        "user": current_user.username,
        "role": current_user.role,
        "data": "Sensitive information"
    }

@app.get("/jwt/admin")
async def jwt_admin_endpoint(current_user: User = Depends(require_admin_jwt)):
    """Admin-only endpoint with JWT"""
    return {
        "message": "Admin access granted",
        "admin": current_user.username,
        "all_users": list(USERS_DB.keys())
    }

print("JWT authentication endpoints added!")

## 3. Rate Limiting

Protect your API from abuse with rate limiting.

In [None]:
# Simple in-memory rate limiter
class RateLimiter:
    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests = defaultdict(list)
    
    def is_allowed(self, identifier: str) -> tuple:
        """Check if request is allowed"""
        now = time.time()
        window_start = now - self.window_seconds
        
        # Remove old requests
        self.requests[identifier] = [
            req_time for req_time in self.requests[identifier]
            if req_time > window_start
        ]
        
        # Check limit
        current_requests = len(self.requests[identifier])
        
        if current_requests >= self.max_requests:
            # Calculate retry after
            oldest_request = self.requests[identifier][0]
            retry_after = int(oldest_request + self.window_seconds - now)
            
            return False, {
                "allowed": False,
                "current_requests": current_requests,
                "max_requests": self.max_requests,
                "window_seconds": self.window_seconds,
                "retry_after": retry_after
            }
        
        # Add current request
        self.requests[identifier].append(now)
        
        return True, {
            "allowed": True,
            "current_requests": current_requests + 1,
            "max_requests": self.max_requests,
            "remaining": self.max_requests - current_requests - 1
        }

# Create rate limiters
rate_limiter = RateLimiter(max_requests=10, window_seconds=60)  # 10 requests per minute
strict_limiter = RateLimiter(max_requests=3, window_seconds=10)  # 3 requests per 10 seconds

async def check_rate_limit(request: Request, limiter: RateLimiter = rate_limiter):
    """Rate limiting dependency"""
    # Use IP address as identifier (in production, use user ID or API key)
    identifier = request.client.host
    
    allowed, info = limiter.is_allowed(identifier)
    
    if not allowed:
        raise HTTPException(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            detail=f"Rate limit exceeded. Try again in {info['retry_after']} seconds.",
            headers={"Retry-After": str(info["retry_after"])}
        )
    
    return info

# Rate limited endpoints
@app.get("/ratelimit/test")
async def rate_limited_endpoint(request: Request):
    """Rate limited endpoint (10 requests per minute)"""
    rate_info = await check_rate_limit(request, rate_limiter)
    return {
        "message": "Request successful",
        "rate_limit": rate_info
    }

@app.get("/ratelimit/strict")
async def strict_rate_limited(request: Request):
    """Strictly rate limited endpoint (3 per 10 seconds)"""
    rate_info = await check_rate_limit(request, strict_limiter)
    return {
        "message": "Request successful",
        "rate_limit": rate_info,
        "note": "This endpoint has strict rate limiting"
    }

print("Rate limiting endpoints added!")

## 4. CORS Configuration

Enable Cross-Origin Resource Sharing for browser-based clients.

In [None]:
# CORS middleware configuration
origins = [
    "http://localhost",
    "http://localhost:3000",  # React default
    "http://localhost:8080",  # Vue default
    "http://127.0.0.1:3000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,  # In production, specify exact origins
    allow_credentials=True,
    allow_methods=["*"],  # Or specify: ["GET", "POST"]
    allow_headers=["*"],  # Or specify: ["Authorization", "Content-Type"]
)

@app.get("/cors/test")
async def cors_test():
    """Test CORS configuration"""
    return {
        "message": "CORS enabled",
        "allowed_origins": origins,
        "note": "This endpoint can be accessed from allowed origins"
    }

print("CORS middleware configured!")

## 5. Input Validation and Sanitization

Protect against injection attacks and invalid input.

In [None]:
import re
import html

class SecureInput(BaseModel):
    """Secure input with validation and sanitization"""
    username: str = Field(..., min_length=3, max_length=50)
    email: EmailStr
    message: str = Field(..., max_length=1000)
    url: Optional[str] = Field(None, max_length=500)
    
    @validator('username')
    def validate_username(cls, v):
        """Validate username - alphanumeric and underscores only"""
        if not re.match(r'^[a-zA-Z0-9_]+$', v):
            raise ValueError('Username must be alphanumeric (underscores allowed)')
        
        # Prevent common injection patterns
        dangerous_patterns = ['<script', 'javascript:', 'onerror=', 'onclick=']
        v_lower = v.lower()
        if any(pattern in v_lower for pattern in dangerous_patterns):
            raise ValueError('Invalid username - suspicious pattern detected')
        
        return v
    
    @validator('message')
    def sanitize_message(cls, v):
        """Sanitize message - escape HTML"""
        # Remove any HTML tags
        v = html.escape(v)
        
        # Additional sanitization
        v = v.strip()
        
        return v
    
    @validator('url')
    def validate_url(cls, v):
        """Validate URL - must be http/https"""
        if v is None:
            return v
        
        if not v.startswith(('http://', 'https://')):
            raise ValueError('URL must start with http:// or https://')
        
        # Prevent javascript: protocol
        if 'javascript:' in v.lower():
            raise ValueError('Invalid URL protocol')
        
        return v

class SQLQueryInput(BaseModel):
    """Example of SQL query input validation"""
    table_name: str = Field(..., pattern=r'^[a-zA-Z_][a-zA-Z0-9_]*$')
    column_name: str = Field(..., pattern=r'^[a-zA-Z_][a-zA-Z0-9_]*$')
    value: str = Field(..., max_length=100)
    
    @validator('value')
    def prevent_sql_injection(cls, v):
        """Basic SQL injection prevention"""
        # In production, use parameterized queries!
        dangerous = ["'", '"', ';', '--', '/*', '*/', 'xp_', 'sp_', 'DROP', 'DELETE', 'INSERT']
        v_upper = v.upper()
        if any(pattern in v_upper for pattern in dangerous):
            raise ValueError('Invalid input - potential SQL injection detected')
        return v

@app.post("/secure/submit")
async def submit_secure_data(data: SecureInput):
    """Submit data with validation and sanitization"""
    return {
        "message": "Data received and validated",
        "data": data.dict(),
        "note": "All input has been validated and sanitized"
    }

@app.post("/secure/query")
async def secure_query(query: SQLQueryInput):
    """Example of secure query input"""
    # In production, use parameterized queries
    return {
        "message": "Query validated",
        "query": query.dict(),
        "note": "In production, use parameterized queries with SQLAlchemy or similar"
    }

print("Input validation endpoints added!")

## 6. Running the Server

In [None]:
# Create TestClient
client = TestClient(app)

print(f"\nâœ… Client ready for testing")
print(f"ðŸ“š The app would normally run at http://127.0.0.1:8000/docs in production")

## 7. Testing Authentication and Security

In [None]:
# Test 1: API Key Authentication
print("Test 1 - API Key Authentication:")

# Try without API key
response = client.get("/protected/data")
print(f"Without API key: {response.status_code} - {response.json()['detail']}")

# Try with valid API key
headers = {"X-API-Key": "test_key_123"}
response = client.get("/protected/data", headers=headers)
print(f"With valid API key: {response.status_code}")
print(f"Response: {response.json()}")
print()

In [None]:
# Test 2: Admin access with API key
print("Test 2 - Admin access:")

# Try with user API key (should fail)
headers = {"X-API-Key": "test_key_123"}
response = client.get("/admin/users", headers=headers)
print(f"User key on admin endpoint: {response.status_code} - {response.json()['detail']}")

# Try with admin API key (should succeed)
headers = {"X-API-Key": "admin_key_456"}
response = client.get("/admin/users", headers=headers)
print(f"Admin key on admin endpoint: {response.status_code}")
print(f"Response: {response.json()['message']}")
print()

In [None]:
# Test 3: JWT Login and Token Usage
print("Test 3 - JWT Authentication:")

# Login to get token
login_data = {"username": "john", "password": "secret123"}
response = client.post("/jwt/login", json=login_data)
token_data = response.json()
access_token = token_data["access_token"]
print(f"Login successful, token received")
print(f"Token (first 50 chars): {access_token[:50]}...")

# Use token to access protected endpoint
headers = {"Authorization": f"Bearer {access_token}"}
response = client.get("/jwt/me", headers=headers)
print(f"\nUser info: {response.json()}")

# Access protected data
response = client.get("/jwt/protected", headers=headers)
print(f"Protected endpoint: {response.json()['message']}")
print()

In [None]:
# Test 4: JWT Admin Access
print("Test 4 - JWT Admin Access:")

# Login as regular user
login_data = {"username": "john", "password": "secret123"}
response = client.post("/jwt/login", json=login_data)
user_token = response.json()["access_token"]

# Try to access admin endpoint (should fail)
headers = {"Authorization": f"Bearer {user_token}"}
response = client.get("/jwt/admin", headers=headers)
print(f"User accessing admin endpoint: {response.status_code} - {response.json()['detail']}")

# Login as admin
login_data = {"username": "alice", "password": "admin456"}
response = client.post("/jwt/login", json=login_data)
admin_token = response.json()["access_token"]

# Access admin endpoint (should succeed)
headers = {"Authorization": f"Bearer {admin_token}"}
response = client.get("/jwt/admin", headers=headers)
print(f"Admin accessing admin endpoint: {response.status_code}")
print(f"Response: {response.json()['message']}")
print()

In [None]:
# Test 5: Rate Limiting
print("Test 5 - Rate Limiting:")

# Make requests until rate limited
for i in range(5):
    response = client.get("/ratelimit/strict")
    if response.status_code == 200:
        rate_info = response.json()["rate_limit"]
        print(f"Request {i+1}: Success - {rate_info['remaining']} remaining")
    else:
        print(f"Request {i+1}: Rate limited! {response.json()['detail']}")
        break
print()

In [None]:
# Test 6: Input Validation
print("Test 6 - Input Validation:")

# Valid input
valid_data = {
    "username": "john_doe",
    "email": "john@example.com",
    "message": "This is a <b>safe</b> message",
    "url": "https://example.com"
}
response = client.post("/secure/submit", json=valid_data)
print(f"Valid input: {response.status_code}")
print(f"Sanitized message: {response.json()['data']['message']}")
print()

# Invalid username (XSS attempt)
invalid_data = {
    "username": "user<script>alert('xss')</script>",
    "email": "test@example.com",
    "message": "Test"
}
response = client.post("/secure/submit", json=invalid_data)
print(f"Invalid username (XSS): {response.status_code}")
print(f"Error: {response.json()['detail'][0]['msg']}")
print()

# Invalid URL
invalid_data = {
    "username": "testuser",
    "email": "test@example.com",
    "message": "Test",
    "url": "javascript:alert('xss')"
}
response = client.post("/secure/submit", json=invalid_data)
print(f"Invalid URL (JavaScript): {response.status_code}")
print(f"Error: {response.json()['detail'][0]['msg']}")
print()

In [None]:
# Test 7: SQL Injection Prevention
print("Test 7 - SQL Injection Prevention:")

# Valid query
valid_query = {
    "table_name": "users",
    "column_name": "username",
    "value": "john"
}
response = client.post("/secure/query", json=valid_query)
print(f"Valid query: {response.status_code} - Success")

# SQL injection attempt
injection_query = {
    "table_name": "users",
    "column_name": "username",
    "value": "john' OR '1'='1"
}
response = client.post("/secure/query", json=injection_query)
print(f"SQL injection attempt: {response.status_code}")
print(f"Error: {response.json()['detail'][0]['msg']}")
print()

## 8. Key Takeaways

### What we learned:

1. **API Key Authentication**:
   - Simple authentication using headers
   - Good for service-to-service communication
   - Use `APIKeyHeader` security scheme
   - Implement role-based access control
   - Store API keys securely (hash them!)

2. **JWT Authentication**:
   - Stateless authentication with tokens
   - Tokens contain user info and expiration
   - Use `HTTPBearer` security scheme
   - Hash passwords with bcrypt
   - Set appropriate token expiration times

3. **Rate Limiting**:
   - Protect against abuse and DoS
   - Sliding window algorithm
   - Return 429 status code when limited
   - Include Retry-After header
   - Different limits for different endpoints

4. **CORS**:
   - Enable cross-origin requests from browsers
   - Specify allowed origins explicitly in production
   - Configure allowed methods and headers
   - Enable credentials if needed

5. **Input Validation**:
   - Always validate and sanitize user input
   - Use Pydantic validators
   - Escape HTML to prevent XSS
   - Validate URLs and prevent dangerous protocols
   - Use parameterized queries for SQL

### Security Best Practices:

1. **Never**:
   - Store passwords in plain text
   - Trust user input without validation
   - Use string concatenation for SQL queries
   - Expose sensitive data in error messages
   - Store secrets in code (use environment variables)

2. **Always**:
   - Use HTTPS in production
   - Hash passwords with bcrypt or argon2
   - Validate all input
   - Use parameterized queries
   - Set appropriate token expiration
   - Implement rate limiting
   - Log authentication failures
   - Keep dependencies updated

3. **Production Considerations**:
   - Store secrets in environment variables or secret manager
   - Use Redis or database for rate limiting (not in-memory)
   - Implement request signing for APIs
   - Add monitoring and alerting
   - Regular security audits
   - Implement account lockout after failed attempts
   - Use refresh tokens for long-lived sessions

### Authentication Flow Comparison:

**API Keys:**
- Best for: Service-to-service, internal APIs
- Pros: Simple, stateless
- Cons: No expiration, harder to revoke

**JWT Tokens:**
- Best for: User authentication, mobile apps
- Pros: Stateless, contains user info, expiration
- Cons: Cannot revoke before expiration, larger size

**Sessions:**
- Best for: Traditional web apps
- Pros: Can revoke immediately, server controls state
- Cons: Requires server storage, not stateless

In [None]:
print(f"\nðŸŽ‰ Congratulations! You've completed Authentication and Security!\n")
print(f"You now know how to:")
print(f"  âœ“ Implement API key authentication")
print(f"  âœ“ Use JWT tokens for user authentication")
print(f"  âœ“ Add rate limiting to protect your API")
print(f"  âœ“ Configure CORS for browser clients")
print(f"  âœ“ Validate and sanitize user input")
print(f"\nYou're ready to build secure production APIs with FastAPI!")
print(f"\nIn production, run: uvicorn app:app --host 0.0.0.0 --port 8000")
print(f"\nTest credentials:")
print(f"  Username: john, Password: secret123 (user role)")
print(f"  Username: alice, Password: admin456 (admin role)")
print(f"  API Key: test_key_123 (user), admin_key_456 (admin)")