# 7.1 Production Model Serving

## Learning Objectives
- Export trained models for production
- Deploy with Ray Serve for scalable inference
- Implement A/B testing and canary deployments
- Monitor model performance in production

In [None]:
import ray
from ray import serve
from ray.rllib.algorithms.ppo import PPOConfig, PPO
from ray.rllib.algorithms.algorithm import Algorithm
import numpy as np
import time
import json
from typing import Dict, List
import requests

ray.init(ignore_reinit_error=True)

## Production RL Pipeline

```
┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│   Train     │───▶│   Export    │───▶│   Deploy    │───▶│   Monitor   │
│   Model     │    │   Model     │    │   Serve     │    │   Metrics   │
└─────────────┘    └─────────────┘    └─────────────┘    └─────────────┘
                          │                   │                  │
                          ▼                   ▼                  ▼
                    Checkpoint           REST API           Dashboards
                    ONNX Export          gRPC              Alerting
                    TorchScript          Batching           A/B Tests
```

## Training and Exporting a Model

In [None]:
# Train a model
config = (
    PPOConfig()
    .environment("CartPole-v1")
    .framework("torch")
    .env_runners(num_env_runners=2)
    .training(
        lr=3e-4,
        train_batch_size=4000,
    )
)

algo = config.build()

# Train for a few iterations
for i in range(10):
    result = algo.train()
    print(f"Iter {i+1}: Reward = {result['env_runners']['episode_reward_mean']:.2f}")

# Save checkpoint
checkpoint_path = algo.save()
print(f"\nModel saved to: {checkpoint_path}")

In [None]:
# Export policy for inference
policy = algo.get_policy()
model = policy.model

# For PyTorch models, you can export to TorchScript
import torch

# Get a sample observation for tracing
import gymnasium as gym
env = gym.make("CartPole-v1")
sample_obs, _ = env.reset()

print(f"Model type: {type(model)}")
print(f"Sample observation shape: {sample_obs.shape}")

## Ray Serve Deployment

In [None]:
@serve.deployment(
    num_replicas=2,
    ray_actor_options={"num_cpus": 1},
)
class RLPolicyServer:
    """
    Ray Serve deployment for RL policy inference.
    """
    
    def __init__(self, checkpoint_path: str):
        # Load the trained algorithm
        self.algo = Algorithm.from_checkpoint(checkpoint_path)
        self.request_count = 0
        self.total_latency = 0
    
    async def __call__(self, request) -> Dict:
        """Handle inference request."""
        start_time = time.time()
        
        # Parse request
        data = await request.json()
        observation = np.array(data["observation"])
        
        # Get action from policy
        action = self.algo.compute_single_action(observation)
        
        # Track metrics
        latency = time.time() - start_time
        self.request_count += 1
        self.total_latency += latency
        
        return {
            "action": int(action) if isinstance(action, (np.integer, int)) else action.tolist(),
            "latency_ms": latency * 1000,
        }
    
    def get_metrics(self) -> Dict:
        """Return server metrics."""
        avg_latency = self.total_latency / max(1, self.request_count)
        return {
            "request_count": self.request_count,
            "avg_latency_ms": avg_latency * 1000,
        }

In [None]:
# Deploy the server
serve.start()

# Create deployment with checkpoint path
deployment = RLPolicyServer.bind(checkpoint_path)
handle = serve.run(deployment, name="rl-policy")

print("Server deployed!")
print("Endpoint: http://localhost:8000/")

In [None]:
# Test the deployed server
import requests

# Send inference request
response = requests.post(
    "http://localhost:8000/",
    json={"observation": sample_obs.tolist()}
)

result = response.json()
print(f"Response: {result}")

## Batch Inference for High Throughput

