In [17]:
import json
import re

import numpy as np
import pandas as pd

In [18]:
METRICS = [
    "agent_total_tokens",
    "agent_completion_tokens",
    "agent_prompt_tokens",
    "agent_llm_calls",
    "num_user_turns",
    "num_tool_calls",
    "agent_active_time",
    "duration",
    "agent_cost",
]

STAT_COLS = ["n", "min", "mean", "p50", "p90", "p95", "p99", "max", "std"]


def _cat(task_id):
    m = re.match(r"\[([^\]]+)\]", task_id)
    return m.group(1) if m else "unknown"


def _n_actions(task):
    return len((task.get("evaluation_criteria") or {}).get("actions") or [])


def _success(sim):
    ri = sim.get("reward_info")
    return ri is not None and abs(ri.get("reward", 0) - 1.0) < 1e-6


def _reward(sim):
    ri = sim.get("reward_info")
    return ri.get("reward", 0) if ri else 0.0


def _stats_series(vals):
    if not vals:
        return {k: np.nan for k in STAT_COLS}
    a = np.array(vals, dtype=float)
    return {
        "n": len(a),
        "min": np.min(a),
        "mean": np.mean(a),
        "p50": np.median(a),
        "p90": np.percentile(a, 90),
        "p95": np.percentile(a, 95),
        "p99": np.percentile(a, 99),
        "max": np.max(a),
        "std": np.std(a, ddof=1) if len(a) > 1 else 0.0,
    }


def _flatten_rt(series):
    flat = []
    for rt in series.dropna():
        if isinstance(rt, list):
            flat.extend(rt)
    return flat


def _compute_stats_df(df, metrics=None):
    if metrics is None:
        metrics = [c for c in METRICS if c in df.columns]
    rows = []
    for m in metrics:
        s = _stats_series(df[m].dropna().tolist())
        s["metric"] = m
        rows.append(s)
    if "agent_response_times" in df.columns:
        flat = _flatten_rt(df["agent_response_times"])
        if flat:
            s = _stats_series(flat)
            s["metric"] = "agent_response_time"
            rows.append(s)
    return pd.DataFrame(rows).set_index("metric")[STAT_COLS]


def _grouped_stats_df(df, group_col, metrics=None):
    if metrics is None:
        metrics = [c for c in METRICS if c in df.columns]
    rows = []
    for gval, gdf in df.groupby(group_col):
        for m in metrics:
            s = _stats_series(gdf[m].dropna().tolist())
            s["group"] = gval
            s["metric"] = m
            rows.append(s)
        if "agent_response_times" in gdf.columns:
            flat = _flatten_rt(gdf["agent_response_times"])
            if flat:
                s = _stats_series(flat)
                s["group"] = gval
                s["metric"] = "agent_response_time"
                rows.append(s)
        if "success" in gdf.columns:
            n = len(gdf)
            p = gdf["success"].sum()
            rows.append(
                {
                    "group": gval,
                    "metric": "success_rate",
                    "n": n,
                    "mean": p / n if n else np.nan,
                    **{k: np.nan for k in STAT_COLS if k not in ["n", "mean"]},
                }
            )
    out = pd.DataFrame(rows).set_index(["group", "metric"])
    return out[STAT_COLS]


class RunAnalysis:
    def __init__(self, info: dict, sims: pd.DataFrame):
        self.info = info
        self.sims = sims

    def __repr__(self):
        ai = self.info.get("agent_info", {})
        domain = self.info.get("environment_info", {}).get("domain_name", "?")
        n = len(self.sims)
        p = int(self.sims["success"].sum()) if "success" in self.sims else "?"
        return f"RunAnalysis({ai.get('llm', '?')}, domain={domain}, n={n}, pass={p})"

    def overall(self):
        return _compute_stats_df(self.sims)

    def by(self, col):
        return _grouped_stats_df(self.sims, col)

    def by_category(self):
        return self.by("category")

    def by_outcome(self):
        return self.by("success")

    def by_actions(self):
        return self.by("n_expected_actions")

    def by_termination(self):
        return self.by("termination_reason")


def load_run(path) -> RunAnalysis:
    with open(path) as f:
        data = json.load(f)
    info = data.get("info", {})
    task_map = {t["id"]: t for t in data.get("tasks", [])}
    rows = []
    for sim in data.get("simulations", []):
        task = task_map.get(sim["task_id"], {})
        row = {
            "task_id": sim["task_id"],
            "category": _cat(sim["task_id"]),
            "n_expected_actions": _n_actions(task),
            "success": _success(sim),
            "reward": _reward(sim),
            "termination_reason": sim.get("termination_reason"),
            "trial": sim.get("trial"),
            "seed": sim.get("seed"),
            "agent_response_times": sim.get("agent_response_times"),
        }
        for m in METRICS:
            row[m] = sim.get(m)
        rows.append(row)
    return RunAnalysis(info, pd.DataFrame(rows))

In [19]:
run = load_run(
    "/Users/areremkin/Code/tau2-bench/data/simulations/2026-02-09T23:10:40.625891_telecom_mas_2_gpt-oss-120b_user_simulator_gpt-oss-120b.json"
)

In [20]:
run.sims

