# 8.1 Industry Best Practices for RL Systems

## Learning Objectives
- Design robust RL pipelines for production
- Implement safety and reliability patterns
- Handle real-world challenges (distribution shift, catastrophic forgetting)
- Build maintainable RL systems

## Production RL Architecture

```
┌────────────────────────────────────────────────────────────────────────┐
│                         Production RL System                            │
├────────────────────────────────────────────────────────────────────────┤
│                                                                        │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐               │
│  │   Data      │───▶│  Training   │───▶│  Validation │               │
│  │  Pipeline   │    │  Pipeline   │    │  Pipeline   │               │
│  └─────────────┘    └─────────────┘    └─────────────┘               │
│         │                  │                  │                       │
│         ▼                  ▼                  ▼                       │
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐               │
│  │  Feature    │    │   Model     │    │   Safety    │               │
│  │   Store     │    │  Registry   │    │   Checks    │               │
│  └─────────────┘    └─────────────┘    └─────────────┘               │
│                            │                  │                       │
│                            ▼                  ▼                       │
│                     ┌─────────────────────────────┐                  │
│                     │    Serving Infrastructure   │                  │
│                     │    (A/B Testing, Canary)   │                  │
│                     └─────────────────────────────┘                  │
│                                   │                                   │
│                                   ▼                                   │
│                     ┌─────────────────────────────┐                  │
│                     │   Monitoring & Alerting     │                  │
│                     │   (Drift, Performance)      │                  │
│                     └─────────────────────────────┘                  │
│                                                                        │
└────────────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass
from enum import Enum
import json
from datetime import datetime

## 1. Reward Engineering Best Practices

In [None]:
class RewardShaper:
    """
    Best practices for reward shaping in production RL.
    """
    
    def __init__(self):
        self.reward_history = []
        self.component_history = {}
    
    def compute_shaped_reward(
        self,
        raw_reward: float,
        components: Dict[str, float],
        weights: Dict[str, float],
        clip_range: tuple = (-10, 10),
        normalize: bool = True
    ) -> float:
        """
        Compute shaped reward from multiple components.
        
        Best practices:
        1. Use multiple reward components for interpretability
        2. Clip rewards to prevent extreme values
        3. Normalize for stable training
        4. Log components for debugging
        """
        # Combine weighted components
        shaped_reward = raw_reward
        for name, value in components.items():
            weight = weights.get(name, 1.0)
            shaped_reward += weight * value
            
            # Track component history
            if name not in self.component_history:
                self.component_history[name] = []
            self.component_history[name].append(value)
        
        # Clip extreme values
        shaped_reward = np.clip(shaped_reward, clip_range[0], clip_range[1])
        
        # Optional normalization
        if normalize and len(self.reward_history) > 100:
            mean = np.mean(self.reward_history[-1000:])
            std = np.std(self.reward_history[-1000:]) + 1e-8
            shaped_reward = (shaped_reward - mean) / std
        
        self.reward_history.append(shaped_reward)
        return shaped_reward
    
    def get_diagnostics(self) -> Dict:
        """Get reward diagnostics for debugging."""
        return {
            "total_reward_mean": np.mean(self.reward_history[-100:]) if self.reward_history else 0,
            "total_reward_std": np.std(self.reward_history[-100:]) if self.reward_history else 0,
            "component_means": {
                name: np.mean(vals[-100:]) 
                for name, vals in self.component_history.items()
            }
        }

# Example: Trading reward shaping
shaper = RewardShaper()

# Simulate reward computation
for _ in range(100):
    raw_pnl = np.random.normal(0, 100)  # Raw P&L
    components = {
        "sharpe_bonus": np.random.uniform(0, 1),     # Risk-adjusted return
        "drawdown_penalty": -np.random.uniform(0, 0.5),  # Drawdown penalty
        "turnover_cost": -np.random.uniform(0, 0.1),     # Transaction costs
    }
    weights = {"sharpe_bonus": 0.5, "drawdown_penalty": 1.0, "turnover_cost": 2.0}
    
    shaped = shaper.compute_shaped_reward(raw_pnl, components, weights)

print("Reward diagnostics:")
print(json.dumps(shaper.get_diagnostics(), indent=2))

## 2. Safety Constraints and Action Filtering

In [None]:
class SafetyLayer:
    """
    Safety layer for production RL policies.
    
    Implements:
    1. Action masking (disallow unsafe actions)
    2. Action clipping (bound action magnitude)
    3. Fallback policies (when primary fails)
    4. Constraint monitoring
    """
    
    def __init__(
        self,
        action_bounds: tuple = None,
        constraint_functions: List[Callable] = None,
        fallback_policy: Callable = None
    ):
        self.action_bounds = action_bounds
        self.constraint_functions = constraint_functions or []
        self.fallback_policy = fallback_policy
        
        # Tracking
        self.violations = []
        self.fallback_count = 0
    
    def is_action_safe(self, action, state) -> tuple:
        """Check if action satisfies all constraints."""
        for constraint_fn in self.constraint_functions:
            is_safe, reason = constraint_fn(action, state)
            if not is_safe:
                return False, reason
        return True, None
    
    def filter_action(self, action, state) -> np.ndarray:
        """
        Filter action through safety layer.
        
        Returns safe action, potentially modified or from fallback.
        """
        original_action = action.copy() if hasattr(action, 'copy') else action
        
        # 1. Clip to bounds
        if self.action_bounds:
            action = np.clip(action, self.action_bounds[0], self.action_bounds[1])
        
        # 2. Check constraints
        is_safe, violation_reason = self.is_action_safe(action, state)
        
        if not is_safe:
            self.violations.append({
                "timestamp": datetime.now().isoformat(),
                "original_action": str(original_action),
                "reason": violation_reason,
            })
            
            # 3. Use fallback if available
            if self.fallback_policy:
                self.fallback_count += 1
                action = self.fallback_policy(state)
        
        return action
    
    def get_safety_stats(self) -> Dict:
        """Get safety statistics."""
        return {
            "total_violations": len(self.violations),
            "fallback_count": self.fallback_count,
            "recent_violations": self.violations[-10:],
        }

# Example: Trading safety constraints
def max_position_constraint(action, state):
    """Ensure position doesn't exceed maximum."""
    current_position = state.get("position", 0)
    new_position = current_position + action
    max_position = 100
    
    if abs(new_position) > max_position:
        return False, f"Position {new_position} exceeds max {max_position}"
    return True, None

def risk_limit_constraint(action, state):
    """Ensure risk limits are maintained."""
    current_risk = state.get("var", 0)
    max_risk = 10000
    
    if current_risk > max_risk:
        return False, f"Risk {current_risk} exceeds limit {max_risk}"
    return True, None

# Create safety layer
safety = SafetyLayer(
    action_bounds=(-10, 10),
    constraint_functions=[max_position_constraint, risk_limit_constraint],
    fallback_policy=lambda s: np.array([0])  # Do nothing fallback
)

# Test safety layer
state = {"position": 95, "var": 5000}
risky_action = np.array([10])  # Would exceed position limit

safe_action = safety.filter_action(risky_action, state)
print(f"Original: {risky_action}, Safe: {safe_action}")
print(f"Safety stats: {safety.get_safety_stats()}")

## 3. Distribution Shift Detection

In [None]:
from scipy import stats

class DistributionShiftDetector:
    """
    Detect distribution shift in observations.
    
    Critical for production RL where environment may change.
    """
    
    def __init__(self, reference_size: int = 10000, window_size: int = 1000):
        self.reference_size = reference_size
        self.window_size = window_size
        self.reference_data = []
        self.current_window = []
        self.reference_stats = None
        self.shift_history = []
    
    def update_reference(self, observations: np.ndarray):
        """Update reference distribution from training data."""
        self.reference_data = observations[-self.reference_size:].tolist()
        self._compute_reference_stats()
    
    def _compute_reference_stats(self):
        """Compute statistics of reference distribution."""
        data = np.array(self.reference_data)
        self.reference_stats = {
            "mean": np.mean(data, axis=0),
            "std": np.std(data, axis=0),
            "percentiles": {
                "5": np.percentile(data, 5, axis=0),
                "95": np.percentile(data, 95, axis=0),
            }
        }
    
    def add_observation(self, obs: np.ndarray):
        """Add new observation to current window."""
        self.current_window.append(obs)
        if len(self.current_window) > self.window_size:
            self.current_window.pop(0)
    
    def check_shift(self, threshold: float = 0.05) -> Dict:
        """
        Check for distribution shift using statistical tests.
        
        Returns shift metrics and whether retraining is recommended.
        """
        if len(self.current_window) < self.window_size // 2:
            return {"status": "insufficient_data"}
        
        if self.reference_stats is None:
            return {"status": "no_reference"}
        
        current_data = np.array(self.current_window)
        reference_data = np.array(self.reference_data[:len(self.current_window)])
        
        # Kolmogorov-Smirnov test per dimension
        ks_results = []
        for dim in range(current_data.shape[1] if len(current_data.shape) > 1 else 1):
            curr = current_data[:, dim] if len(current_data.shape) > 1 else current_data
            ref = reference_data[:, dim] if len(reference_data.shape) > 1 else reference_data
            
            ks_stat, p_value = stats.ks_2samp(curr, ref)
            ks_results.append({"statistic": ks_stat, "p_value": p_value})
        
        # Check mean shift
        current_mean = np.mean(current_data, axis=0)
        mean_shift = np.abs(current_mean - self.reference_stats["mean"]) / (self.reference_stats["std"] + 1e-8)
        
        # Determine if shift detected
        shift_detected = any(r["p_value"] < threshold for r in ks_results) or np.any(mean_shift > 3)
        
        result = {
            "shift_detected": shift_detected,
            "ks_tests": ks_results,
            "mean_shift_zscore": mean_shift.tolist() if hasattr(mean_shift, 'tolist') else mean_shift,
            "recommendation": "retrain" if shift_detected else "continue",
        }
        
        self.shift_history.append({
            "timestamp": datetime.now().isoformat(),
            **result
        })
        
        return result

# Example usage
detector = DistributionShiftDetector(reference_size=500, window_size=100)

# Simulate reference data (training distribution)
reference_obs = np.random.normal(0, 1, (500, 4))
detector.update_reference(reference_obs)

# Simulate production data (same distribution)
for _ in range(100):
    obs = np.random.normal(0, 1, 4)
    detector.add_observation(obs)

print("No shift expected:")
print(json.dumps(detector.check_shift(), indent=2, default=str))

# Simulate distribution shift
for _ in range(100):
    obs = np.random.normal(2, 1.5, 4)  # Shifted distribution
    detector.add_observation(obs)

print("\nShift expected:")
print(json.dumps(detector.check_shift(), indent=2, default=str))

## 4. Experiment Tracking and Reproducibility

In [None]:
@dataclass
class ExperimentConfig:
    """Complete experiment configuration for reproducibility."""
    
    # Environment
    env_name: str
    env_config: Dict
    
    # Algorithm
    algorithm: str
    algorithm_config: Dict
    
    # Training
    num_iterations: int
    checkpoint_frequency: int
    
    # Reproducibility
    seed: int
    framework: str  # torch or tf
    
    # Metadata
    experiment_name: str
    description: str
    tags: List[str]
    
    def to_dict(self) -> Dict:
        return {
            "env_name": self.env_name,
            "env_config": self.env_config,
            "algorithm": self.algorithm,
            "algorithm_config": self.algorithm_config,
            "num_iterations": self.num_iterations,
            "checkpoint_frequency": self.checkpoint_frequency,
            "seed": self.seed,
            "framework": self.framework,
            "experiment_name": self.experiment_name,
            "description": self.description,
            "tags": self.tags,
        }
    
    def save(self, path: str):
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def load(cls, path: str) -> 'ExperimentConfig':
        with open(path, 'r') as f:
            data = json.load(f)
        return cls(**data)

# Example experiment config
config = ExperimentConfig(
    env_name="CartPole-v1",
    env_config={},
    algorithm="PPO",
    algorithm_config={
        "lr": 3e-4,
        "gamma": 0.99,
        "train_batch_size": 4000,
        "num_sgd_iter": 10,
        "clip_param": 0.2,
    },
    num_iterations=100,
    checkpoint_frequency=10,
    seed=42,
    framework="torch",
    experiment_name="cartpole_ppo_baseline",
    description="Baseline PPO experiment on CartPole",
    tags=["baseline", "ppo", "cartpole"]
)

print("Experiment config:")
print(json.dumps(config.to_dict(), indent=2))

## 5. Offline RL and Human Feedback Integration

In [None]:
class HumanFeedbackCollector:
    """
    Collect and incorporate human feedback for RL.
    
    Patterns:
    1. Trajectory ratings
    2. Action corrections
    3. Preference comparisons
    """
    
    def __init__(self):
        self.trajectory_ratings = []
        self.action_corrections = []
        self.preferences = []
    
    def record_trajectory_rating(self, trajectory_id: str, rating: float, comments: str = ""):
        """Record human rating for a trajectory (1-5 scale)."""
        self.trajectory_ratings.append({
            "trajectory_id": trajectory_id,
            "rating": rating,
            "comments": comments,
            "timestamp": datetime.now().isoformat()
        })
    
    def record_action_correction(self, state, policy_action, human_action, reason: str = ""):
        """Record when human corrects policy action."""
        self.action_corrections.append({
            "state": state.tolist() if hasattr(state, 'tolist') else state,
            "policy_action": policy_action,
            "human_action": human_action,
            "reason": reason,
            "timestamp": datetime.now().isoformat()
        })
    
    def record_preference(self, trajectory_a: str, trajectory_b: str, preferred: str):
        """Record human preference between two trajectories."""
        self.preferences.append({
            "trajectory_a": trajectory_a,
            "trajectory_b": trajectory_b,
            "preferred": preferred,
            "timestamp": datetime.now().isoformat()
        })
    
    def get_correction_dataset(self) -> List[Dict]:
        """Get corrections as training data for behavior cloning."""
        return [
            {"state": c["state"], "action": c["human_action"]}
            for c in self.action_corrections
        ]
    
    def get_feedback_summary(self) -> Dict:
        """Summarize collected feedback."""
        ratings = [r["rating"] for r in self.trajectory_ratings]
        return {
            "total_ratings": len(self.trajectory_ratings),
            "avg_rating": np.mean(ratings) if ratings else 0,
            "total_corrections": len(self.action_corrections),
            "total_preferences": len(self.preferences),
        }

# Example usage
feedback = HumanFeedbackCollector()

# Simulate feedback collection
feedback.record_trajectory_rating("traj_001", 4.5, "Good risk management")
feedback.record_trajectory_rating("traj_002", 2.0, "Too aggressive")

feedback.record_action_correction(
    state=np.array([0.5, 0.1, -0.2, 0.3]),
    policy_action=1,  # Buy
    human_action=0,   # Hold
    reason="Market too volatile"
)

feedback.record_preference("traj_001", "traj_002", "traj_001")

print("Feedback summary:")
print(json.dumps(feedback.get_feedback_summary(), indent=2))

## 6. Production Checklist

In [None]:
class ProductionChecklist:
    """
    Checklist for deploying RL to production.
    """
    
    CHECKLIST = {
        "training": [
            "Experiment config saved and versioned",
            "Random seeds set for reproducibility",
            "Training curves show convergence",
            "Multiple seeds tested for robustness",
            "Hyperparameters tuned on validation set",
        ],
        "evaluation": [
            "Evaluated on held-out test environments",
            "Compared against baseline policies",
            "Tested on edge cases and adversarial inputs",
            "Performance metrics meet requirements",
            "Latency requirements verified",
        ],
        "safety": [
            "Action constraints implemented and tested",
            "Fallback policy configured",
            "Safety bounds verified",
            "Human override mechanism in place",
            "Rollback procedure documented",
        ],
        "monitoring": [
            "Logging infrastructure set up",
            "Distribution shift detection enabled",
            "Performance metrics dashboards created",
            "Alerting thresholds configured",
            "A/B testing framework ready",
        ],
        "deployment": [
            "Model versioned in registry",
            "Serving infrastructure tested",
            "Canary deployment plan ready",
            "Documentation complete",
            "On-call procedures established",
        ],
    }
    
    def __init__(self):
        self.completed = {category: set() for category in self.CHECKLIST}
    
    def mark_complete(self, category: str, item: str):
        if category in self.completed:
            self.completed[category].add(item)
    
    def get_status(self) -> Dict:
        status = {}
        for category, items in self.CHECKLIST.items():
            completed = len(self.completed[category])
            total = len(items)
            status[category] = {
                "completed": completed,
                "total": total,
                "percentage": completed / total * 100,
                "missing": [i for i in items if i not in self.completed[category]]
            }
        return status
    
    def is_ready_for_production(self) -> bool:
        status = self.get_status()
        return all(s["percentage"] == 100 for s in status.values())
    
    def print_checklist(self):
        print("=" * 60)
        print("PRODUCTION READINESS CHECKLIST")
        print("=" * 60)
        
        for category, items in self.CHECKLIST.items():
            print(f"\n{category.upper()}")
            print("-" * 40)
            for item in items:
                status = "[x]" if item in self.completed[category] else "[ ]"
                print(f"  {status} {item}")
        
        print("\n" + "=" * 60)
        ready = self.is_ready_for_production()
        print(f"PRODUCTION READY: {'YES' if ready else 'NO'}")
        print("=" * 60)

# Example usage
checklist = ProductionChecklist()

# Mark some items complete
checklist.mark_complete("training", "Experiment config saved and versioned")
checklist.mark_complete("training", "Random seeds set for reproducibility")
checklist.mark_complete("safety", "Action constraints implemented and tested")

checklist.print_checklist()

## Key Industry Patterns Summary

### 1. Always Have a Fallback
- Rule-based policy as backup
- Previous production model
- Human override capability

### 2. Monitor Everything
- Input distribution
- Action distribution
- Reward signals
- Latency

### 3. Deploy Incrementally
- Shadow mode first
- Small traffic percentage
- Gradual rollout

### 4. Design for Failure
- Graceful degradation
- Circuit breakers
- Automatic rollback

### 5. Maintain Reproducibility
- Version everything
- Document experiments
- Track lineage

## Congratulations!

You've completed the Ray RLlib tutorial series. You now have knowledge spanning:

1. **Fundamentals**: RL concepts, MDPs, Q-learning
2. **Deep RL**: DQN, Policy Gradients, Actor-Critic
3. **RLlib**: Setup, algorithms, configuration
4. **Custom Environments**: Gymnasium interface, registration
5. **Distributed Training**: Scaling, multi-GPU, clusters
6. **Hyperparameter Tuning**: Ray Tune, PBT, ASHA
7. **Production Deployment**: Serving, A/B testing, monitoring
8. **Industry Patterns**: Safety, reliability, best practices

### Next Steps
- Apply these patterns to your specific use case
- Explore advanced algorithms (offline RL, multi-agent)
- Contribute to the RLlib community
- Stay updated with latest research