# üìä Nash Gap Analysis

**Compute Nash equilibrium gaps** to measure how close trained policies are to optimal.

## What is Nash Gap?
- Nash gap = Best Response Value - Self-Play Value
- **Lower is better** (0 = Nash equilibrium)
- Measures how much an agent could improve by deviating

‚ö†Ô∏è **Requires trained models from 02_training.ipynb**

In [None]:
# =========================================================
# BR-BASED NASH ANALYSIS (PARALLEL + RESUME + CRASH-SAFE)
# =========================================================

import os, csv, random
import numpy as np
import gymnasium as gym
import torch
from joblib import Parallel, delayed

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed

from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, OvercookedGridworld
from overcooked_ai_py.mdp.actions import Action

print("Imports loaded!")

## Configuration

In [None]:
# ==========================================
# CONFIGURATION
# ==========================================
LAYOUT = "asymmetric_advantages"
HORIZON = 400
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RUNS_DIR = "/content/drive/MyDrive/runs"

NUM_ACTIONS = len(Action.ALL_ACTIONS)

BR_RESULTS_CSV = "/content/drive/MyDrive/br_nash_results.csv"
EVAL_EPISODES = 20
BR_TRAIN_STEPS = 200_000  # Steps to train best response

BASELINES = ["Baseline", "PPO+LLM", "CC_PPO", "SP_PPO", "HARL", "PBT_PPO"]
ENV_NAMES = ["No Noise", "Noise", "Delay", "Combo"]
SEEDS = [1001, 2002, 3003, 4004, 5005]

print(f"Device: {DEVICE}")
print(f"Results will be saved to: {BR_RESULTS_CSV}")

## MLAM Warmup (Fixes Pickle Crash)

In [None]:
def warmup_mlam():
    """
    Ensures MediumLevelActionManager pickle is built ONCE
    on main process before workers spawn.
    """
    print("Prewarming MLAM planner...")
    mdp = OvercookedGridworld.from_layout_name(LAYOUT)
    env = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
    _ = env.featurize_state_mdp(env.state)  # Forces MLAM creation
    print("‚úÖ MLAM prewarm complete.")

# Run warmup
warmup_mlam()

## Environment Wrappers

In [None]:
class OCWrapper(gym.Env):
    def __init__(self, layout):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.obs_shape = o0.flatten().shape
        self.observation_space = gym.spaces.Box(-np.inf, np.inf, self.obs_shape, dtype=np.float32)
        self.action_space = gym.spaces.MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])

    def _obs(self):
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32)

    def reset(self, seed=None, options=None):
        if seed is not None:
            random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
            if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
        self.oc.reset()
        return self._obs(), {}

    def step(self, action):
        a0, a1 = int(action[0]), int(action[1])
        joint = [Action.ALL_ACTIONS[a0], Action.ALL_ACTIONS[a1]]
        _, r, done, info = self.oc.step(joint)
        return self._obs(), float(r), bool(done), False, info


class OCWrapperNoise(OCWrapper):
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        return obs + np.random.normal(0, 0.01, obs.shape), r, term, trunc, info


