# Module 04: ML APIs with FastAPI

**Difficulty**: ⭐⭐ Intermediate
**Estimated Time**: 70 minutes
**Prerequisites**: 
- [Module 03: Model Serialization](03_model_serialization.ipynb)
- Basic understanding of REST APIs

## Learning Objectives
By the end of this notebook, you will be able to:
1. Create ML prediction APIs using FastAPI
2. Define request and response models with Pydantic
3. Implement input validation for ML endpoints
4. Handle errors gracefully in production APIs
5. Test API endpoints programmatically
6. Document APIs automatically with OpenAPI/Swagger

## 1. Why FastAPI for ML?

### The Challenge

You've trained a model. Now what?
- **How do web apps use your model?**
- **How do mobile apps get predictions?**
- **How do you validate inputs?**
- **How do you handle errors gracefully?**

### The Solution: REST APIs

**REST APIs** allow any application to request predictions over HTTP.

**Why FastAPI?**
1. **Fast**: High performance (based on Starlette and Pydantic)
2. **Easy**: Simple syntax, similar to Flask
3. **Automatic docs**: Interactive API documentation (Swagger UI)
4. **Type hints**: Python type hints for validation
5. **Async support**: Handle concurrent requests efficiently

### API Workflow

```
Client (Web/Mobile App)
    ↓ HTTP Request
FastAPI Server
    ↓ Validate input
    ↓ Load model
    ↓ Make prediction
    ↓ HTTP Response
Client receives prediction
```

In [None]:
# Setup: Import all required libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from pathlib import Path

# Machine learning
from sklearn.datasets import load_iris, make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

# FastAPI and Pydantic
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, validator
from typing import List, Optional
import uvicorn
from fastapi.testclient import TestClient

# Set random seed
np.random.seed(42)

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

print("Setup complete!")
print("FastAPI is ready for ML model serving")

## 2. Training and Saving a Model

First, let's train a model to serve via API.

In [None]:
# Load Iris dataset for classification
iris = load_iris()
X, y = iris.data, iris.target

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape}")
print(f"Features: {iris.feature_names}")
print(f"Classes: {iris.target_names}")

In [None]:
# Create and train a pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
])

pipeline.fit(X_train, y_train)

# Evaluate
accuracy = pipeline.score(X_test, y_test)
print(f"\n✓ Model trained")
print(f"  Accuracy: {accuracy:.4f}")

# Save model
models_dir = Path("api_models")
models_dir.mkdir(exist_ok=True)
model_path = models_dir / "iris_classifier.joblib"
joblib.dump(pipeline, model_path)

print(f"✓ Model saved to {model_path}")

## 3. Creating Your First FastAPI App

Let's build a simple prediction API.

In [None]:
# Create FastAPI app
app = FastAPI(
    title="Iris Classification API",
    description="ML API for predicting iris flower species",
    version="1.0.0"
)

# Load model at startup
model = joblib.load(model_path)

@app.get("/")
def root():
    """Health check endpoint"""
    return {
        "message": "Iris Classification API",
        "status": "running",
        "model_loaded": model is not None
    }

print("✓ FastAPI app created")
print("  Available endpoints: / (health check)")

## 4. Defining Request/Response Models with Pydantic

**Pydantic** provides automatic validation using Python type hints.

In [None]:
# Define request model (input to API)
class IrisFeatures(BaseModel):
    sepal_length: float = Field(..., ge=0, le=10, description="Sepal length in cm")
    sepal_width: float = Field(..., ge=0, le=10, description="Sepal width in cm")
    petal_length: float = Field(..., ge=0, le=10, description="Petal length in cm")
    petal_width: float = Field(..., ge=0, le=10, description="Petal width in cm")
    
    class Config:
        schema_extra = {
            "example": {
                "sepal_length": 5.1,
                "sepal_width": 3.5,
                "petal_length": 1.4,
                "petal_width": 0.2
            }
        }

# Define response model (output from API)
class PredictionResponse(BaseModel):
    prediction: str = Field(..., description="Predicted iris species")
    confidence: float = Field(..., ge=0, le=1, description="Prediction confidence")
    probabilities: dict = Field(..., description="Class probabilities")

