# 07 - LLM Sensitivity Analysis (GPT-Neo 125M)

Trains PPO+LLM with GPT-Neo 125M (smaller model) to study LLM size sensitivity.

**Output:**
- Models saved to `/content/drive/MyDrive/runs_gptneo_125M/`
- Results saved to `/content/drive/MyDrive/results_gptneo_125M.csv`

In [None]:
# =========================================================
# PPO+LLM TRAINING WITH GPT-NEO 125M (ALL FOUR ENV REGIMES)
# - True 2-agent Overcooked (no [a, a] mirroring)
# - Uses LLM shaping
# - Baseline: PPO+LLM only, over No Noise / Noise / Delay / Combo
# - Saves to /content/drive/MyDrive/runs_gptneo_125M and results_gptneo_125M.csv
# =========================================================

import os, re, json, time, csv, random
import numpy as np
import gymnasium as gym
import torch
from joblib import Parallel, delayed

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor

from transformers import AutoTokenizer, AutoModelForCausalLM
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, OvercookedGridworld
from overcooked_ai_py.mdp.actions import Action

# ==========================================
# 1. CONFIGURATION
# ==========================================
LAYOUT = "asymmetric_advantages"
HORIZON = 400
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEEDS = [1001, 2002, 3003, 4004, 5005]

CHECKPOINT_EVERY_STEPS = 50_000
EVAL_EPISODES = 10

# Tag this run as using GPT-Neo 125M
MODEL_TAG = "gptneo_125M"
LLM_MODEL_NAME = "EleutherAI/gpt-neo-125M"

BASE_DIR = "/content/drive/MyDrive"
RUNS_DIR = os.path.join(BASE_DIR, f"runs_{MODEL_TAG}")
RESULTS_CSV = os.path.join(BASE_DIR, f"results_{MODEL_TAG}.csv")

LLM_BONUS = 0.2

BASELINE_STEPS = {
    "PPO+LLM": 600_000,
}

NUM_ACTIONS = len(Action.ALL_ACTIONS)

os.makedirs(RUNS_DIR, exist_ok=True)

# ==========================================
# 2. LLM SETUP (LAZY LOAD)
# ==========================================
_GLOBAL_LLM = None
_GLOBAL_TOK = None

def get_llm():
    global _GLOBAL_LLM, _GLOBAL_TOK
    if _GLOBAL_LLM is None:
        try:
            print(f"Loading LLM: {LLM_MODEL_NAME} on {DEVICE}")
            _GLOBAL_TOK = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
            _GLOBAL_LLM = AutoModelForCausalLM.from_pretrained(
                LLM_MODEL_NAME,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            ).to(DEVICE).eval()
        except Exception as e:
            print(f"Failed to load LLM {LLM_MODEL_NAME}: {e}")
            _GLOBAL_LLM = False
    return _GLOBAL_LLM, _GLOBAL_TOK

@torch.no_grad()
def is_good(prompt: str) -> bool:
    llm, tok = get_llm()
    if not llm:
        return False
    try:
        enc = tok(prompt, return_tensors="pt").to(DEVICE)
        logits = llm(**enc).logits[0, -1]

        def idx(tok_str):
            ids = tok.encode(tok_str, add_special_tokens=False)
            return ids[0] if ids else None

        gi, bi = idx(" good"), idx(" bad")
        if gi is None or bi is None:
            return False
        return logits[gi].item() > logits[bi].item()
    except Exception as e:
        print(f"LLM inference failed: {e}")
        return False

