# üèãÔ∏è PPO-LLM Strategy Shaping: Training

**Train all baseline methods** across different environment perturbations.

## Baselines:
- **Baseline** - Vanilla PPO
- **PPO+LLM** - PPO with LLM reward shaping (only 600K steps needed!)
- **CC_PPO** - Centralized Critic PPO
- **SP_PPO** - Self-Play PPO
- **HARL** - Heterogeneous-Agent RL
- **PBT_PPO** - Population-Based Training

## Environments:
- **No Noise** - Clean observations
- **Noise** - Gaussian observation noise (œÉ=0.01)
- **Delay** - 20% chance of action delay with penalty
- **Combo** - Both noise and delay

‚ö†Ô∏è **Run 01_setup.ipynb first!**

In [None]:
# =========================================================
# TRAINING SCRIPT (PARALLEL + FIXED STOPPING + 2-AGENT ENV)
# =========================================================
import os, re, gc, json, time, csv, random, math, pickle
import numpy as np
import gymnasium as gym
import torch, pandas as pd
import torch.nn as nn
from joblib import Parallel, delayed

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor

from transformers import AutoTokenizer, AutoModelForCausalLM
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, OvercookedGridworld
from overcooked_ai_py.mdp.actions import Action

print("Imports loaded!")

## Configuration

In [None]:
# ==========================================
# CONFIGURATION
# ==========================================
LAYOUT = "asymmetric_advantages"
HORIZON = 400
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEEDS = [1001, 2002, 3003, 4004, 5005]

CHECKPOINT_EVERY_STEPS = 50_000
LOG_EVERY_STEPS = 2_048
EVAL_EPISODES = 10

RESULTS_CSV = "/content/drive/MyDrive/results_combined_new.csv"
RUNS_DIR = "/content/drive/MyDrive/runs"
LLM_BONUS = 0.2

BASELINE_STEPS = {
    "Baseline": 1_000_000,
    "CC_PPO":   1_000_000,
    "SP_PPO":   1_000_000,
    "HARL":     1_000_000,
    "PPO+LLM":  600_000,
    "PBT_PPO":  1_000_000,
}

NUM_ACTIONS = len(Action.ALL_ACTIONS)

print(f"Device: {DEVICE}")
print(f"Layout: {LAYOUT}")
print(f"Seeds: {SEEDS}")

## LLM Setup (Lazy Loading)

In [None]:
# ==========================================
# LLM SETUP (LAZY LOAD)
# ==========================================
_GLOBAL_LLM = None
_GLOBAL_TOK = None

def get_llm():
    global _GLOBAL_LLM, _GLOBAL_TOK
    if _GLOBAL_LLM is None:
        try:
            _GLOBAL_TOK = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
            _GLOBAL_LLM = AutoModelForCausalLM.from_pretrained(
                "EleutherAI/gpt-neo-1.3B",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            ).to(DEVICE).eval()
        except Exception:
            _GLOBAL_LLM = False
    return _GLOBAL_LLM, _GLOBAL_TOK

@torch.no_grad()
def is_good(prompt: str) -> bool:
    """Ask LLM if the action is 'good' or 'bad' based on logit comparison."""
    llm, tok = get_llm()
    if not llm:
        return False
    try:
        enc = tok(prompt, return_tensors="pt").to(DEVICE)
        logits = llm(**enc).logits[0, -1]
        def idx(tok_str):
            ids = tok.encode(tok_str, add_special_tokens=False)
            return ids[0] if ids else None
        gi, bi = idx(" good"), idx(" bad")
        if gi is None or bi is None:
            return False
        return logits[gi].item() > logits[bi].item()
    except Exception:
        return False

print("LLM functions defined!")

## Environment Wrappers (True 2-Agent)

In [None]:
# ==========================================
# ENVIRONMENT WRAPPERS
# ==========================================

class OCWrapper(gym.Env):
    """
    True 2-agent Overcooked wrapper:
    - observation: global featurized state (flattened)
    - action_space: MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])
    - reward: shared team reward
    """
    metadata = {"render.modes": []}

    def __init__(self, layout):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=o0.flatten().shape, dtype=np.float32
        )
        self.action_space = gym.spaces.MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])

    def reset(self, seed=None, options=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        self.oc.reset()
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32), {}

    def step(self, action):
        a0, a1 = int(action[0]), int(action[1])
        joint = [Action.ALL_ACTIONS[a0], Action.ALL_ACTIONS[a1]]
        state, r, done, info = self.oc.step(joint)
        o0, _ = self.oc.featurize_state_mdp(state)
        return o0.flatten().astype(np.float32), float(r), bool(done), False, info


