# ‚è±Ô∏è Latency Analysis

**Measure inference latency** for each trained model.

## What this measures:
- Time per decision step (model.predict + env.step)
- In milliseconds
- Important for real-time deployment feasibility

‚ö†Ô∏è **Requires trained models from 02_training.ipynb**

In [None]:
import os
import time
import csv
import random
import numpy as np
import gymnasium as gym
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor

from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, OvercookedGridworld
from overcooked_ai_py.mdp.actions import Action

print("Imports loaded!")

## Configuration

In [None]:
# ==========================================
# CONFIGURATION
# ==========================================
RUNS_DIR = "/content/drive/MyDrive/runs"
LAYOUT = "asymmetric_advantages"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

HORIZON = 400
NUM_ACTIONS = len(Action.ALL_ACTIONS)

BASELINES = ["Baseline", "PPO+LLM", "CC_PPO", "SP_PPO", "HARL", "PBT_PPO"]
ENV_NAMES = ["No Noise", "Noise", "Delay", "Combo"]
SEEDS = [1001, 2002, 3003, 4004, 5005]

LATENCY_CSV = "/content/drive/MyDrive/latency_results.csv"

print(f"Device: {DEVICE}")
print(f"Results will be saved to: {LATENCY_CSV}")

## Environment Wrappers

In [None]:
class OCWrapper(gym.Env):
    """True 2-agent Overcooked wrapper."""
    metadata = {"render.modes": []}

    def __init__(self, layout):
        super().__init__()
        mdp = OvercookedGridworld.from_layout_name(layout)
        self.oc = OvercookedEnv.from_mdp(mdp, horizon=HORIZON)
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=o0.flatten().shape, dtype=np.float32
        )
        self.action_space = gym.spaces.MultiDiscrete([NUM_ACTIONS, NUM_ACTIONS])

    def reset(self, seed=None, options=None):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
        self.oc.reset()
        o0, _ = self.oc.featurize_state_mdp(self.oc.state)
        return o0.flatten().astype(np.float32), {}

    def step(self, action):
        a0, a1 = int(action[0]), int(action[1])
        joint = [Action.ALL_ACTIONS[a0], Action.ALL_ACTIONS[a1]]
        state, r, done, info = self.oc.step(joint)
        o0, _ = self.oc.featurize_state_mdp(state)
        return o0.flatten().astype(np.float32), float(r), bool(done), False, info


class OCWrapperNoise(OCWrapper):
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        return obs, r, term, trunc, info


class OCWrapperDelay(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info


class OCWrapperCombo(OCWrapper):
    def __init__(self, layout, noise_prob=0.2, delay_penalty=0.5):
        super().__init__(layout)
        self.noise_prob = noise_prob
        self.delay_penalty = delay_penalty
    def step(self, action):
        obs, r, term, trunc, info = super().step(action)
        obs = (obs + np.random.normal(0, 0.01, size=obs.shape)).astype(np.float32)
        if np.random.rand() < self.noise_prob:
            r -= self.delay_penalty
        return obs, r, term, trunc, info


def make_env(env_name: str, layout: str):
    """Evaluation env factory."""
    e = env_name.lower()
    mapping = {
        "no noise": OCWrapper,
        "noise": OCWrapperNoise,
        "delay": OCWrapperDelay,
        "combo": OCWrapperCombo,
    }
    return Monitor(mapping[e](layout))

print("Environment wrappers defined!")

## Latency Measurement Functions

In [None]:
def measure_step_latency(agent, env, episodes=5):
    """
    Measures average latency per decision step (predict + env.step),
    in milliseconds.
    """
    total_time = 0.0
    total_steps = 0

    for _ in range(episodes):
        obs, _ = env.reset()
        done = False

        while not done:
            t0 = time.perf_counter()

            # Decision + transition
            action, _ = agent.predict(obs, deterministic=True)
            obs, r, term, trunc, _ = env.step(action)

            t1 = time.perf_counter()

            total_time += (t1 - t0)
            total_steps += 1
            done = term or trunc

    if total_steps == 0:
        return float("nan")

    # seconds ‚Üí milliseconds
    return (total_time / total_steps) * 1000.0

print("Measurement function defined!")

In [None]:
def load_and_measure_latency(baseline, env_name, seed, episodes=5):
    """
    Loads a trained PPO model and measures latency.
    """
    safe_base = baseline.replace(" ", "_")
    safe_env = env_name.replace(" ", "_")
    model_path = f"{RUNS_DIR}/{safe_base}/{safe_env}/seed_{seed}/final_model.zip"

    if not os.path.exists(model_path):
        print(f"[WARN] Missing model: {model_path}")
        return None

    env = make_env(env_name, LAYOUT)
    agent = PPO.load(model_path, env=env, device=DEVICE)

    # Warmup passes (stabilize GPU / cache effects)
    for _ in range(10):
        obs, _ = env.reset()
        action, _ = agent.predict(obs, deterministic=True)
        obs, _, term, trunc, _ = env.step(action)
        if term or trunc:
            break

    latency_ms = measure_step_latency(agent, env, episodes=episodes)
    print(f"{baseline} | {env_name} | seed={seed}: {latency_ms:.4f} ms/step")

    return latency_ms

print("Load and measure function defined!")

## üöÄ Run Latency Sweep

In [None]:
def run_latency_sweep():
    """
    For every (baseline, env, seed), measure latency and save results.
    """
    rows = []

    print("=== Measuring per-step latency for all models ===")
    for b in BASELINES:
        for e in ENV_NAMES:
            seed_latencies = []
            for s in SEEDS:
                lat = load_and_measure_latency(b, e, s, episodes=5)
                if lat is not None and not np.isnan(lat):
                    rows.append([b, e, s, lat])
                    seed_latencies.append(lat)

            if seed_latencies:
                mean_lat = float(np.mean(seed_latencies))
                std_lat = float(np.std(seed_latencies))
                print(f"[AGG] {b} | {e}: {mean_lat:.4f} ¬± {std_lat:.4f} ms/step")
                rows.append([b, e, "mean_over_seeds", mean_lat])
                rows.append([b, e, "std_over_seeds", std_lat])

    os.makedirs(os.path.dirname(LATENCY_CSV), exist_ok=True)
    with open(LATENCY_CSV, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["baseline", "env", "seed_or_stat", "latency_ms"])
        writer.writerows(rows)

    print(f"\nüéâ Latency sweep complete. Saved to: {LATENCY_CSV}")

# Run the sweep
run_latency_sweep()

## üìà View Results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv(LATENCY_CSV)

# Filter to mean values only
df_mean = df[df["seed_or_stat"] == "mean_over_seeds"].copy()
df_std = df[df["seed_or_stat"] == "std_over_seeds"].copy()

print("Mean Latency by Baseline and Environment (ms):")
pivot = df_mean.pivot(index="baseline", columns="env", values="latency_ms")
display(pivot.round(4))

In [None]:
# Visualization
fig, ax = plt.subplots(figsize=(12, 6))

# Aggregate across environments
latency_by_baseline = df_mean.groupby("baseline")["latency_ms"].mean().sort_values()

colors = plt.cm.viridis(np.linspace(0, 0.8, len(latency_by_baseline)))
bars = ax.barh(latency_by_baseline.index, latency_by_baseline.values, color=colors)

ax.set_xlabel("Latency (ms/step)")
ax.set_title("Mean Inference Latency by Baseline")

# Add value labels
for bar, val in zip(bars, latency_by_baseline.values):
    ax.text(val + 0.01, bar.get_y() + bar.get_height()/2, 
            f"{val:.3f}", va="center", fontsize=10)

plt.tight_layout()
plt.show()