# ==========================================
# 3. ENVIRONMENT WRAPPERS (TRUE 2-AGENT)
# ==========================================
class OCWrapper(gym.Env):
    """
    True 2-agent Overcooked wrapper:
    - observation: global featurized state (flattened)
    - action_space: MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])
    - reward: shared team reward
    """
    metadata = {"render.modes": []}

    def __init__(self, layout):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.observation_space = gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=o0.flatten().shape,
            dtype=np.float32,
        )
        self.action_space = gym.spaces.MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])

    def reset(self, seed=None, options=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        self.oc.reset()
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32), {}

    def step(self, action):
        a0, a1 = int(action[0]), int(action[1])
        joint = [Action.ALL_ACTIONS[a0], Action.ALL_ACTIONS[a1]]
        state, r, done, info = self.oc.step(joint)
        o0, _ = self.oc.featurize_state_mdp(state)
        return o0.flatten().astype(np.float32), float(r), bool(done), False, info

class OCWrapperNoise(OCWrapper):
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        return obs, r, term, trunc, info

class OCWrapperDelay(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info

class OCWrapperCombo(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info

class OCWrapperLLM(OCWrapper):
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        a0, a1 = int(action[0]), int(action[1])
        act0 = Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[a0]]
        act1 = Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[a1]]
        prompt = (
            f"In cooperative cooking, are joint actions '{act0}' and '{act1}' helpful? "
            f"Answer good or bad."
        )
        if is_good(prompt):
            r += LLM_BONUS
        return obs, r, term, trunc, info

class _LLMNoise(OCWrapperLLM, OCWrapperNoise):
    pass

class _LLMDelay(OCWrapperLLM, OCWrapperDelay):
    pass

class _LLMCombo(OCWrapperLLM, OCWrapperCombo):
    pass

# ---------------------------------------------------------
# FACTORIES
# ---------------------------------------------------------
def make_env(env_name, layout):
    e = env_name.lower()
    mapping = {
        "no noise": OCWrapper,
        "noise": OCWrapperNoise,
        "delay": OCWrapperDelay,
        "combo": OCWrapperCombo,
    }
    return Monitor(mapping[e](layout))

def make_train_env(baseline, layout, env_name):
    # Only PPO+LLM is supported in this script
    assert baseline.lower() == "ppo+llm"
    e = env_name.lower()
    cls = {
        "no noise": OCWrapperLLM,
        "noise": _LLMNoise,
        "delay": _LLMDelay,
        "combo": _LLMCombo,
    }[e]
    return Monitor(cls(layout))

# ==========================================
# 4. CALLBACKS
# ==========================================
class StopTrainingOnMaxSteps(BaseCallback):
    def __init__(self, max_steps: int, verbose: int = 0):
        super().__init__(verbose=verbose)
        self.max_steps = max_steps

    def _on_step(self) -> bool:
        if self.num_timesteps >= self.max_steps:
            if self.verbose > 0:
                print(
                    f"Stopping training because the number of steps "
                    f"{self.num_timesteps} reached the limit {self.max_steps}"
                )
            return False
        return True

# ==========================================
# 5. UTILS
# ==========================================
def set_global_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    set_random_seed(seed)

def evaluate(agent, env, episodes=EVAL_EPISODES):
    scores = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        ep_r = 0.0
        while not done:
            action, _ = agent.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = env.step(action)
            done = term or trunc
            ep_r += float(r)
        scores.append(ep_r)
    return float(np.mean(scores)), float(np.std(scores))