print("✓ Pydantic models defined")
print("  Request: IrisFeatures")
print("  Response: PredictionResponse")

## 5. Creating Prediction Endpoint

Now let's add the actual prediction logic.

In [None]:
@app.post("/predict", response_model=PredictionResponse)
def predict(features: IrisFeatures):
    """
    Predict iris species from flower measurements.
    
    Returns the predicted class and confidence scores.
    """
    try:
        # Convert input to numpy array
        input_data = np.array([[
            features.sepal_length,
            features.sepal_width,
            features.petal_length,
            features.petal_width
        ]])
        
        # Make prediction
        prediction = model.predict(input_data)[0]
        probabilities = model.predict_proba(input_data)[0]
        
        # Map prediction to class name
        class_names = ["setosa", "versicolor", "virginica"]
        predicted_class = class_names[prediction]
        
        # Create probability dictionary
        prob_dict = {
            class_name: float(prob)
            for class_name, prob in zip(class_names, probabilities)
        }
        
        return PredictionResponse(
            prediction=predicted_class,
            confidence=float(probabilities.max()),
            probabilities=prob_dict
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

print("✓ Prediction endpoint created")
print("  POST /predict - Make predictions")

## 6. Testing the API

FastAPI provides a TestClient for testing without running a server.

In [None]:
# Create test client
client = TestClient(app)

# Test health check endpoint
response = client.get("/")
print("Health Check:")
print(f"  Status: {response.status_code}")
print(f"  Response: {response.json()}")

In [None]:
# Test prediction endpoint with valid input
test_input = {
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
}

response = client.post("/predict", json=test_input)
print("\nPrediction Test:")
print(f"  Input: {test_input}")
print(f"  Status: {response.status_code}")
print(f"  Response: {response.json()}")

## 7. Input Validation

Pydantic automatically validates inputs. Let's test with invalid data.

In [None]:
# Test with invalid input (negative value)
invalid_input = {
    "sepal_length": -1.0,  # Invalid: negative
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
}

response = client.post("/predict", json=invalid_input)
print("Invalid Input Test (negative value):")
print(f"  Status: {response.status_code}")
print(f"  Error: {response.json()['detail']}")

In [None]:
# Test with missing field
incomplete_input = {
    "sepal_length": 5.1,
    "sepal_width": 3.5
    # Missing petal_length and petal_width
}

response = client.post("/predict", json=incomplete_input)
print("\nIncomplete Input Test:")
print(f"  Status: {response.status_code}")
print(f"  Error: {response.json()['detail'][:200]}...")  # Truncated for display

## 8. Batch Predictions

Let's add an endpoint for batch predictions.

In [None]:
# Define batch request/response models
class BatchPredictionRequest(BaseModel):
    samples: List[IrisFeatures]

class BatchPredictionResponse(BaseModel):
    predictions: List[PredictionResponse]
    count: int

@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
    """
    Make predictions for multiple samples at once.
    """
    predictions = []
    
    for sample in request.samples:
        # Reuse the single prediction logic
        pred = predict(sample)
        predictions.append(pred)
    
    return BatchPredictionResponse(
        predictions=predictions,
        count=len(predictions)
    )

print("✓ Batch prediction endpoint created")
print("  POST /predict/batch - Predict multiple samples")

In [None]:
# Test batch prediction
batch_input = {
    "samples": [
        {"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2},
        {"sepal_length": 6.2, "sepal_width": 2.9, "petal_length": 4.3, "petal_width": 1.3},
        {"sepal_length": 7.7, "sepal_width": 3.0, "petal_length": 6.1, "petal_width": 2.3}
    ]
}

response = client.post("/predict/batch", json=batch_input)
print("Batch Prediction Test:")
print(f"  Input samples: {len(batch_input['samples'])}")
print(f"  Status: {response.status_code}")

result = response.json()
print(f"\nPredictions:")
for i, pred in enumerate(result['predictions']):
    print(f"  Sample {i+1}: {pred['prediction']} (confidence: {pred['confidence']:.3f})")

## 9. Advanced Validation with Pydantic

Add custom validators for domain-specific rules.

In [None]:
# Create enhanced request model with custom validation
class EnhancedIrisFeatures(BaseModel):
    sepal_length: float = Field(..., ge=0, le=10)
    sepal_width: float = Field(..., ge=0, le=10)
    petal_length: float = Field(..., ge=0, le=10)
    petal_width: float = Field(..., ge=0, le=10)
    
    @validator('petal_width')
    def petal_width_must_be_less_than_length(cls, v, values):
        """Validate that petal width is typically less than petal length"""
        if 'petal_length' in values and v > values['petal_length']:
            raise ValueError('Petal width should not exceed petal length')
        return v
    
    @validator('sepal_width')
    def sepal_width_reasonable(cls, v, values):
        """Validate sepal width is reasonable compared to length"""
        if 'sepal_length' in values and v > values['sepal_length'] * 1.5:
            raise ValueError('Sepal width seems unreasonably large')
        return v

@app.post("/predict/validated")
def predict_validated(features: EnhancedIrisFeatures):
    """Prediction with enhanced validation"""
    # Convert to standard model for prediction
    standard_features = IrisFeatures(**features.dict())
    return predict(standard_features)

print("✓ Enhanced validation endpoint created")

In [None]:
# Test with biologically unrealistic data
unrealistic_input = {
    "sepal_length": 5.0,
    "sepal_width": 3.0,
    "petal_length": 2.0,
    "petal_width": 3.0  # Petal width > petal length (unusual)
}

response = client.post("/predict/validated", json=unrealistic_input)
print("Validation Test (unrealistic data):")
print(f"  Status: {response.status_code}")
if response.status_code == 422:
    print(f"  Validation caught the issue!")
    print(f"  Error: {response.json()['detail'][0]['msg']}")

## 10. Error Handling

Proper error handling is crucial for production APIs.

In [None]:
from fastapi import status

@app.post("/predict/robust", response_model=PredictionResponse)
def predict_robust(features: IrisFeatures):
    """
    Prediction endpoint with comprehensive error handling.
    """
    try:
        # Validate model is loaded
        if model is None:
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                detail="Model not loaded. Service temporarily unavailable."
            )
        
        # Convert to array
        input_data = np.array([[
            features.sepal_length,
            features.sepal_width,
            features.petal_length,
            features.petal_width
        ]])
        
        # Check for NaN or Inf
        if not np.isfinite(input_data).all():
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Input contains invalid values (NaN or Inf)"
            )
        
        # Make prediction
        prediction = model.predict(input_data)[0]
        probabilities = model.predict_proba(input_data)[0]
        
        class_names = ["setosa", "versicolor", "virginica"]
        predicted_class = class_names[prediction]
        
        prob_dict = {
            class_name: float(prob)
            for class_name, prob in zip(class_names, probabilities)
        }
        
        return PredictionResponse(
            prediction=predicted_class,
            confidence=float(probabilities.max()),
            probabilities=prob_dict
        )
    
    except HTTPException:
        raise  # Re-raise HTTP exceptions
    
    except Exception as e:
        # Log the error in production
        print(f"Unexpected error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Internal server error during prediction"
        )