Unnamed: 0,task_id,category,n_expected_actions,success,reward,termination_reason,trial,seed,agent_response_times,agent_total_tokens,agent_completion_tokens,agent_prompt_tokens,agent_llm_calls,num_user_turns,num_tool_calls,agent_active_time,duration,agent_cost
0,[mobile_data_issue]bad_network_preference|bad_...,mobile_data_issue,3,False,0.0,user_stop,0,626729,"[1.4706675000488758, 2.1628092501778156, 1.746...",36488,543,35945,10,5,2,8.192862,13.801299,0.000773
1,[mobile_data_issue]bad_vpn|data_saver_mode_on|...,mobile_data_issue,3,False,0.0,user_stop,0,626729,"[1.6321470830589533, 2.134185832925141, 1.7710...",36550,640,35910,10,5,2,8.936691,14.163152,0.000782
2,[mobile_data_issue]data_usage_exceeded|user_ab...,mobile_data_issue,2,False,0.0,user_stop,0,626729,"[2.726437416858971, 2.759948165854439, 2.11877...",42464,594,41870,11,5,3,9.972858,15.302671,0.000897
3,[mobile_data_issue]airplane_mode_on|data_saver...,mobile_data_issue,3,False,0.0,user_stop,0,626729,"[2.5679715420119464, 2.1783245000988245, 1.434...",49216,885,48331,13,6,3,12.644782,18.098548,0.001055
4,[mobile_data_issue]airplane_mode_on|bad_networ...,mobile_data_issue,5,False,0.0,user_stop,0,626729,"[1.409363916143775, 2.0630275420844555, 1.4915...",50059,927,49132,14,7,2,13.135080,21.872907,0.001075
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
109,[mms_issue]airplane_mode_on|bad_network_prefer...,mms_issue,10,True,1.0,user_stop,0,626729,"[10.567182707833126, 2.3961115840356797, 8.705...",235030,9072,225958,39,44,46,72.771561,120.675686,0.005426
110,[mms_issue]airplane_mode_on|bad_network_prefer...,mms_issue,12,True,1.0,user_stop,0,626729,"[2.201849792152643, 10.014923583017662, 2.9710...",210644,8060,202584,36,35,34,62.720344,96.909001,0.004858
111,[mms_issue]airplane_mode_on|bad_network_prefer...,mms_issue,11,False,0.0,user_stop,0,626729,"[2.0883687089663, 7.824182708980516, 3.1526584...",312130,13526,298604,46,39,33,94.435746,150.350789,0.007325
112,[mms_issue]airplane_mode_on|bad_network_prefer...,mms_issue,11,True,1.0,user_stop,0,626729,"[1.7083764169365168, 15.809256250038743, 9.352...",332372,13707,318665,48,43,45,91.715043,139.496229,0.007744


In [21]:
run.overall()

Unnamed: 0_level_0,n,min,mean,p50,p90,p95,p99,max,std
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
agent_total_tokens,114,24873.0,110600.763158,86658.0,198085.0,240992.95,329740.54,725204.0,88859.522872
agent_completion_tokens,114,423.0,3654.921053,2583.5,7812.0,9121.7,13683.47,21297.0,3355.613013
agent_prompt_tokens,114,24286.0,106945.842105,83518.5,190564.0,230837.0,316057.07,703907.0,85630.007408
agent_llm_calls,114,6.0,21.517544,18.0,35.7,39.7,47.74,81.0,11.541562
num_user_turns,114,2.0,17.359649,15.0,32.0,36.05,43.87,62.0,11.71802
num_tool_calls,114,1.0,17.192982,16.0,32.7,38.05,45.0,46.0,12.135574
agent_active_time,114,6.006979,34.82813,26.773434,68.294233,82.718282,103.64708,147.067059,25.905697
duration,114,8.382476,55.118449,43.331923,100.13412,124.456193,150.326295,236.958648,40.291259
agent_cost,114,0.000539,0.002504,0.001987,0.004532,0.005632,0.007689,0.016208,0.002037
agent_response_time,785,0.930501,5.033074,3.659357,10.202964,11.767127,16.711792,29.410854,3.722753


In [22]:
run.by_category()

Unnamed: 0_level_0,Unnamed: 1_level_0,n,min,mean,p50,p90,p95,p99,max,std
group,metric,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
mms_issue,agent_total_tokens,49,35964.0,149026.836735,139954.0,238437.4,298693.6,536644.64,725204.0,113064.917567
mms_issue,agent_completion_tokens,49,560.0,5166.122449,4588.0,9420.6,12983.2,17653.8,21297.0,4163.349566
mms_issue,agent_prompt_tokens,49,35404.0,143860.714286,135366.0,228746.0,286252.0,518990.84,703907.0,109076.656528
mms_issue,agent_llm_calls,49,10.0,26.408163,27.0,39.4,45.2,65.16,81.0,13.300875
mms_issue,num_user_turns,49,5.0,23.265306,26.0,35.6,41.4,53.36,62.0,12.982
mms_issue,num_tool_calls,49,1.0,23.142857,26.0,37.6,41.6,45.52,46.0,13.242923
mms_issue,agent_active_time,49,8.48102,46.205428,45.777467,81.38531,95.40075,126.770787,147.067059,29.97531
mms_issue,duration,49,12.566665,73.424668,76.971283,122.835976,145.895917,195.386876,236.958648,45.875915
mms_issue,agent_cost,49,0.000764,0.003394,0.003205,0.005544,0.006969,0.012145,0.016208,0.002583
mms_issue,agent_response_time,416,0.930501,5.397368,4.02203,10.572514,12.68907,16.986209,21.020411,3.915508