class OCWrapperDelay(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        if np.random.rand() < self.noise_prob: r -= self.delay_penalty
        return obs, r, term, trunc, info


class OCWrapperCombo(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = obs + np.random.normal(0, 0.01, obs.shape)
        if np.random.rand() < self.noise_prob: r -= self.delay_penalty
        return obs, r, term, trunc, info


def make_env(env_name, layout):
    m = {
        "no noise": OCWrapper,
        "noise": OCWrapperNoise,
        "delay": OCWrapperDelay,
        "combo": OCWrapperCombo,
    }
    return Monitor(m[env_name.lower()](layout))

print("Environment wrappers defined!")

## Evaluation Functions

In [None]:
def set_global_seed(seed):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    set_random_seed(seed)


def eval_joint_policy(model, env, episodes=EVAL_EPISODES):
    """Evaluate self-play performance."""
    scores = []
    for _ in range(episodes):
        obs, _ = env.reset()
        ep = 0
        done = False
        while not done:
            a, _ = model.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = env.step(a)
            done = term or trunc
            ep += r
        scores.append(ep)
    return float(np.mean(scores)), float(np.std(scores))

print("Evaluation functions defined!")

## Best Response Environment

In [None]:
class OvercookedBREnv(gym.Env):
    """
    Best Response environment:
    - One agent trains while opponent uses fixed policy
    """
    def __init__(self, env_name, layout, opponent_model, agent_idx):
        super().__init__()
        self.agent_idx = agent_idx
        self.opponent_idx = 1 - agent_idx
        self.opponent_model = opponent_model

        self.base_env = make_env(env_name, layout).env
        self.observation_space = self.base_env.observation_space
        self.action_space = gym.spaces.Discrete(NUM_ACTIONS)
        self._last_obs = None

    def reset(self, seed=None, options=None):
        if seed is not None:
            set_global_seed(seed)
        obs, info = self.base_env.reset(seed=seed)
        self._last_obs = obs
        return obs, info

    def step(self, action):
        a_self = int(action)
        opp_joint, _ = self.opponent_model.predict(self._last_obs, deterministic=True)
        a_opp = int(opp_joint[self.opponent_idx])

        joint = np.zeros(2, dtype=np.int64)
        joint[self.agent_idx] = a_self
        joint[self.opponent_idx] = a_opp

        obs, r, term, trunc, info = self.base_env.step(joint)
        self._last_obs = obs
        return obs, float(r), bool(term), bool(trunc), info

print("BR environment defined!")

## Best Response Training

In [None]:
def train_best_response(env_name, layout, opponent_model, agent_idx, seed):
    """Train a best response agent against fixed opponent."""
    set_global_seed(seed)
    br_env = Monitor(OvercookedBREnv(env_name, layout, opponent_model, agent_idx))

    br = PPO("MlpPolicy", br_env,
             n_steps=2048, batch_size=2048,
             learning_rate=3e-4, gamma=0.99,
             verbose=0, device=DEVICE, seed=seed)

    br.learn(total_timesteps=BR_TRAIN_STEPS)

    # Evaluate BR
    scores = []
    for _ in range(EVAL_EPISODES):
        obs, _ = br_env.reset()
        ep = 0
        done = False
        while not done:
            a, _ = br.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = br_env.step(a)
            done = term or trunc
            ep += r
        scores.append(ep)
    return float(np.mean(scores)), float(np.std(scores))

print("BR training function defined!")

## Nash Gap Computation

In [None]:
def compute_nash_gap_for_model(baseline, env_name, seed):
    """Compute Nash gap for a single trained model."""
    safe_base = baseline.replace(" ", "_")
    safe_env = env_name.replace(" ", "_")
    model_path = f"{RUNS_DIR}/{safe_base}/{safe_env}/seed_{seed}/final_model.zip"

    if not os.path.exists(model_path):
        print(f"‚ùå Missing: {model_path}")
        return None

    model = PPO.load(model_path, device=DEVICE)

    # Self-play value
    env_self = make_env(env_name, LAYOUT)
    v_self_m, v_self_s = eval_joint_policy(model, env_self)

    # Best response value
    v_br_m, v_br_s = train_best_response(env_name, LAYOUT, model, 0, seed+999)

    delta = v_br_m - v_self_m
    print(f"‚úÖ {baseline}|{env_name}|{seed}: V_self={v_self_m:.2f}, V_BR={v_br_m:.2f}, Œî={delta:.2f}")

    return {
        "baseline": baseline,
        "env": env_name,
        "seed": seed,
        "V_self_mean": v_self_m,
        "V_self_std": v_self_s,
        "V_BR_mean": v_br_m,
        "V_BR_std": v_br_s,
        "delta": delta,
    }

print("Nash gap function defined!")

## Resume Logic

In [None]:
def load_completed_set():
    """Load already-computed results to skip."""
    done = set()
    if os.path.exists(BR_RESULTS_CSV):
        with open(BR_RESULTS_CSV) as f:
            next(f)  # Skip header
            for row in csv.reader(f):
                done.add((row[0], row[1], row[2]))
    return done

print("Resume logic defined!")

## üöÄ Run Nash Gap Analysis

In [None]:
# Initialize CSV
if not os.path.exists(BR_RESULTS_CSV):
    with open(BR_RESULTS_CSV, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "baseline", "env", "seed",
            "V_self_mean", "V_self_std",
            "V_BR_mean", "V_BR_std", "delta"
        ])
    print(f"Created {BR_RESULTS_CSV}")

completed = load_completed_set()
print(f"Already completed: {len(completed)} runs")

In [None]:
# Build job list (skip completed)
jobs = [(b, e, s) for b in BASELINES for e in ENV_NAMES for s in SEEDS
        if (b, e, str(s)) not in completed]

print(f"Remaining jobs: {len(jobs)}")

In [None]:
def wrapper(job):
    """Compute and save result for one job."""
    res = compute_nash_gap_for_model(*job)
    if res is None:
        return None

    with open(BR_RESULTS_CSV, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            res["baseline"], res["env"], res["seed"],
            round(res["V_self_mean"], 2),
            round(res["V_self_std"], 2),
            round(res["V_BR_mean"], 2),
            round(res["V_BR_std"], 2),
            round(res["delta"], 2),
        ])
    return res

# Run parallel analysis
print("üöÄ Running Nash gap analysis...")
Parallel(n_jobs=6)(delayed(wrapper)(job) for job in jobs)

print(f"\nüéâ DONE! Results saved to: {BR_RESULTS_CSV}")

## üìà View Results

In [None]:
import pandas as pd

df = pd.read_csv(BR_RESULTS_CSV)
print(f"Total results: {len(df)}")

# Aggregate by baseline and env
agg = df.groupby(["baseline", "env"]).agg({
    "V_self_mean": "mean",
    "V_BR_mean": "mean",
    "delta": ["mean", "std"]
}).round(2)

print("\nüìä Nash Gap Summary:")
display(agg)