In [None]:
from __future__ import annotations

import json
from pathlib import Path
from statistics import mean
from typing import Any


def _find_project_root() -> Path:
    """Find repo root whether cwd is repo/ or repo/notebooks/."""
    cwd = Path.cwd().resolve()
    for candidate in (cwd, cwd.parent):
        if (candidate / "viz" / "data_archive").exists():
            return candidate
    return cwd


PROJECT_ROOT = _find_project_root()
RUNS_FOLDER = PROJECT_ROOT / "viz" / "data_archive"

# Optional plotting stack (doesn't block the notebook if numpy/pandas are broken)
try:
    import pandas as pd  # type: ignore
    import seaborn as sns  # type: ignore
    import matplotlib.pyplot as plt  # type: ignore
except Exception:  # pragma: no cover
    pd = None  # type: ignore
    sns = None  # type: ignore
    plt = None  # type: ignore

In [None]:
class Benchmark:
    def __init__(
        self,
        runs_folder: str | Path,
        runs_names: list[str],
        compare_metrics: list[tuple[str, str, str]],
    ) -> None:
        self.runs_folder = Path(runs_folder)
        self.runs_names = runs_names
        self.compare_metrics = compare_metrics

    def _discover_files(self, run_name: str) -> list[Path]:
        """Match env name to files inside runs_folder.

        Prefers '*data.json'. Falls back to any '*.json' containing run_name.
        Matching is case-insensitive and substring-based.
        """
        run_key = run_name.lower().strip()
        data_jsons = sorted(self.runs_folder.glob("*data.json"))
        hits = [p for p in data_jsons if run_key in p.name.lower()]
        if hits:
            return hits

        all_jsons = sorted(self.runs_folder.glob("*.json"))
        hits = [p for p in all_jsons if run_key in p.name.lower()]
        return hits

    ALGOS = {"ppo", "sac", "transformer", "ddpg", "rule"}

    @classmethod
    def _infer_varied_role(cls, files: list[Path]) -> str | None:
        """Infer which role is varied by looking at filename tokens.

        Example (monopolistic competition):
        - ..._gov_us_firm_ppo_...
        - ..._gov_us_firm_sac_...
        - ..._gov_us_firm_transformer_...

        Here the role right before the algo token is 'firm'.
        """
        role_algos: dict[str, set[str]] = {}
        for p in files:
            tokens = p.stem.lower().split("_")
            for i, tok in enumerate(tokens):
                if tok in cls.ALGOS and i > 0:
                    role = tokens[i - 1]
                    role_algos.setdefault(role, set()).add(tok)

        if not role_algos:
            return None

        # choose role with most distinct algos (must be >1 to be "varied")
        best_role, best_n = None, 1
        for role, algos in role_algos.items():
            n = len(algos)
            if n > best_n:
                best_role, best_n = role, n
        return best_role

    @classmethod
    def _label_for_file(cls, p: Path, used: set[str], varied_role: str | None) -> str:
        stem = p.stem
        tokens = stem.lower().split("_")

        label: str | None = None

        # If we inferred the varied role, use the token after it as model name.
        if varied_role is not None:
            for i, tok in enumerate(tokens[:-1]):
                if tok == varied_role and tokens[i + 1] in cls.ALGOS:
                    label = tokens[i + 1]
                    break

        # Fallback: common pattern: <env>_..._gov_<ALGO>_...
        if label is None and "_gov_" in stem:
            label = stem.split("_gov_", 1)[1].split("_", 1)[0].lower()

        # Fallback: common pattern: <env>_<ALGO>.json
        if label is None:
            for algo in cls.ALGOS:
                if stem.lower().endswith(f"_{algo}"):
                    label = algo
                    break

        label = (label or stem).strip() or stem

        if label not in used:
            used.add(label)
            return label

        # Disambiguate duplicates
        i = 2
        while f"{label}#{i}" in used:
            i += 1
        label2 = f"{label}#{i}"
        used.add(label2)
        return label2

    @staticmethod
    def _to_number(x: Any) -> float:
        """Convert scalars or nested lists/tuples of scalars to a single float (sum)."""
        if isinstance(x, (list, tuple)):
            if len(x) == 0:
                return 0.0
            if len(x) == 1:
                return Benchmark._to_number(x[0])
            return float(sum(Benchmark._to_number(v) for v in x))
        return float(x)

    @staticmethod
    def _flatten_numbers(x: Any) -> list[float]:
        """Return all numeric leaves from a nested list/tuple structure."""
        if isinstance(x, (list, tuple)):
            out: list[float] = []
            for v in x:
                out.extend(Benchmark._flatten_numbers(v))
            return out
        try:
            return [float(x)]
        except Exception:
            return []

    @staticmethod
    def _agg(values: list[float], how: str) -> float:
        how = how.lower().strip()
        if not values:
            return float("nan")
        if how == "sum":
            return float(sum(values))
        if how in {"avg", "mean"}:
            return float(mean(values))
        if how == "last":
            return float(values[-1])
        raise ValueError(f"Unknown aggregation: {how!r}")

    def _calc_metric(self, run_data: dict[str, Any], metric_key: str, how: str) -> float | None:
        if metric_key not in run_data:
            return None

        raw = run_data[metric_key]
        how_norm = how.lower().strip()

        # For deeply nested arrays (e.g. house_income: years x households x [value])
        # compute mean over ALL numeric leaves.
        if how_norm in {"leaf_avg", "avg_leaf", "mean_leaf"}:
            leaves = self._flatten_numbers(raw)
            return float(mean(leaves)) if leaves else float("nan")

        if how_norm in {"leaf_sum", "sum_leaf"}:
            leaves = self._flatten_numbers(raw)
            return float(sum(leaves)) if leaves else float("nan")

        # list of (possibly list-wrapped) numbers: [[1.0], [0.2], ...]
        if isinstance(raw, list):
            values = [self._to_number(v) for v in raw]
            return self._agg(values, how_norm)

        # scalar
        try:
            return self._to_number(raw)
        except Exception:
            return None

    def run(self) -> dict[str, dict[str, dict[str, float | None]]]:
        """Run metrics for all requested environments.

        Returns: results[env][metric_out_name][label] = value_or_None
        """
        results: dict[str, dict[str, dict[str, float | None]]] = {}

        for env in self.runs_names:
            files = self._discover_files(env)
            results[env] = {out: {} for out, _, _ in self.compare_metrics}

            print("=" * 40)
            print(env)
            if not files:
                print(f"No files matched in {self.runs_folder}")
                continue
            print(f"Matched {len(files)} file(s)")

            varied_role = self._infer_varied_role(files)
            if varied_role is not None:
                print(f"Role: {varied_role}")

            used_labels: set[str] = set()
            for p in files:
                label = self._label_for_file(p, used_labels, varied_role)
                try:
                    run_data = json.loads(p.read_text())
                except Exception as e:
                    print(f"- {label}: failed to read {p.name}: {e}")
                    for out_name, _, _ in self.compare_metrics:
                        results[env][out_name][label] = None
                    continue

                if not isinstance(run_data, dict):
                    print(f"- {label}: unsupported JSON shape ({type(run_data)}) in {p.name}")
                    for out_name, _, _ in self.compare_metrics:
                        results[env][out_name][label] = None
                    continue

                for out_name, metric_key, how in self.compare_metrics:
                    results[env][out_name][label] = self._calc_metric(run_data, metric_key, how)

            # Pretty-print
            for out_name, _, _ in self.compare_metrics:
                print("---------------------")
                print(out_name)
                for label, val in results[env][out_name].items():
                    print(f"{label}: {val if val is not None else 'N/A'}")

        return results

In [None]:
bench = Benchmark(
    runs_folder=RUNS_FOLDER,
    runs_names=[
        "estate_tax",
        "monopolistic_competition",
        "inflation_control",
        "optimal_monetary_gov",
        "optimal_monetary_bank",
        "optimal_tax",
        "oligopoly",
        "delayed_retirement",
        "work_hard",
        "work_life_well_being",
        "pension_gap",
        "universal_basic_income",
    ],
    compare_metrics=[
        ("accumulative_gov_reward", "gov_reward", "sum"),
        ("accumulative_firm_reward", "firm_reward", "sum"),
        ("accumulative_bank_reward", "bank_reward", "sum"),
        ("social_welfare_avg", "social_welfare", "avg"),
        ("wealth_gini_avg", "wealth_gini", "avg"),
        ("gdp", "GDP", "avg"),
        # house_income format: years x households x [value]
        ("house_income_avg", "house_income", "leaf_avg"),
    ],
)

In [None]:
results = bench.run()