print("✓ Robust prediction endpoint created")

## 11. Exercises

Practice building ML APIs with FastAPI.

### Exercise 1: Add Model Info Endpoint

Create an endpoint that returns model metadata.

**Requirements**:
1. Create GET endpoint `/model/info`
2. Return model type, training accuracy, feature names
3. Test the endpoint

In [None]:
# Exercise 1: Your code here

# YOUR CODE HERE

In [None]:
# Exercise 1 Solution

class ModelInfo(BaseModel):
    model_type: str
    training_accuracy: float
    features: List[str]
    classes: List[str]

@app.get("/model/info", response_model=ModelInfo)
def get_model_info():
    """Get information about the loaded model"""
    return ModelInfo(
        model_type="RandomForestClassifier",
        training_accuracy=accuracy,
        features=["sepal_length", "sepal_width", "petal_length", "petal_width"],
        classes=["setosa", "versicolor", "virginica"]
    )

# Test
response = client.get("/model/info")
print("✓ Model Info Endpoint:")
print(f"  Status: {response.status_code}")
print(f"  Response: {response.json()}")

### Exercise 2: Prediction with Explanation

Enhance the prediction endpoint to include feature importance.

**Requirements**:
1. Get feature importances from the model
2. Return top 3 most important features with values
3. Test with a sample input