# ==========================================
# 6. MAIN TRAINING LOOP
# ==========================================
def train_one_run(baseline, env_name, seed, steps_total):
    set_global_seed(seed)
    label = f"{baseline}|{env_name}|{seed}|{MODEL_TAG}"
    safe_base = baseline.replace(" ", "_")
    safe_env = env_name.replace(" ", "_")
    run_dir = os.path.join(RUNS_DIR, safe_base, safe_env, f"seed_{seed}")
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    train_env = make_train_env(baseline, LAYOUT, env_name)
    eval_env = make_env(env_name, LAYOUT)

    final_path = os.path.join(run_dir, "final_model.zip")
    if os.path.exists(final_path):
        print(f"Skipping {label}, already has final_model.zip")
        return

    # -------------------------------------------------
    # Robust checkpoint resume logic
    # -------------------------------------------------
    agent = None
    reset_flag = True  # default to fresh run

    if os.path.isdir(ckpt_dir):
        cands = [f for f in os.listdir(ckpt_dir) if f.startswith("ppo_") and f.endswith(".zip")]
        if cands:
            # Try newest checkpoints first
            cands.sort(key=lambda x: int(re.search(r"(\d+)", x).group(1)), reverse=True)
            for fname in cands:
                ckpt_path = os.path.join(ckpt_dir, fname)
                try:
                    print(f"Trying checkpoint {ckpt_path} for {label}")
                    agent = PPO.load(ckpt_path, env=train_env, device=DEVICE, verbose=1)
                    reset_flag = False
                    print(f"Resumed {label} from {ckpt_path}")
                    break
                except Exception as e:
                    print(f"Failed to load checkpoint {ckpt_path}: {e}. Removing it.")
                    try:
                        os.remove(ckpt_path)
                    except OSError:
                        pass

    if agent is None:
        print(f"Starting fresh run {label}")
        agent = PPO(
            "MlpPolicy",
            train_env,
            n_steps=2048,
            batch_size=2048,
            learning_rate=3e-4,
            gamma=0.99,
            verbose=1,
            seed=seed,
            device=DEVICE,
        )
        reset_flag = True

    callbacks = [
        CheckpointCallback(save_freq=CHECKPOINT_EVERY_STEPS, save_path=ckpt_dir, name_prefix="ppo"),
        StopTrainingOnMaxSteps(max_steps=steps_total, verbose=1),
    ]

    t0 = time.time()
    agent.learn(
        total_timesteps=steps_total,
        callback=callbacks,
        reset_num_timesteps=reset_flag,
    )
    minutes = (time.time() - t0) / 60.0

    agent.save(final_path)
    with open(os.path.join(run_dir, "meta.json"), "w") as f:
        json.dump(
            {
                "baseline": baseline,
                "env": env_name,
                "seed": seed,
                "steps": steps_total,
                "model_tag": MODEL_TAG,
                "llm_model": LLM_MODEL_NAME,
            },
            f,
        )

    mean, std = evaluate(agent, eval_env, episodes=EVAL_EPISODES)
    with open(RESULTS_CSV, "a", newline="") as f:
        csv.writer(f).writerow(
            [
                baseline,
                env_name,
                seed,
                "final",
                round(mean, 4),
                round(std, 4),
                round(minutes, 3),
                MODEL_TAG,
                LLM_MODEL_NAME,
            ]
        )

    print(f"âœ… [{label}] FINISH in {minutes:.2f} min | eval {mean:.2f}Â±{std:.2f}")
    return (baseline, env_name, seed, mean)

# ==========================================
# 7. PARALLEL EXECUTION ENTRY POINT
# ==========================================
if __name__ == "__main__":
    # Initialize CSV
    if not os.path.exists(RESULTS_CSV):
        with open(RESULTS_CSV, "w", newline="") as f:
            csv.writer(f).writerow(
                [
                    "baseline",
                    "env",
                    "seed",
                    "phase",
                    "mean_return",
                    "std_dev",
                    "train_minutes",
                    "model_tag",
                    "llm_model",
                ]
            )

    baselines = ["PPO+LLM"]
    env_names = ["No Noise", "Noise", "Delay", "Combo"]

    all_jobs = []
    for b in baselines:
        steps = BASELINE_STEPS.get(b, 600_000)
        for e in env_names:
            for s in SEEDS:
                all_jobs.append((b, e, s, steps))

    print(f"ðŸš€ Launching PPO+LLM jobs for {MODEL_TAG} over {len(env_names)} envs and {len(SEEDS)} seeds")
    print(f"Total runs: {len(all_jobs)}")  # should print 20

    Parallel(n_jobs=20)(
        delayed(train_one_run)(b, e, s, steps)
        for b, e, s, steps in all_jobs
    )

    print("\nðŸŽ‰ ALL RUNS COMPLETE FOR GPT-NEO 125M.")