# Latency Analysis

Measure LLM inference latency per environment step.

In [None]:
import os, time
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
from overcooked_ai_py.mdp.actions import Action

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENCY_CSV = "/content/drive/MyDrive/latency_results.csv"
N_SAMPLES = 200
WARMUP = 10

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-1.3B",
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE).eval()

In [None]:
@torch.no_grad()
def measure_latency(prompt):
    enc = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    
    start = time.perf_counter()
    _ = model(**enc)
    
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    
    return (time.perf_counter() - start) * 1000

In [None]:
test_prompt = "In cooperative cooking, are joint actions 'N' and 'E' helpful? Answer good or bad."

for _ in range(WARMUP):
    measure_latency(test_prompt)

latencies = [measure_latency(test_prompt) for _ in tqdm(range(N_SAMPLES))]

mean_ms = np.mean(latencies)
std_ms = np.std(latencies)
p50 = np.percentile(latencies, 50)
p95 = np.percentile(latencies, 95)
p99 = np.percentile(latencies, 99)

In [None]:
results = {
    "device": DEVICE,
    "n_samples": N_SAMPLES,
    "mean_ms": mean_ms,
    "std_ms": std_ms,
    "p50_ms": p50,
    "p95_ms": p95,
    "p99_ms": p99
}

pd.DataFrame([results]).to_csv(LATENCY_CSV, index=False)

print(f"Mean: {mean_ms:.2f} ms")
print(f"Std:  {std_ms:.2f} ms")
print(f"P50:  {p50:.2f} ms")
print(f"P95:  {p95:.2f} ms")
print(f"P99:  {p99:.2f} ms")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(latencies, bins=30, edgecolor="black", alpha=0.7)
axes[0].axvline(mean_ms, color="red", linestyle="--", label=f"Mean: {mean_ms:.1f}ms")
axes[0].axvline(p95, color="orange", linestyle="--", label=f"P95: {p95:.1f}ms")
axes[0].set_xlabel("Latency (ms)")
axes[0].set_ylabel("Count")
axes[0].legend()

axes[1].plot(latencies, marker=".", alpha=0.5, linewidth=0)
axes[1].axhline(mean_ms, color="red", linestyle="--")
axes[1].set_xlabel("Sample")
axes[1].set_ylabel("Latency (ms)")

plt.tight_layout()
plt.show()