# Nash Analysis

Best-response Nash gap analysis with MLAM warmup.

In [None]:
import os, time, random, csv
import numpy as np
import torch
import gymnasium as gym
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, OvercookedGridworld
from overcooked_ai_py.mdp.actions import Action

In [None]:
LAYOUT = "asymmetric_advantages"
HORIZON = 400
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEEDS = [1001, 2002, 3003, 4004, 5005]

RUNS_DIR = "/content/drive/MyDrive/runs"
BR_RESULTS_CSV = "/content/drive/MyDrive/br_nash_results.csv"

BR_TRAIN_STEPS = 200_000
BR_EVAL_EPISODES = 20
MLAM_WARMUP_STEPS = 50_000

BASELINES = ["Baseline", "PPO+LLM", "CC_PPO", "SP_PPO", "HARL", "PBT_PPO"]
ENV_NAMES = ["No Noise", "Noise", "Delay", "Combo"]
NUM_ACTIONS = len(Action.ALL_ACTIONS)

In [None]:
class SingleAgentBREnv(gym.Env):
    """Environment for training best response against fixed partner."""
    metadata = {"render.modes": []}

    def __init__(self, layout, fixed_policy, agent_idx):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        self.fixed = fixed_policy
        self.idx = agent_idx
        
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=o0.flatten().shape, dtype=np.float32
        )
        self.action_space = gym.spaces.Discrete(NUM_ACTIONS)

    def reset(self, seed=None, options=None):
        self.oc.reset()
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32), {}

    def step(self, action):
        obs = self._get_obs()
        dummy_joint = np.array([0, 0])
        partner_action, _ = self.fixed.predict(np.concatenate([obs, obs]), deterministic=True)
        partner_a = int(partner_action[1 - self.idx])
        
        if self.idx == 0:
            joint = [Action.ALL_ACTIONS[int(action)], Action.ALL_ACTIONS[partner_a]]
        else:
            joint = [Action.ALL_ACTIONS[partner_a], Action.ALL_ACTIONS[int(action)]]
        
        state, r, done, info = self.oc.step(joint)
        o0, _ = self.oc.featurize_state_mdp(state)
        return o0.flatten().astype(np.float32), float(r), bool(done), False, info

    def _get_obs(self):
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32)

In [None]:
def load_policy(baseline, env_name, seed):
    safe_base = baseline.replace(" ", "_")
    safe_env = env_name.replace(" ", "_")
    path = os.path.join(RUNS_DIR, safe_base, safe_env, f"seed_{seed}", "final_model.zip")
    if os.path.exists(path):
        return PPO.load(path, device=DEVICE)
    return None


def evaluate_joint(policy, env, episodes=20):
    scores = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        ep_r = 0.0
        while not done:
            action, _ = policy.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = env.step(action)
            done = term or trunc
            ep_r += r
        scores.append(ep_r)
    return np.mean(scores)


def train_best_response(fixed_policy, agent_idx, steps=BR_TRAIN_STEPS):
    env = Monitor(SingleAgentBREnv(LAYOUT, fixed_policy, agent_idx))
    br = PPO("MlpPolicy", env, verbose=0, device=DEVICE)
    br.learn(total_timesteps=steps)
    return br

In [None]:
class OCWrapper(gym.Env):
    metadata = {"render.modes": []}

    def __init__(self, layout):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=o0.flatten().shape, dtype=np.float32
        )
        self.action_space = gym.spaces.MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])

    def reset(self, seed=None, options=None):
        self.oc.reset()
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32), {}

    def step(self, action):
        a0, a1 = int(action[0]), int(action[1])
        joint = [Action.ALL_ACTIONS[a0], Action.ALL_ACTIONS[a1]]
        state, r, done, info = self.oc.step(joint)
        o0, _ = self.oc.featurize_state_mdp(state)
        return o0.flatten().astype(np.float32), float(r), bool(done), False, info

In [None]:
def compute_nash_gap(baseline, env_name, seed):
    policy = load_policy(baseline, env_name, seed)
    if policy is None:
        return None
    
    eval_env = OCWrapper(LAYOUT)
    v_joint = evaluate_joint(policy, eval_env, BR_EVAL_EPISODES)
    
    br0 = train_best_response(policy, agent_idx=0)
    br1 = train_best_response(policy, agent_idx=1)
    
    br0_env = SingleAgentBREnv(LAYOUT, policy, 0)
    br1_env = SingleAgentBREnv(LAYOUT, policy, 1)
    
    v_br0 = np.mean([sum_episode(br0, br0_env) for _ in range(BR_EVAL_EPISODES)])
    v_br1 = np.mean([sum_episode(br1, br1_env) for _ in range(BR_EVAL_EPISODES)])
    
    nash_gap = (v_br0 - v_joint) + (v_br1 - v_joint)
    return nash_gap


def sum_episode(policy, env):
    obs, _ = env.reset()
    done = False
    total = 0.0
    while not done:
        action, _ = policy.predict(obs, deterministic=True)
        obs, r, term, trunc, _ = env.step(action)
        done = term or trunc
        total += r
    return total

In [None]:
results = []

for baseline in tqdm(BASELINES, desc="Baselines"):
    for env_name in ENV_NAMES:
        for seed in SEEDS:
            gap = compute_nash_gap(baseline, env_name, seed)
            if gap is not None:
                results.append({
                    "baseline": baseline,
                    "env": env_name,
                    "seed": seed,
                    "nash_gap": gap
                })

df = pd.DataFrame(results)
df.to_csv(BR_RESULTS_CSV, index=False)

In [None]:
agg = df.groupby(["baseline", "env"]).agg(
    nash_gap_mean=("nash_gap", "mean"),
    nash_gap_std=("nash_gap", "std")
).reset_index()

print(agg.round(2))

In [None]:
pivot = agg.pivot(index="baseline", columns="env", values="nash_gap_mean")

fig, ax = plt.subplots(figsize=(10, 6))
pivot.plot(kind="bar", ax=ax)
ax.set_ylabel("Nash Gap")
ax.set_title("Nash Gap by Baseline and Environment")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()