In [None]:
# Exercise 2 Solution

class ExplainedPrediction(BaseModel):
    prediction: str
    confidence: float
    top_features: dict

@app.post("/predict/explain", response_model=ExplainedPrediction)
def predict_with_explanation(features: IrisFeatures):
    """Prediction with feature importance explanation"""
    # Make prediction
    input_data = np.array([[
        features.sepal_length, features.sepal_width,
        features.petal_length, features.petal_width
    ]])
    
    prediction = model.predict(input_data)[0]
    probabilities = model.predict_proba(input_data)[0]
    
    # Get feature importances
    classifier = model.named_steps['classifier']
    importances = classifier.feature_importances_
    feature_names = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
    
    # Get top 3 features
    top_indices = np.argsort(importances)[-3:][::-1]
    top_features = {
        feature_names[i]: float(importances[i])
        for i in top_indices
    }
    
    class_names = ["setosa", "versicolor", "virginica"]
    
    return ExplainedPrediction(
        prediction=class_names[prediction],
        confidence=float(probabilities.max()),
        top_features=top_features
    )

# Test
test_input = {
    "sepal_length": 6.3,
    "sepal_width": 2.5,
    "petal_length": 5.0,
    "petal_width": 1.9
}

response = client.post("/predict/explain", json=test_input)
print("✓ Explained Prediction:")
result = response.json()
print(f"  Prediction: {result['prediction']}")
print(f"  Confidence: {result['confidence']:.3f}")
print(f"  Top Features: {result['top_features']}")

### Exercise 3: Performance Metrics Endpoint

Track API usage statistics.

**Requirements**:
1. Create a simple counter for predictions made
2. Add GET endpoint `/metrics` that returns total predictions
3. Increment counter in prediction endpoint

In [None]:
# Exercise 3 Solution

# Global counter (in production, use proper state management)
prediction_counter = {"total": 0}

class Metrics(BaseModel):
    total_predictions: int
    model_version: str

@app.get("/metrics", response_model=Metrics)
def get_metrics():
    """Get API usage metrics"""
    return Metrics(
        total_predictions=prediction_counter["total"],
        model_version="1.0.0"
    )

@app.post("/predict/tracked")
def predict_with_tracking(features: IrisFeatures):
    """Prediction with usage tracking"""
    # Increment counter
    prediction_counter["total"] += 1
    
    # Make prediction
    return predict(features)

# Test
# Make a few predictions
for _ in range(3):
    client.post("/predict/tracked", json=test_input)

# Check metrics
response = client.get("/metrics")
print("✓ API Metrics:")
print(f"  {response.json()}")

## 12. Summary

### Key Takeaways

1. **FastAPI makes ML APIs easy** with automatic validation and documentation

2. **Pydantic models** provide type-safe request/response validation

3. **Input validation** prevents bad data from reaching your model

4. **Error handling** ensures graceful failures in production

5. **TestClient** enables testing without running a server

6. **Batch endpoints** improve efficiency for multiple predictions

### API Best Practices

- Always validate inputs with Pydantic
- Return structured responses (not just raw predictions)
- Include confidence scores with predictions
- Handle errors gracefully with proper HTTP status codes
- Add health check endpoints
- Document your API (FastAPI does this automatically)
- Version your APIs (/v1/predict, /v2/predict)
- Log predictions for monitoring

### What's Next?

In **Module 05**, we'll learn about:
- **Containerizing ML applications** with Docker
- **Creating Dockerfiles** for ML projects
- **Building and running** Docker images
- **Docker best practices** for ML

## 13. Additional Resources

### Documentation
- **FastAPI**: https://fastapi.tiangolo.com/
- **Pydantic**: https://pydantic-docs.helpmanual.io/
- **Uvicorn**: https://www.uvicorn.org/

### Tutorials
- **FastAPI Tutorial**: https://fastapi.tiangolo.com/tutorial/
- **ML with FastAPI**: https://testdriven.io/blog/fastapi-machine-learning/

### Advanced Topics
- Async endpoints for better concurrency
- Authentication and API keys
- Rate limiting
- CORS configuration
- Deploying FastAPI to cloud platforms