In [None]:
@serve.deployment(
    num_replicas=2,
    max_ongoing_requests=100,
)
class BatchedPolicyServer:
    """
    Policy server with batched inference for high throughput.
    """
    
    def __init__(self, checkpoint_path: str):
        self.algo = Algorithm.from_checkpoint(checkpoint_path)
    
    @serve.batch(max_batch_size=32, batch_wait_timeout_s=0.01)
    async def batched_inference(self, observations: List[np.ndarray]) -> List[Dict]:
        """Process multiple observations in a batch."""
        start_time = time.time()
        
        # Stack observations for batch inference
        batch_obs = np.stack(observations)
        
        # Batch compute actions
        actions = self.algo.compute_actions(
            {"default_policy": batch_obs}
        )["default_policy"]
        
        latency = time.time() - start_time
        
        # Return individual results
        return [
            {"action": int(a), "batch_size": len(observations), "latency_ms": latency * 1000}
            for a in actions
        ]
    
    async def __call__(self, request) -> Dict:
        data = await request.json()
        observation = np.array(data["observation"])
        return await self.batched_inference(observation)

print("Batched server class defined")

## A/B Testing and Canary Deployments

In [None]:
@serve.deployment
class PolicyRouter:
    """
    Route requests between multiple policy versions for A/B testing.
    """
    
    def __init__(self, policy_a, policy_b, traffic_split: float = 0.9):
        self.policy_a = policy_a  # Production policy (90%)
        self.policy_b = policy_b  # Canary policy (10%)
        self.traffic_split = traffic_split
        
        # Metrics
        self.requests_a = 0
        self.requests_b = 0
    
    async def __call__(self, request) -> Dict:
        # Route based on traffic split
        if np.random.random() < self.traffic_split:
            self.requests_a += 1
            result = await self.policy_a.__call__(request)
            result["policy_version"] = "A"
        else:
            self.requests_b += 1
            result = await self.policy_b.__call__(request)
            result["policy_version"] = "B"
        
        return result
    
    def get_traffic_stats(self) -> Dict:
        total = self.requests_a + self.requests_b
        return {
            "policy_a_requests": self.requests_a,
            "policy_b_requests": self.requests_b,
            "policy_a_percentage": self.requests_a / max(1, total) * 100,
            "policy_b_percentage": self.requests_b / max(1, total) * 100,
        }

print("A/B testing router defined")

## Model Versioning and Registry

In [None]:
import os
from datetime import datetime

class ModelRegistry:
    """
    Simple model registry for versioning RL policies.
    """
    
    def __init__(self, registry_path: str = "/tmp/model_registry"):
        self.registry_path = registry_path
        os.makedirs(registry_path, exist_ok=True)
        self.metadata_file = os.path.join(registry_path, "metadata.json")
        self._load_metadata()
    
    def _load_metadata(self):
        if os.path.exists(self.metadata_file):
            with open(self.metadata_file, 'r') as f:
                self.metadata = json.load(f)
        else:
            self.metadata = {"models": []}
    
    def _save_metadata(self):
        with open(self.metadata_file, 'w') as f:
            json.dump(self.metadata, f, indent=2)
    
    def register_model(
        self,
        checkpoint_path: str,
        model_name: str,
        metrics: Dict,
        tags: List[str] = None
    ) -> str:
        """Register a new model version."""
        version = len([m for m in self.metadata["models"] if m["name"] == model_name]) + 1
        version_id = f"{model_name}_v{version}"
        
        model_info = {
            "version_id": version_id,
            "name": model_name,
            "version": version,
            "checkpoint_path": checkpoint_path,
            "metrics": metrics,
            "tags": tags or [],
            "registered_at": datetime.now().isoformat(),
            "status": "staging",
        }
        
        self.metadata["models"].append(model_info)
        self._save_metadata()
        
        print(f"Registered model: {version_id}")
        return version_id
    
    def promote_to_production(self, version_id: str):
        """Promote a model to production status."""
        for model in self.metadata["models"]:
            if model["version_id"] == version_id:
                model["status"] = "production"
                model["promoted_at"] = datetime.now().isoformat()
        self._save_metadata()
        print(f"Promoted {version_id} to production")
    
    def get_production_model(self, model_name: str) -> Dict:
        """Get the current production model."""
        for model in reversed(self.metadata["models"]):
            if model["name"] == model_name and model["status"] == "production":
                return model
        return None
    
    def list_models(self, model_name: str = None) -> List[Dict]:
        """List all registered models."""
        if model_name:
            return [m for m in self.metadata["models"] if m["name"] == model_name]
        return self.metadata["models"]