class OCWrapperNoise(OCWrapper):
    """Adds Gaussian noise to observations."""
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        return obs, r, term, trunc, info


class OCWrapperDelay(OCWrapper):
    """Adds random action delay with reward penalty."""
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info


class OCWrapperCombo(OCWrapper):
    """Combines noise and delay."""
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info


class OCWrapperLLM(OCWrapper):
    """Adds LLM-based reward shaping."""
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        a0, a1 = int(action[0]), int(action[1])
        act0 = Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[a0]]
        act1 = Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[a1]]
        prompt = f"In cooperative cooking, are joint actions '{act0}' and '{act1}' helpful? Answer good or bad."
        if is_good(prompt):
            r += LLM_BONUS
        return obs, r, term, trunc, info


# Combined wrappers for LLM + perturbations
class _LLMNoise(OCWrapperLLM, OCWrapperNoise): pass
class _LLMDelay(OCWrapperLLM, OCWrapperDelay): pass
class _LLMCombo(OCWrapperLLM, OCWrapperCombo): pass


class HARL(OCWrapper):
    """Heterogeneous-Agent RL with order-based shaping."""
    def step(self, action):
        obs, reward, term, trunc, info = super().step(action)
        orders_remaining = info.get("orders_remaining", 0)
        if orders_remaining == 0:
            reward += 1.0
        elif orders_remaining < 3:
            reward += 0.5
        return obs, reward, term, trunc, info

class _HARLNoise(HARL, OCWrapperNoise): pass
class _HARLDelay(HARL, OCWrapperDelay): pass
class _HARLCombo(HARL, OCWrapperCombo): pass

class SP_PPO(OCWrapper): pass  # Self-play uses same base wrapper

print("Environment wrappers defined!")

## Environment Factories

In [None]:
# ==========================================
# ENVIRONMENT FACTORIES
# ==========================================

def make_env(env_name, layout):
    """Create evaluation environment (no training-specific shaping)."""
    e = env_name.lower()
    mapping = {
        "no noise": OCWrapper,
        "noise": OCWrapperNoise,
        "delay": OCWrapperDelay,
        "combo": OCWrapperCombo,
    }
    return Monitor(mapping[e](layout))


def make_train_env(baseline, layout, env_name):
    """Create training environment with baseline-specific shaping."""
    b, e = baseline.lower(), env_name.lower()
    
    if b == "ppo+llm":
        cls = {
            "no noise": OCWrapperLLM,
            "noise": _LLMNoise,
            "delay": _LLMDelay,
            "combo": _LLMCombo,
        }[e]
        return Monitor(cls(layout))
    
    if b == "harl":
        cls = {
            "no noise": HARL,
            "noise": _HARLNoise,
            "delay": _HARLDelay,
            "combo": _HARLCombo,
        }[e]
        return Monitor(cls(layout))
    
    if b == "sp_ppo":
        return Monitor(SP_PPO(layout))
    
    # Baseline, CC_PPO, PBT_PPO
    return make_env(env_name, layout)

print("Factories defined!")

## Training Callbacks

In [None]:
# ==========================================
# CALLBACKS
# ==========================================

class StopTrainingOnMaxSteps(BaseCallback):
    """Hard stop when num_timesteps reaches max_steps."""
    def __init__(self, max_steps: int, verbose: int = 0):
        super().__init__(verbose=verbose)
        self.max_steps = max_steps

    def _on_step(self) -> bool:
        if self.num_timesteps >= self.max_steps:
            if self.verbose > 0:
                print(f"Stopping: {self.num_timesteps} >= {self.max_steps}")
            return False
        return True


class PBTCallback(BaseCallback):
    """Population-Based Training: adaptive learning rate."""
    def __init__(self, check_freq=5000, verbose=0):
        super().__init__(verbose)
        self.check_freq = check_freq
        self.best_mean_reward = -np.inf
        self.patience = 0

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:
            if len(self.model.ep_info_buffer) > 0:
                mean_reward = np.mean([ep["r"] for ep in self.model.ep_info_buffer])
            else:
                mean_reward = -np.inf
            
            if mean_reward <= self.best_mean_reward + 0.5:
                self.patience += 1
            else:
                self.best_mean_reward = mean_reward
                self.patience = 0
            
            if self.patience >= 2:
                old_lr = self.model.learning_rate
                new_lr = old_lr * np.random.choice([0.8, 1.2])
                self.model.learning_rate = new_lr
                for pg in self.model.policy.optimizer.param_groups:
                    pg["lr"] = new_lr
                self.patience = 0
        return True

print("Callbacks defined!")

