# 🏭 Enhanced RL Training for Perishable Inventory MDP

**Professional Training Pipeline with State-of-the-Art Improvements**

This notebook implements a production-ready RL training pipeline featuring:

| Feature | Description |
|---------|-------------|
| **5M Training Steps** | Extended training for better convergence |
| **Learning Rate Annealing** | Linear decay from 3e-4 → 0 |
| **Entropy Decay** | From 0.01 → 0.001 for exploration/exploitation balance |
| **Curriculum Learning** | Simple → Moderate → Complex → Extreme |
| **Cost-Aware Observations** | Enhanced state with supplier costs |
| **Asymmetric Actions** | Favor cheap supplier ordering |
| **TBS Benchmarking** | Continuous comparison with optimal baseline |
| **100+ Environments** | Comprehensive evaluation suite |

---

**Objective**: Train an RL agent that outperforms the Tailored Base-Surge (TBS) policy on complex environments while matching performance on simple ones.

## 📋 Table of Contents

1. [Setup & Installation](#setup)
2. [Environment Suite](#env-suite)
3. [Training Configuration](#config)
4. [Model Training](#training)
5. [Evaluation & Benchmarking](#evaluation)
6. [Results Analysis](#results)
7. [Model Export](#export)

---
## 1️⃣ Setup & Installation <a name="setup"></a>

Install dependencies and clone the repository.

In [None]:
# Install dependencies
!pip install stable-baselines3[extra] gymnasium numpy scipy matplotlib pandas tensorboard -q
print("✅ Dependencies installed")

In [None]:
import os
import sys

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone repository
    REPO_URL = "https://github.com/MahmoudZah/Multi-Supplier-Perishable-Inventory.git"
    REPO_DIR = "Multi-Supplier-Perishable-Inventory"
    
    if not os.path.exists(REPO_DIR):
        print(f"📥 Cloning repository...")
        !git clone {REPO_URL} {REPO_DIR}
    
    os.chdir(REPO_DIR)
    sys.path.insert(0, os.getcwd())
    print(f"📂 Working directory: {os.getcwd()}")
else:
    print("🖥️ Running locally")

print("✅ Repository ready")

In [None]:
# Core imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, List, Optional, Any
import time
import json
from pathlib import Path

# RL imports
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.callbacks import CallbackList, EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from gymnasium.wrappers import TimeLimit

# Project imports
from colab_training.gym_env import (
    PerishableInventoryGymWrapper,
    RewardConfig,
    create_gym_env
)
from colab_training.environment_suite import (
    EnvironmentSuite,
    EnvironmentConfig,
    create_environment_suite,
    build_environment_from_config,
    get_canonical_suite
)
from colab_training.callbacks import (
    ScheduleCallback,
    CurriculumCallback,
    BenchmarkCallback,
    create_lr_schedule,
    create_entropy_schedule
)
from colab_training.benchmark import (
    evaluate_policy,
    compare_policies,
    get_tbs_policy_for_env,
    get_basestock_policy_for_env,
    generate_performance_report,
    visualize_comparison,
    ComparisonReport
)
from perishable_inventory_mdp.policies import (
    TailoredBaseSurgePolicy,
    BaseStockPolicy,
    DoNothingPolicy
)

print("✅ All imports successful!")

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    print(f"🚀 GPU available: {gpu_name}")
    print(f"   CUDA version: {torch.version.cuda}")
else:
    device = torch.device("cpu")
    print("⚠️ No GPU available, using CPU")
    print("   Consider enabling GPU: Runtime → Change runtime type → GPU")

---
## 2️⃣ Environment Suite <a name="env-suite"></a>

Load the canonical 105-environment benchmark suite with varying complexity levels.

In [None]:
# Load environment suite
suite = get_canonical_suite()

print("📊 Environment Suite Summary")
print("=" * 40)
print(f"Total environments: {len(suite)}")
print()

summary = suite.get_summary()
for complexity, count in sorted(summary.items()):
    bar = "█" * (count // 2)
    print(f"  {complexity.capitalize():10s} │ {count:3d} │ {bar}")

print()
print("Complexity Progression:")
print("  Simple   → TBS-optimal scenarios (baseline)")
print("  Moderate → Some seasonality/stochasticity")
print("  Complex  → Composite demand, crisis dynamics")
print("  Extreme  → Maximum challenge, RL should excel")

In [None]:
# Preview sample environment
sample_config = suite.get_by_complexity("simple")[0]
print("📦 Sample Environment Configuration (Simple)")
print("=" * 50)
print(f"  Environment ID: {sample_config.env_id}")
print(f"  Shelf life: {sample_config.shelf_life}")
print(f"  Mean demand: {sample_config.mean_demand:.1f}")
print(f"  Suppliers: {sample_config.num_suppliers}")
print(f"  Lead times: {sample_config.lead_times}")
print(f"  Unit costs: {sample_config.unit_costs}")
print(f"  Demand type: {sample_config.demand_type}")

# Create and test environment
test_env = create_gym_env(
    shelf_life=sample_config.shelf_life,
    mean_demand=sample_config.mean_demand,
    fast_lead_time=sample_config.lead_times[0],
    slow_lead_time=sample_config.lead_times[1],
    fast_cost=sample_config.unit_costs[0],
    slow_cost=sample_config.unit_costs[1]
)

print(f"\n🎮 Gym Environment")
print(f"  Observation space: {test_env.observation_space}")
print(f"  Action space: {test_env.action_space}")

obs_info = test_env.get_observation_space_info()
print(f"\n📊 Observation Components:")
for name, (start, end) in obs_info.items():
    print(f"    {name}: [{start}:{end}] ({end-start} dims)")

---
## 3️⃣ Training Configuration <a name="config"></a>

Configure the training hyperparameters for optimal performance.

In [None]:
# 🎛️ TRAINING CONFIGURATION
# Adjust these based on available compute time

TRAINING_CONFIG = {
    # Core training
    "total_timesteps": 5_000_000,       # 5M for full training (reduce for testing)
    "n_envs": 8,                         # Parallel environments
    "episode_length": 500,               # Steps per episode
    
    # Learning rate schedule
    "initial_learning_rate": 3e-4,
    "final_learning_rate": 0.0,
    
    # Entropy coefficient schedule
    "initial_entropy_coef": 0.01,
    "final_entropy_coef": 0.001,
    
    # Curriculum learning
    "curriculum_enabled": True,
    "curriculum_thresholds": {
        "simple": -5.0,
        "moderate": -8.0,
        "complex": -12.0
    },
    "min_episodes_per_level": 50,
    
    # Evaluation & checkpointing
    "eval_freq": 50_000,
    "checkpoint_freq": 100_000,
    "benchmark_freq": 100_000,
    "n_eval_episodes": 10,
    
    # Model architecture
    "policy_kwargs": {
        "net_arch": [256, 256]           # Two hidden layers of 256 units
    },
    
    # Random seed for reproducibility
    "seed": 42
}

# Reward shaping configuration
REWARD_CONFIG = RewardConfig(
    alpha=0.5,       # Procurement cost weight
    beta=0.3,        # Holding + spoilage weight
    gamma=0.2,       # Shortage penalty weight
    delta=0.1,       # Service bonus
    target_fill_rate=0.95,
    normalize=True,
    normalization_scale=10.0
)

print("⚙️ Training Configuration")
print("=" * 50)
print(f"  Total timesteps: {TRAINING_CONFIG['total_timesteps']:,}")
print(f"  Parallel envs: {TRAINING_CONFIG['n_envs']}")
print(f"  Episode length: {TRAINING_CONFIG['episode_length']}")
print()
print("📉 Learning Rate Schedule")
print(f"  {TRAINING_CONFIG['initial_learning_rate']} → {TRAINING_CONFIG['final_learning_rate']}")
print()
print("🎲 Entropy Schedule")
print(f"  {TRAINING_CONFIG['initial_entropy_coef']} → {TRAINING_CONFIG['final_entropy_coef']}")
print()
print("📚 Curriculum Learning")
print(f"  Enabled: {TRAINING_CONFIG['curriculum_enabled']}")
print(f"  Thresholds: {TRAINING_CONFIG['curriculum_thresholds']}")

# Estimate training time
steps_per_second = 1000  # Approximate
estimated_hours = TRAINING_CONFIG['total_timesteps'] / steps_per_second / 3600
print(f"\n⏱️ Estimated training time: {estimated_hours:.1f} hours")

---
## 4️⃣ Model Training <a name="training"></a>

Train the PPO agent with curriculum learning and continuous benchmarking.

In [None]:
def create_env_from_config(
    env_config: EnvironmentConfig,
    reward_config: Optional[RewardConfig] = None,
    episode_length: int = 500
) -> PerishableInventoryGymWrapper:
    """Create gym environment from EnvironmentConfig."""
    mdp = build_environment_from_config(env_config)
    
    env = PerishableInventoryGymWrapper(
        mdp=mdp,
        reward_config=reward_config or REWARD_CONFIG
    )
    
    env = TimeLimit(env, max_episode_steps=episode_length)
    env = Monitor(env)
    
    # Store mdp reference on wrapper for TBS policy creation
    env.mdp = mdp
    
    return env


def make_curriculum_env_factory(n_envs: int, seed: int):
    """Create factory function for curriculum environments."""
    def env_factory(complexity: str) -> SubprocVecEnv:
        configs = suite.get_by_complexity(complexity)
        
        if not configs:
            raise ValueError(f"No environments for complexity: {complexity}")
        
        rng = np.random.RandomState(seed)
        selected = rng.choice(configs, size=min(n_envs, len(configs)), replace=False).tolist()
        
        while len(selected) < n_envs:
            selected.append(rng.choice(configs))
        
        def make_env(cfg):
            def _init():
                return create_env_from_config(
                    cfg, 
                    REWARD_CONFIG, 
                    TRAINING_CONFIG['episode_length']
                )
            return _init
        
        env_fns = [make_env(cfg) for cfg in selected]
        return SubprocVecEnv(env_fns) if n_envs > 1 else DummyVecEnv(env_fns)
    
    return env_factory


def create_eval_env(complexity: str = "simple"):
    """Create evaluation environment."""
    configs = suite.get_by_complexity(complexity)
    config = configs[0] if configs else suite.configs[0]
    
    def _init():
        return create_env_from_config(
            config, 
            REWARD_CONFIG, 
            TRAINING_CONFIG['episode_length']
        )
    
    return DummyVecEnv([_init])

print("✅ Environment factories defined")

In [None]:
# Setup directories
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
(log_dir / "checkpoints").mkdir(exist_ok=True)
(log_dir / "best_model").mkdir(exist_ok=True)
(log_dir / "benchmark").mkdir(exist_ok=True)

# Create environment factory
env_factory = make_curriculum_env_factory(
    n_envs=TRAINING_CONFIG['n_envs'],
    seed=TRAINING_CONFIG['seed']
)

# Create initial training environment (simple)
print("🏗️ Creating training environment...")
train_env = env_factory("simple")
print(f"   Starting with 'simple' complexity")

# Create evaluation environment
eval_env = create_eval_env("simple")
print(f"   Evaluation environment ready")

# Create learning rate schedule
lr_schedule = create_lr_schedule(
    TRAINING_CONFIG['initial_learning_rate'],
    TRAINING_CONFIG['final_learning_rate']
)

# Create PPO model
print("\n🧠 Creating PPO model...")
model = PPO(
    "MlpPolicy",
    train_env,
    learning_rate=lr_schedule,
    ent_coef=TRAINING_CONFIG['initial_entropy_coef'],
    verbose=1,
    tensorboard_log=str(log_dir / "tensorboard"),
    seed=TRAINING_CONFIG['seed'],
    device="auto",
    **{"policy_kwargs": TRAINING_CONFIG['policy_kwargs']}
)

print(f"   Policy: {model.policy}")
print(f"   Device: {model.device}")
print("\n✅ Model ready for training")

In [None]:
# Setup callbacks
callbacks = []

# 1. Schedule callback (logs LR/entropy to tensorboard)
schedule_callback = ScheduleCallback(
    initial_lr=TRAINING_CONFIG['initial_learning_rate'],
    final_lr=TRAINING_CONFIG['final_learning_rate'],
    initial_ent_coef=TRAINING_CONFIG['initial_entropy_coef'],
    final_ent_coef=TRAINING_CONFIG['final_entropy_coef'],
    log_freq=5000,
    verbose=1
)
callbacks.append(schedule_callback)
print("📉 Schedule callback: LR/entropy annealing")

# 2. Curriculum callback
if TRAINING_CONFIG['curriculum_enabled']:
    curriculum_callback = CurriculumCallback(
        env_factory=env_factory,
        thresholds=TRAINING_CONFIG['curriculum_thresholds'],
        window_size=100,
        min_episodes_per_level=TRAINING_CONFIG['min_episodes_per_level'],
        verbose=1
    )
    callbacks.append(curriculum_callback)
    print("📚 Curriculum callback: complexity progression")

# 3. Evaluation callback
eval_callback = EvalCallback(
    eval_env,
    best_model_save_path=str(log_dir / "best_model"),
    log_path=str(log_dir / "eval"),
    eval_freq=TRAINING_CONFIG['eval_freq'] // TRAINING_CONFIG['n_envs'],
    n_eval_episodes=TRAINING_CONFIG['n_eval_episodes'],
    deterministic=True,
    verbose=1
)
callbacks.append(eval_callback)
print("📊 Evaluation callback: best model tracking")

# 4. Checkpoint callback
checkpoint_callback = CheckpointCallback(
    save_freq=TRAINING_CONFIG['checkpoint_freq'] // TRAINING_CONFIG['n_envs'],
    save_path=str(log_dir / "checkpoints"),
    name_prefix="ppo_perishable"
)
callbacks.append(checkpoint_callback)
print("💾 Checkpoint callback: periodic saves")

# 5. Benchmark callback (TBS comparison)
try:
    tbs_policy = get_tbs_policy_for_env(eval_env)
    benchmark_callback = BenchmarkCallback(
        eval_env=eval_env,
        benchmark_freq=TRAINING_CONFIG['benchmark_freq'] // TRAINING_CONFIG['n_envs'],
        n_eval_episodes=TRAINING_CONFIG['n_eval_episodes'],
        baseline_policies={"TBS": tbs_policy},
        save_path=str(log_dir / "benchmark"),
        verbose=1
    )
    callbacks.append(benchmark_callback)
    print("🏆 Benchmark callback: TBS comparison")
except Exception as e:
    print(f"⚠️ Could not create TBS baseline: {e}")

callback_list = CallbackList(callbacks)
print(f"\n✅ {len(callbacks)} callbacks configured")

In [None]:
# 🚀 START TRAINING
print("=" * 60)
print("🚀 STARTING TRAINING")
print("=" * 60)
print(f"Total timesteps: {TRAINING_CONFIG['total_timesteps']:,}")
print(f"Parallel environments: {TRAINING_CONFIG['n_envs']}")
print(f"Curriculum learning: {TRAINING_CONFIG['curriculum_enabled']}")
print("=" * 60)

start_time = time.time()

model.learn(
    total_timesteps=TRAINING_CONFIG['total_timesteps'],
    callback=callback_list,
    progress_bar=True
)

training_time = time.time() - start_time
print("\n" + "=" * 60)
print("✅ TRAINING COMPLETE")
print("=" * 60)
print(f"Training time: {training_time/3600:.2f} hours")
print(f"Steps per second: {TRAINING_CONFIG['total_timesteps']/training_time:.0f}")

In [None]:
# Save final model
final_model_path = log_dir / "final_model"
model.save(str(final_model_path))
print(f"💾 Final model saved to: {final_model_path}")

# Cleanup
train_env.close()
eval_env.close()
print("✅ Environments closed")

---
## 🔄 Resume Training (If Interrupted)

Use this cell if your training was interrupted (kernel crash, timeout, etc.).
It will load the latest checkpoint and continue training from where it left off.

In [None]:
# 🔄 RESUME TRAINING FROM CHECKPOINT
# Run this cell ONLY if training was interrupted and you want to continue

import glob

# Find latest checkpoint
checkpoint_dir = Path("logs/checkpoints")
checkpoints = sorted(checkpoint_dir.glob("ppo_perishable_*_steps.zip"))

if not checkpoints:
    print("❌ No checkpoints found. Please run training from the beginning.")
else:
    latest_checkpoint = checkpoints[-1]
    print(f"📂 Found {len(checkpoints)} checkpoints")
    print(f"   Latest: {latest_checkpoint.name}")
    
    # Extract step count from filename
    checkpoint_steps = int(latest_checkpoint.stem.split("_")[-2])
    remaining_steps = TRAINING_CONFIG["total_timesteps"] - checkpoint_steps
    
    print(f"\n📊 Progress:")
    print(f"   Completed: {checkpoint_steps:,} steps")
    print(f"   Remaining: {remaining_steps:,} steps")
    
    if remaining_steps <= 0:
        print("\n✅ Training already complete! Proceed to evaluation.")
    else:
        # Recreate environments (run env-factory and setup-training cells first if not done)
        try:
            train_env
        except NameError:
            print("\n⚠️ Environments not initialized. Run these cells first:")
            print("   1. Training Configuration")
            print("   2. Environment factories")
            print("   3. Setup training")
            print("   Then come back here.")
        else:
            # Load model with environment
            print(f"\n🧠 Loading model from checkpoint...")
            model = PPO.load(str(latest_checkpoint), env=train_env)
            print(f"   Model loaded successfully")
            
            # Resume training
            print(f"\n🚀 Resuming training for {remaining_steps:,} more steps...")
            print("=" * 60)
            
            start_time = time.time()
            
            model.learn(
                total_timesteps=remaining_steps,
                callback=callback_list,
                reset_num_timesteps=False,  # Critical: continue step counter
                progress_bar=True
            )
            
            training_time = time.time() - start_time
            print("\n" + "=" * 60)
            print("✅ RESUMED TRAINING COMPLETE")
            print("=" * 60)
            print(f"Additional training time: {training_time/3600:.2f} hours")
            
            # Save final model
            final_model_path = log_dir / "final_model"
            model.save(str(final_model_path))
            print(f"💾 Final model saved to: {final_model_path}")

---
## 5️⃣ Evaluation & Benchmarking <a name="evaluation"></a>

Comprehensively evaluate the trained model against baselines across all environments.

In [None]:
# Load best model
best_model_path = log_dir / "best_model" / "best_model.zip"
if best_model_path.exists():
    model = PPO.load(str(best_model_path))
    print(f"✅ Loaded best model from: {best_model_path}")
else:
    print(f"⚠️ Best model not found, using final model")
    model = PPO.load(str(final_model_path))

In [None]:
# Comprehensive evaluation across all complexity levels
print("📊 Comprehensive Evaluation")
print("=" * 60)

report = ComparisonReport()
n_eval = 5  # Episodes per environment

for complexity in ["simple", "moderate", "complex", "extreme"]:
    configs = suite.get_by_complexity(complexity)
    n_envs = min(5, len(configs))  # Evaluate on subset for speed
    
    print(f"\n🔍 Evaluating {complexity.upper()} environments ({n_envs} samples)...")
    
    for i, config in enumerate(configs[:n_envs]):
        env = create_env_from_config(config, REWARD_CONFIG, 500)
        
        # Evaluate RL
        rl_result = evaluate_policy(
            policy=model,
            env=env,
            n_episodes=n_eval,
            max_steps=500,
            policy_name="RL",
            env_id=config.env_id,
            complexity=complexity
        )
        report.add_result(rl_result)
        
        # Evaluate TBS
        try:
            tbs = get_tbs_policy_for_env(env)
            tbs_result = evaluate_policy(
                policy=tbs,
                env=env,
                n_episodes=n_eval,
                max_steps=500,
                policy_name="TBS",
                env_id=config.env_id,
                complexity=complexity
            )
            report.add_result(tbs_result)
        except:
            pass
        
        env.close()
        print(f"   [{i+1}/{n_envs}] {config.env_id}: RL cost={rl_result.mean_cost:.1f}")

print("\n✅ Evaluation complete")

---
## 6️⃣ Results Analysis <a name="results"></a>

Analyze and visualize the performance comparison.

In [None]:
# Generate summary report
print(generate_performance_report(report))

In [None]:
# Show detailed results
df = report.to_dataframe()
print("\n📋 Detailed Results")
display(df.head(20))

In [None]:
# Summary by complexity
summary = report.get_summary_by_complexity()
print("\n📊 Summary by Complexity Level")
display(summary)

In [None]:
# Visualize comparison
fig = visualize_comparison(report, save_path=str(log_dir / "comparison.png"))
plt.show()

In [None]:
# RL vs TBS cost ratio
ratio = report.get_rl_vs_tbs_ratio()
print("\n🏆 RL vs TBS Cost Ratio (lower is better for RL)")
print("=" * 50)
display(ratio)

print("\nInterpretation:")
print("  Ratio < 1.0 → RL outperforms TBS")
print("  Ratio = 1.0 → Equal performance")
print("  Ratio > 1.0 → TBS outperforms RL")

---
## 7️⃣ Model Export <a name="export"></a>

Export the trained model and results for deployment or further analysis.

In [None]:
# Save evaluation report
report_path = log_dir / "evaluation_report.json"
report.save(str(report_path))
print(f"📄 Evaluation report saved: {report_path}")

# Save training config
config_path = log_dir / "training_config.json"
with open(config_path, 'w') as f:
    json.dump(TRAINING_CONFIG, f, indent=2)
print(f"📄 Training config saved: {config_path}")

In [None]:
# Download files (Colab only)
if IN_COLAB:
    from google.colab import files
    
    # Zip logs directory
    !zip -r logs.zip logs/
    
    print("\n📥 Download trained model and results:")
    files.download('logs.zip')
    print("\n✅ Download started")
else:
    print(f"\n📁 All files saved in: {log_dir.absolute()}")

---
## 🎉 Training Complete!

### Summary

You have successfully trained a PPO agent on the Perishable Inventory MDP with:

- ✅ 5M training steps with learning rate annealing
- ✅ Curriculum learning through 4 complexity levels
- ✅ Cost-aware observations and asymmetric action space
- ✅ Comprehensive benchmarking against TBS baseline
- ✅ Evaluation across 100+ unique environments

### Next Steps

1. **Analyze results**: Review the RL vs TBS comparison by complexity
2. **Fine-tune**: Adjust hyperparameters if needed and retrain
3. **Deploy**: Use the trained model for inventory optimization
4. **Extend**: Add more complex scenarios or multi-item support

### Files Generated

| File | Description |
|------|-------------|
| `logs/final_model.zip` | Final trained model |
| `logs/best_model/` | Best model during training |
| `logs/checkpoints/` | Periodic checkpoints |
| `logs/evaluation_report.json` | Full evaluation results |
| `logs/comparison.png` | Visualization chart |