# Example usage
registry = ModelRegistry()
version_id = registry.register_model(
    checkpoint_path=checkpoint_path,
    model_name="cartpole_ppo",
    metrics={"mean_reward": 450.0, "episodes": 1000},
    tags=["ppo", "cartpole", "v1"]
)

print(f"\nRegistered models: {registry.list_models()}")

## Production Monitoring

In [None]:
from collections import deque
import statistics

class PolicyMonitor:
    """
    Monitor RL policy performance in production.
    """
    
    def __init__(self, window_size: int = 1000):
        self.window_size = window_size
        self.latencies = deque(maxlen=window_size)
        self.rewards = deque(maxlen=window_size)
        self.actions = deque(maxlen=window_size)
        self.errors = deque(maxlen=window_size)
        
        # Alerting thresholds
        self.latency_threshold_ms = 100
        self.error_rate_threshold = 0.01
    
    def record_request(self, latency_ms: float, action: int, reward: float = None, error: bool = False):
        """Record a single request."""
        self.latencies.append(latency_ms)
        self.actions.append(action)
        self.errors.append(error)
        if reward is not None:
            self.rewards.append(reward)
    
    def get_metrics(self) -> Dict:
        """Get current monitoring metrics."""
        if not self.latencies:
            return {"status": "no_data"}
        
        metrics = {
            "request_count": len(self.latencies),
            "latency_p50_ms": statistics.median(self.latencies),
            "latency_p99_ms": sorted(self.latencies)[int(len(self.latencies) * 0.99)] if len(self.latencies) >= 100 else max(self.latencies),
            "latency_mean_ms": statistics.mean(self.latencies),
            "error_rate": sum(self.errors) / len(self.errors),
        }
        
        if self.rewards:
            metrics["mean_reward"] = statistics.mean(self.rewards)
        
        # Action distribution
        action_counts = {}
        for a in self.actions:
            action_counts[a] = action_counts.get(a, 0) + 1
        metrics["action_distribution"] = action_counts
        
        return metrics
    
    def check_alerts(self) -> List[str]:
        """Check for alerting conditions."""
        alerts = []
        metrics = self.get_metrics()
        
        if metrics.get("latency_p99_ms", 0) > self.latency_threshold_ms:
            alerts.append(f"HIGH_LATENCY: p99 latency {metrics['latency_p99_ms']:.1f}ms > {self.latency_threshold_ms}ms")
        
        if metrics.get("error_rate", 0) > self.error_rate_threshold:
            alerts.append(f"HIGH_ERROR_RATE: {metrics['error_rate']:.2%} > {self.error_rate_threshold:.2%}")
        
        return alerts

# Example usage
monitor = PolicyMonitor()

# Simulate some requests
for _ in range(100):
    monitor.record_request(
        latency_ms=np.random.exponential(5),
        action=np.random.choice([0, 1]),
        reward=np.random.uniform(-1, 1),
        error=np.random.random() < 0.005
    )

print("Monitoring metrics:")
print(json.dumps(monitor.get_metrics(), indent=2))
print(f"\nAlerts: {monitor.check_alerts()}")

## Clean Up

In [None]:
# Stop serve and cleanup
serve.shutdown()
algo.stop()
ray.shutdown()

print("Cleanup complete")

## Key Takeaways

1. **Ray Serve** provides scalable, production-ready model serving

2. **Batch inference** improves throughput significantly

3. **A/B testing** and canary deployments reduce risk

4. **Model registry** tracks versions and promotes to production

5. **Monitoring** is essential for production RL systems

## Next Steps

In the final section, we'll cover industry patterns and best practices.