# PPO-LLM Strategy Shaping Experiments

This notebook provides an interactive interface for running and analyzing the PPO-LLM Strategy Shaping experiments.

## Contents
1. Setup & Configuration
2. Training All Baselines
3. Nash Gap Analysis
4. Latency Measurement
5. Robustness Analysis
6. Task Completion Metrics

## 1. Setup & Configuration

In [None]:
# Install package if needed (uncomment if running for first time)
# !pip install -e ..

In [None]:
import sys
sys.path.insert(0, '../src')

from ppo_llm_strategy_shaping import (
    Config,
    train_all,
    train_one_run,
    load_or_train,
    make_env,
    evaluate,
    nash_gap_analysis,
    latency_analysis,
    robustness_analysis,
    task_completion_analysis,
    generate_summary_report,
)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sns.set_theme(style="whitegrid")
print("Imports successful!")

In [None]:
# Configure experiment
config = Config(
    layout="cramped_room",
    horizon=400,
    seeds=[42, 123, 456, 789, 1000],
    
    # PPO hyperparameters
    lr=3e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    
    # LLM settings
    llm_model="EleutherAI/gpt-neo-125M",  # Use smaller model for faster experiments
    llm_reward_scale=0.1,
    
    # Perturbation settings
    noise_std=0.01,
    delay_prob=0.2,
    delay_penalty=-0.5,
)

print(f"Layout: {config.layout}")
print(f"Horizon: {config.horizon}")
print(f"Seeds: {config.seeds}")
print(f"Baselines: {list(config.baseline_steps.keys())}")
print(f"Environments: {config.all_envs}")

## 2. Training All Baselines

Train all baseline methods across all environment perturbations and seeds.

**Note**: Full training takes several hours. For quick testing, reduce seeds or use a subset of baselines.

In [None]:
# Quick test: train single configuration
# model, run_dir, train_time = train_one_run(
#     config,
#     baseline="PPO+LLM",
#     env_name="No Noise",
#     seed=42,
#     verbose=1
# )
# print(f"Training completed in {train_time:.1f}s")
# print(f"Model saved to: {run_dir}")

In [None]:
# Full training (uncomment to run)
# WARNING: This will take many hours!

# results = train_all(
#     config,
#     baselines=["Baseline", "PPO+LLM"],  # Start with subset
#     env_names=["No Noise", "Noise"],
#     seeds=[42, 123],
#     n_jobs=4,  # Parallel workers
#     verbose=1
# )
# print(f"Completed {len(results)} training runs")

## 3. Nash Gap Analysis

Compute Nash gaps to measure how close trained policies are to Nash equilibrium.

In [None]:
# Run Nash gap analysis (requires trained models)
# nash_results = nash_gap_analysis(
#     config,
#     n_episodes=5,
#     output_csv=config.results_root + "/nash_gap.csv"
# )

# # Display results
# for baseline, envs in nash_results.items():
#     print(f"\n{baseline}:")
#     for env_name, gap in envs.items():
#         print(f"  {env_name}: {gap:.4f}")

In [None]:
# Visualize Nash gaps
# df = pd.read_csv(config.results_root + "/nash_gap.csv")
# 
# plt.figure(figsize=(12, 6))
# sns.barplot(data=df, x="baseline", y="nash_gap", hue="env")
# plt.title("Nash Gap by Baseline and Environment")
# plt.ylabel("Nash Gap (lower is better)")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.show()

## 4. Latency Measurement

Measure inference latency for each baseline to assess real-time deployment feasibility.

In [None]:
# Run latency analysis (requires trained models)
# latency_results = latency_analysis(
#     config,
#     n_steps=1000,
#     output_csv=config.results_root + "/latency.csv"
# )

# # Display results
# for baseline, latency in latency_results.items():
#     print(f"{baseline}: {latency['mean_ms']:.2f}ms (p95: {latency['p95_ms']:.2f}ms)")

In [None]:
# Visualize latency
# df = pd.read_csv(config.results_root + "/latency.csv")
# 
# fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 
# axes[0].bar(df['baseline'], df['mean_ms'], yerr=df['std_ms'], capsize=5)
# axes[0].set_title("Mean Inference Latency")
# axes[0].set_ylabel("Latency (ms)")
# axes[0].tick_params(axis='x', rotation=45)
# 
# axes[1].bar(df['baseline'], df['p95_ms'])
# axes[1].set_title("95th Percentile Latency")
# axes[1].set_ylabel("Latency (ms)")
# axes[1].tick_params(axis='x', rotation=45)
# 
# plt.tight_layout()
# plt.show()

## 5. Robustness Analysis

Test how well policies trained in one environment generalize to others.

In [None]:
# Run robustness analysis (requires trained models)
# robustness_results = robustness_analysis(
#     config,
#     n_episodes=10,
#     output_csv=config.results_root + "/robustness.csv"
# )

In [None]:
# Visualize robustness as heatmaps
# df = pd.read_csv(config.results_root + "/robustness.csv")
# 
# baselines_to_plot = ["Baseline", "PPO+LLM"]
# 
# fig, axes = plt.subplots(1, len(baselines_to_plot), figsize=(6*len(baselines_to_plot), 5))
# 
# for ax, baseline in zip(axes, baselines_to_plot):
#     subset = df[df['baseline'] == baseline].groupby(['train_env', 'test_env'])['mean_reward'].mean().unstack()
#     sns.heatmap(subset, annot=True, fmt='.1f', cmap='YlGnBu', ax=ax)
#     ax.set_title(f"{baseline} Robustness")
#     ax.set_xlabel("Test Environment")
#     ax.set_ylabel("Train Environment")
# 
# plt.tight_layout()
# plt.show()

## 6. Task Completion Metrics

Measure actual task performance (dishes served) in Overcooked.

In [None]:
# Run task completion analysis (requires trained models)
# task_results = task_completion_analysis(
#     config,
#     n_episodes=10,
#     output_csv=config.results_root + "/task_completion.csv"
# )

# # Display results
# for baseline, envs in task_results.items():
#     print(f"\n{baseline}:")
#     for env_name, metrics in envs.items():
#         print(f"  {env_name}: {metrics['mean_dishes']:.2f} dishes (rate: {metrics['completion_rate']:.1%})")

In [None]:
# Visualize task completion
# df = pd.read_csv(config.results_root + "/task_completion.csv")
# 
# plt.figure(figsize=(12, 6))
# sns.barplot(data=df, x="baseline", y="mean_dishes", hue="env")
# plt.title("Task Completion by Baseline and Environment")
# plt.ylabel("Mean Dishes Served")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.show()

## Summary Report

Generate a comprehensive summary of all analyses.

In [None]:
# Generate full summary report (requires all analyses to be run)
# report = generate_summary_report(config)
# print(report)

---

## Appendix: Environment Visualization

In [None]:
# Visualize a single episode (requires trained model)
# from ppo_llm_strategy_shaping import make_env
# from stable_baselines3 import PPO
# 
# env = make_env(config, "PPO+LLM", "No Noise", seed=42)
# model = PPO.load("path/to/model.zip", env=env)
# 
# obs, _ = env.reset()
# for _ in range(100):
#     action, _ = model.predict(obs, deterministic=True)
#     obs, reward, done, truncated, info = env.step(action)
#     # env.render()  # Uncomment if render is available
#     if done or truncated:
#         break