## Training Functions

In [None]:
# ==========================================
# TRAINING FUNCTIONS
# ==========================================

def set_global_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    set_random_seed(seed)


def evaluate(agent, env, episodes=EVAL_EPISODES):
    scores = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        ep_r = 0.0
        while not done:
            action, _ = agent.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = env.step(action)
            done = term or trunc
            ep_r += float(r)
        scores.append(ep_r)
    return float(np.mean(scores)), float(np.std(scores))


def train_one_run(baseline, env_name, seed, steps_total):
    """Train a single (baseline, env, seed) configuration."""
    set_global_seed(seed)
    label = f"{baseline}|{env_name}|{seed}"
    safe_base, safe_env = baseline.replace(" ", "_"), env_name.replace(" ", "_")
    run_dir = os.path.join(RUNS_DIR, safe_base, safe_env, f"seed_{seed}")
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    train_env = make_train_env(baseline, LAYOUT, env_name)
    eval_env = make_env(env_name, LAYOUT)

    final_path = os.path.join(run_dir, "final_model.zip")
    if os.path.exists(final_path):
        print(f"‚è≠Ô∏è Skipping {label} - already done")
        return

    # Resume from checkpoint if available
    ckpt = None
    if os.path.isdir(ckpt_dir):
        cands = [f for f in os.listdir(ckpt_dir) if f.startswith("ppo_") and f.endswith(".zip")]
        if cands:
            cands.sort(key=lambda x: int(re.search(r"(\d+)", x).group(1)), reverse=True)
            ckpt = os.path.join(ckpt_dir, cands[0])

    if ckpt:
        agent = PPO.load(ckpt, env=train_env, device=DEVICE, verbose=1)
        reset_flag = False
    else:
        agent = PPO(
            "MlpPolicy", train_env,
            n_steps=2048, batch_size=2048, learning_rate=3e-4,
            gamma=0.99, verbose=1, seed=seed, device=DEVICE
        )
        reset_flag = True

    callbacks = [
        CheckpointCallback(save_freq=CHECKPOINT_EVERY_STEPS, save_path=ckpt_dir, name_prefix="ppo"),
        StopTrainingOnMaxSteps(max_steps=steps_total, verbose=1),
    ]
    if baseline == "PBT_PPO":
        callbacks.append(PBTCallback(check_freq=10000))

    t0 = time.time()
    agent.learn(total_timesteps=steps_total, callback=callbacks, reset_num_timesteps=reset_flag)
    minutes = (time.time() - t0) / 60.0

    agent.save(final_path)
    with open(os.path.join(run_dir, "meta.json"), "w") as f:
        json.dump({"baseline": baseline, "env": env_name, "seed": seed, "steps": steps_total}, f)

    mean, std = evaluate(agent, eval_env, episodes=EVAL_EPISODES)
    try:
        with open(RESULTS_CSV, "a", newline="") as f:
            csv.writer(f).writerow([baseline, env_name, seed, "final", round(mean, 4), round(std, 4), round(minutes, 3)])
    except Exception:
        pass

    print(f"‚úÖ [{label}] FINISH in {minutes:.2f} min | eval {mean:.2f}¬±{std:.2f}")
    return (baseline, env_name, seed, mean)

print("Training functions defined!")

## üöÄ Run Training

**Warning**: Full training takes several hours on GPU. Adjust baselines/seeds for quick tests.

In [None]:
# Initialize results CSV
if not os.path.exists(RESULTS_CSV):
    with open(RESULTS_CSV, "w", newline="") as f:
        csv.writer(f).writerow(["baseline", "env", "seed", "phase", "mean_return", "std_dev", "train_minutes"])
    print(f"Created {RESULTS_CSV}")
else:
    print(f"Appending to existing {RESULTS_CSV}")

In [None]:
# ==========================================
# LAUNCH PARALLEL TRAINING
# ==========================================

baselines = ["Baseline", "PPO+LLM", "CC_PPO", "SP_PPO", "HARL", "PBT_PPO"]
env_names = ["No Noise", "Noise", "Delay", "Combo"]

all_jobs = []
for b in baselines:
    steps = BASELINE_STEPS.get(b, 1_000_000)
    for e in env_names:
        for s in SEEDS:
            all_jobs.append((b, e, s, steps))

print(f"Total jobs: {len(all_jobs)}")
print(f"üöÄ Launching {len(all_jobs)} parallel jobs...")

Parallel(n_jobs=20)(
    delayed(train_one_run)(b, e, s, steps)
    for b, e, s, steps in all_jobs
)

print("\nüéâ ALL RUNS COMPLETE.")