In [1]:
%matplotlib tk

In [2]:
from pathlib import Path
import pickle
import gc
import re
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import display
import ipywidgets as widgets

from plot_utils import plot_dbg_step, inspect_debug_trace_xy, decode_obs, plot_seeker_trajectory

SCRIPT_DIR = Path.cwd()
RESULTS_DIR = (SCRIPT_DIR.parent / "eval_results").resolve()

assert RESULTS_DIR.exists(), f"Results dir not found: {RESULTS_DIR}"

pkl_files = sorted(RESULTS_DIR.glob("*.pkl"))

print("RESULTS_DIR:", RESULTS_DIR)
print("Found PKLs:", len(pkl_files))
pkl_files[:5]


RESULTS_DIR: C:\Users\timok\Desktop\rl_challenge\rl-comptetition\eval_results
Found PKLs: 50


[WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000020000.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000040000.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000060000.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000080000.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000100000.pkl')]

In [3]:
SLIM_DIR = RESULTS_DIR / "slim"
SLIM_DIR.mkdir(exist_ok=True)

GOAL_R = 100.0
CRASH_R = -100.0
TOL = 1e-3

num_obstacles = 10

def slim_path_for(p: Path) -> Path:
    return SLIM_DIR / (p.stem + "_slim.pkl")

def load_eval_pkl(p: Path) -> dict:
    with open(p, "rb") as f:
        return pickle.load(f)

def make_slim_trace(tr: dict, *, num_obstacles: int) -> dict:
    """
    Keeps only lightweight info from one seed trace:
      - rewards
      - agent_pos (2D or 3D depending on dim)
      - goal_pos, obstacles, coin_pos, dim
      - terminal flags, ep_len, total_return, last_reward
    """
    states = tr["states"]
    rewards = [float(r) for r in tr["rewards"]]

    # decode first state for static info
    agent0, goal, obstacles, coin, dim = decode_obs(np.asarray(states[0]), num_obstacles=num_obstacles)

    dim = int(dim)  # 2 or 3
    goal_pos = [float(goal[i]) for i in range(dim)]

    obstacles_arr = np.asarray(obstacles, dtype=float)
    # store x,y,(z),r (assumes radius is last column)
    if obstacles_arr.ndim == 2 and obstacles_arr.shape[1] >= dim + 1:
        # first dim coords + last col as radius
        obstacles_slim = obstacles_arr[:, list(range(dim)) + [-1]].tolist()
    else:
        obstacles_slim = obstacles_arr.tolist()

    coin_pos = None
    if coin is not None:
        coin_pos = [float(coin[i]) for i in range(dim)]

    # agent trajectory (len = len(states) = steps+1)
    agent_pos = []
    for s in states:
        a, _, _, _, _ = decode_obs(np.asarray(s), num_obstacles=num_obstacles)
        agent_pos.append([float(a[i]) for i in range(dim)])

    ep_len = len(rewards)  # number of actions taken
    total_return = float(sum(rewards))
    last_reward = float(rewards[-1]) if rewards else 0.0

    terminated_goal = abs(last_reward - GOAL_R) <= TOL
    terminated_crash = abs(last_reward - CRASH_R) <= TOL
    terminated_timeout = (not terminated_goal) and (not terminated_crash)

    # coin heuristics from total return (as you described)
    got_coin = False
    if abs(total_return - 200.0) <= TOL:
        got_coin = True
    elif abs(total_return - 100.0) <= TOL and not terminated_goal:
        got_coin = True  # coin + timeout
    elif abs(total_return - 0.0) <= TOL and terminated_crash:
        got_coin = True  # coin + crash (-100 + 100)

    return {
        "seed": int(tr["seed"]),
        "dim": dim,

        "ep_len": int(ep_len),
        "total_return": float(total_return),
        "last_reward": float(last_reward),

        "terminated_goal": bool(terminated_goal),
        "terminated_crash": bool(terminated_crash),
        "terminated_timeout": bool(terminated_timeout),
        "got_coin": bool(got_coin),

        "rewards": rewards,
        "agent_pos": agent_pos,

        # static per episode
        "goal_pos": goal_pos,
        "obstacles": obstacles_slim,
        "coin_pos": coin_pos,
    }

def make_slim(data: dict, *, num_obstacles: int) -> dict:
    slim_traces = {}
    for seed, tr in data["traces"].items():
        slim_traces[int(seed)] = make_slim_trace(tr, num_obstacles=num_obstacles)

    return {
        "run": data["run"],
        "checkpoint": data["checkpoint"],
        "stats": data.get("stats", None),     # optional, small
        "traces": slim_traces,
    }


In [4]:
expected = [slim_path_for(p) for p in pkl_files]
existing = [p for p in expected if p.exists()]
missing = [pkl_files[i] for i, slim_p in enumerate(expected) if not slim_p.exists()]

print(f"Slim dir: {SLIM_DIR}")
print(f"Expected slim files: {len(expected)}")
print(f"Existing slim files: {len(existing)}")
print(f"Missing slim files:  {len(missing)}")

# Only process missing ones
for idx, p in enumerate(missing, 1):
    slim_p = slim_path_for(p)
    print(f"[{idx}/{len(missing)}] slim -> {slim_p.name}")

    data = load_eval_pkl(p)
    slim = make_slim(data, num_obstacles=num_obstacles)

    with open(slim_p, "wb") as f:
        pickle.dump(slim, f, protocol=pickle.HIGHEST_PROTOCOL)

    # free memory
    del data, slim
    gc.collect()

print("Done.")


Slim dir: C:\Users\timok\Desktop\rl_challenge\rl-comptetition\eval_results\slim
Expected slim files: 50
Existing slim files: 50
Missing slim files:  0
Done.


In [5]:
slim_files = sorted(SLIM_DIR.glob("*_slim.pkl"))
print("Slim files:", len(slim_files))
slim_files[:5]


Slim files: 50


[WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/slim/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000020000_slim.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/slim/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000040000_slim.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/slim/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000060000_slim.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/slim/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000080000_slim.pkl'),
 WindowsPath('C:/Users/timok/Desktop/rl_challenge/rl-comptetition/eval_results/slim/ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000100000_slim.pkl')]

In [6]:
with open(slim_files[0], "rb") as f:
    slim0 = pickle.load(f)

print("run:", slim0["run"])
print("checkpoint:", slim0["checkpoint"])
some_seed = next(iter(slim0["traces"].keys()))
print("example seed:", some_seed)
print("keys:", slim0["traces"][some_seed].keys())
print("dim:", slim0["traces"][some_seed]["dim"])
print("len(agent_pos):", len(slim0["traces"][some_seed]["agent_pos"]))
print("len(rewards):", len(slim0["traces"][some_seed]["rewards"]))


run: ckpt_3D_exploration_AC_stepwise_base
checkpoint: ckpt_step_000020000.pt
example seed: 1000
keys: dict_keys(['seed', 'dim', 'ep_len', 'total_return', 'last_reward', 'terminated_goal', 'terminated_crash', 'terminated_timeout', 'got_coin', 'rewards', 'agent_pos', 'goal_pos', 'obstacles', 'coin_pos'])
dim: 3
len(agent_pos): 151
len(rewards): 150


In [7]:
# stats[run][checkpoint] = {...metrics...}
stats = {}

for p in slim_files:
    with open(p, "rb") as f:
        d = pickle.load(f)

    run = d["run"]
    ckpt = d["checkpoint"]
    traces = d["traces"]   # seed -> slim trace

    seeds = list(traces.keys())
    n = len(seeds)

    totals = []
    lengths = []
    last_rs = []

    n_goal = 0
    n_crash = 0
    n_timeout = 0
    n_coin = 0

    for seed in seeds:
        tr = traces[seed]
        totals.append(tr["total_return"])
        lengths.append(tr["ep_len"])
        last_rs.append(tr["last_reward"])

        n_goal    += int(tr["terminated_goal"])
        n_crash   += int(tr["terminated_crash"])
        n_timeout += int(tr["terminated_timeout"])
        n_coin    += int(tr["got_coin"])

    # core metric you had before
    count_non_zero_total_return = sum(r != 0.0 for r in totals)

    # mean/std without numpy dependency if you want; numpy is fine though:
    totals_arr = np.asarray(totals, dtype=float)
    lens_arr = np.asarray(lengths, dtype=float)

    stats.setdefault(run, {})[ckpt] = {
        "file": str(p),
        "n_episodes": n,

        "count_non_zero_total_return": int(count_non_zero_total_return),

        "counts": {
            "goal": int(n_goal),
            "crash": int(n_crash),
            "timeout": int(n_timeout),
            "coin": int(n_coin),
        },
        "rates": {
            "goal_rate": float(n_goal / n) if n else 0.0,
            "crash_rate": float(n_crash / n) if n else 0.0,
            "timeout_rate": float(n_timeout / n) if n else 0.0,
            "coin_rate": float(n_coin / n) if n else 0.0,
        },
        "return_stats": {
            "mean": float(totals_arr.mean()) if n else 0.0,
            "std": float(totals_arr.std()) if n else 0.0,
            "min": float(totals_arr.min()) if n else 0.0,
            "max": float(totals_arr.max()) if n else 0.0,
        },
        "length_stats": {
            "mean": float(lens_arr.mean()) if n else 0.0,
            "min": int(lens_arr.min()) if n else 0,
            "max": int(lens_arr.max()) if n else 0,
        },
    }

print("Runs:", len(stats))
print("Example run keys:", list(stats.keys())[:5])

Runs: 5
Example run keys: ['ckpt_3D_exploration_AC_stepwise_base', 'ckpt_3D_exploration_AC_stepwise_biggernet', 'ckpt_3D_exploration_AC_stepwise_cpuct05', 'ckpt_3D_exploration_AC_stepwise_gamma095', 'ckpt_3D_exploration_AC_stepwise_numsim400']


In [8]:
import re

def parse_step(ckpt_name: str) -> int:
    m = re.search(r"step_(\d+)", ckpt_name)
    if not m:
        raise ValueError(f"Could not parse step from: {ckpt_name}")
    return int(m.group(1))


In [9]:
best = None  # (value, run, ckpt)

for run, by_ckpt in stats.items():
    for ckpt, d in by_ckpt.items():
        val = d["rates"]["goal_rate"]
        if best is None or val > best[0]:
            best = (val, run, ckpt)

best


(0.06, 'ckpt_3D_exploration_AC_stepwise_base', 'ckpt_step_000140000.pt')

In [10]:
val, run, ckpt = best
d = stats[run][ckpt]
print(f"BEST goal_rate = {val:.3f}")
print("run:", run)
print("ckpt:", ckpt, "step:", parse_step(ckpt))
print("counts:", d["counts"])
print("return mean:", d["return_stats"]["mean"])
print("length mean:", d["length_stats"]["mean"])
print("slim file:", d["file"])


BEST goal_rate = 0.060
run: ckpt_3D_exploration_AC_stepwise_base
ckpt: ckpt_step_000140000.pt step: 140000
counts: {'goal': 3, 'crash': 0, 'timeout': 47, 'coin': 0}
return mean: 6.0
length mean: 145.02
slim file: C:\Users\timok\Desktop\rl_challenge\rl-comptetition\eval_results\slim\ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000140000_slim.pkl


In [11]:
items = []
for run, by_ckpt in stats.items():
    for ckpt, d in by_ckpt.items():
        items.append((
            d["rates"]["goal_rate"],
            d["return_stats"]["mean"],
            run,
            ckpt,
            parse_step(ckpt),
        ))

top10 = sorted(items, reverse=True)[:10]

for rank, (goal_rate, ret_mean, run, ckpt, step) in enumerate(top10, 1):
    print(f"{rank:>2}. goal_rate={goal_rate:.3f}  ret_mean={ret_mean:>7.2f}  step={step:<9d}  run={run}  ckpt={ckpt}")


 1. goal_rate=0.060  ret_mean=   6.00  step=40000      run=ckpt_3D_exploration_AC_stepwise_numsim400  ckpt=ckpt_step_000040000.pt
 2. goal_rate=0.060  ret_mean=   6.00  step=20000      run=ckpt_3D_exploration_AC_stepwise_numsim400  ckpt=ckpt_step_000020000.pt
 3. goal_rate=0.060  ret_mean=   6.00  step=140000     run=ckpt_3D_exploration_AC_stepwise_base  ckpt=ckpt_step_000140000.pt
 4. goal_rate=0.040  ret_mean=   4.00  step=80000      run=ckpt_3D_exploration_AC_stepwise_gamma095  ckpt=ckpt_step_000080000.pt
 5. goal_rate=0.040  ret_mean=   4.00  step=160000     run=ckpt_3D_exploration_AC_stepwise_base  ckpt=ckpt_step_000160000.pt
 6. goal_rate=0.040  ret_mean=   4.00  step=120000     run=ckpt_3D_exploration_AC_stepwise_base  ckpt=ckpt_step_000120000.pt
 7. goal_rate=0.040  ret_mean=   4.00  step=20000      run=ckpt_3D_exploration_AC_stepwise_base  ckpt=ckpt_step_000020000.pt
 8. goal_rate=0.020  ret_mean=   2.00  step=180000     run=ckpt_3D_exploration_AC_stepwise_numsim400  ckpt=ckpt

In [12]:
best_per_run = {}

for run, by_ckpt in stats.items():
    best_ckpt = None
    best_val = None
    for ckpt, d in by_ckpt.items():
        val = d["rates"]["goal_rate"]  # change metric here
        if best_val is None or val > best_val:
            best_val = val
            best_ckpt = ckpt

    best_per_run[run] = (best_val, best_ckpt, parse_step(best_ckpt))

# print sorted by best_val
for run, (val, ckpt, step) in sorted(best_per_run.items(), key=lambda kv: kv[1][0], reverse=True):
    print(f"run={run:40s}  best_goal_rate={val:.3f}  step={step:<9d}  ckpt={ckpt}")


run=ckpt_3D_exploration_AC_stepwise_base      best_goal_rate=0.060  step=140000     ckpt=ckpt_step_000140000.pt
run=ckpt_3D_exploration_AC_stepwise_numsim400  best_goal_rate=0.060  step=20000      ckpt=ckpt_step_000020000.pt
run=ckpt_3D_exploration_AC_stepwise_gamma095  best_goal_rate=0.040  step=80000      ckpt=ckpt_step_000080000.pt
run=ckpt_3D_exploration_AC_stepwise_biggernet  best_goal_rate=0.020  step=40000      ckpt=ckpt_step_000040000.pt
run=ckpt_3D_exploration_AC_stepwise_cpuct05   best_goal_rate=0.020  step=20000      ckpt=ckpt_step_000020000.pt


In [13]:
run_summary = {}  # run -> aggregated metrics over ckpts

for run, by_ckpt in stats.items():
    goal_rates = [d["rates"]["goal_rate"] for d in by_ckpt.values()]
    crash_rates = [d["rates"]["crash_rate"] for d in by_ckpt.values()]
    coin_rates = [d["rates"]["coin_rate"] for d in by_ckpt.values()]
    ret_means = [d["return_stats"]["mean"] for d in by_ckpt.values()]

    n_ckpt = len(goal_rates)
    if n_ckpt == 0:
        continue

    run_summary[run] = {
        "n_ckpt": n_ckpt,
        "goal_rate_avg": sum(goal_rates) / n_ckpt,
        "crash_rate_avg": sum(crash_rates) / n_ckpt,
        "coin_rate_avg": sum(coin_rates) / n_ckpt,
        "return_mean_avg": sum(ret_means) / n_ckpt,
        "goal_rate_min": min(goal_rates),
        "goal_rate_max": max(goal_rates),
    }

print("Runs summarized:", len(run_summary))
list(run_summary.items())[:1]


Runs summarized: 5


[('ckpt_3D_exploration_AC_stepwise_base',
  {'n_ckpt': 10,
   'goal_rate_avg': 0.026000000000000002,
   'crash_rate_avg': 0.0,
   'coin_rate_avg': 0.0,
   'return_mean_avg': 2.6,
   'goal_rate_min': 0.0,
   'goal_rate_max': 0.06})]

In [14]:
worst10 = sorted(
    ((d["goal_rate_avg"], run, d["n_ckpt"], d["goal_rate_min"], d["goal_rate_max"]) for run, d in run_summary.items()),
    key=lambda x: x[0]
)[:10]

for rank, (avg_goal, run, n_ckpt, min_goal, max_goal) in enumerate(worst10, 1):
    print(f"{rank:>2}. avg_goal={avg_goal:.3f}  min={min_goal:.3f}  max={max_goal:.3f}  n_ckpt={n_ckpt:<3d}  run={run}")


 1. avg_goal=0.004  min=0.000  max=0.020  n_ckpt=10   run=ckpt_3D_exploration_AC_stepwise_biggernet
 2. avg_goal=0.012  min=0.000  max=0.020  n_ckpt=10   run=ckpt_3D_exploration_AC_stepwise_cpuct05
 3. avg_goal=0.014  min=0.000  max=0.040  n_ckpt=10   run=ckpt_3D_exploration_AC_stepwise_gamma095
 4. avg_goal=0.020  min=0.000  max=0.060  n_ckpt=10   run=ckpt_3D_exploration_AC_stepwise_numsim400
 5. avg_goal=0.026  min=0.000  max=0.060  n_ckpt=10   run=ckpt_3D_exploration_AC_stepwise_base


In [15]:
best = max(run_summary.items(), key=lambda kv: kv[1]["goal_rate_avg"])
best_run, d = best

print("BEST RUN (avg over ckpts):")
print(" run:", best_run)
print(" n_ckpt:", d["n_ckpt"])
print(" goal_rate_avg:", round(d["goal_rate_avg"], 4))
print(" goal_rate_min/max:", round(d["goal_rate_min"], 4), "/", round(d["goal_rate_max"], 4))
print(" crash_rate_avg:", round(d["crash_rate_avg"], 4))
print(" coin_rate_avg:", round(d["coin_rate_avg"], 4))
print(" return_mean_avg:", round(d["return_mean_avg"], 4))


BEST RUN (avg over ckpts):
 run: ckpt_3D_exploration_AC_stepwise_base
 n_ckpt: 10
 goal_rate_avg: 0.026
 goal_rate_min/max: 0.0 / 0.06
 crash_rate_avg: 0.0
 coin_rate_avg: 0.0
 return_mean_avg: 2.6


In [16]:
coin_hits = []

for run, by_ckpt in stats.items():
    for ckpt, d in by_ckpt.items():
        n_coin = d["counts"]["coin"]
        if n_coin > 0:
            coin_hits.append((n_coin, d["rates"]["coin_rate"], run, ckpt, parse_step(ckpt)))

coin_hits_sorted = sorted(coin_hits, reverse=True)

print("Total checkpoints with any coin:", len(coin_hits_sorted))

# show top 20 by #coin episodes
for rank, (n_coin, coin_rate, run, ckpt, step) in enumerate(coin_hits_sorted[:20], 1):
    print(f"{rank:>2}. coins={n_coin:<3d}  coin_rate={coin_rate:.3f}  step={step:<9d}  run={run}  ckpt={ckpt}")


Total checkpoints with any coin: 0


In [17]:
if not coin_hits_sorted:
    print("No coin was collected in any run/checkpoint.")
else:
    n_coin, coin_rate, run, ckpt, step = coin_hits_sorted[0]
    print("Coin WAS collected.")
    print("Best coin checkpoint (most coin episodes):")
    print(" run:", run)
    print(" ckpt:", ckpt, "step:", step)
    print(" coins:", n_coin, "coin_rate:", coin_rate)


No coin was collected in any run/checkpoint.


In [18]:
crash_hits = []

for run, by_ckpt in stats.items():
    for ckpt, d in by_ckpt.items():
        n_crash = d["counts"]["crash"]
        if n_crash > 0:
            crash_hits.append((n_crash, d["rates"]["crash_rate"], run, ckpt, parse_step(ckpt)))

crash_hits_sorted = sorted(crash_hits, reverse=True)

print("Total checkpoints with any crash:", len(crash_hits_sorted))


Total checkpoints with any crash: 0


In [19]:
for rank, (n_crash, crash_rate, run, ckpt, step) in enumerate(crash_hits_sorted[:20], 1):
    print(f"{rank:>2}. crashes={n_crash:<3d}  crash_rate={crash_rate:.3f}  step={step:<9d}  run={run}  ckpt={ckpt}")


In [20]:
if not crash_hits_sorted:
    print("No crashes occurred in any run/checkpoint.")
else:
    n_crash, crash_rate, run, ckpt, step = crash_hits_sorted[0]
    print("Crashes DID occur.")
    print("Worst crash checkpoint (most crash episodes):")
    print(" run:", run)
    print(" ckpt:", ckpt, "step:", step)
    print(" crashes:", n_crash, "crash_rate:", crash_rate)


No crashes occurred in any run/checkpoint.


In [21]:
best

('ckpt_3D_exploration_AC_stepwise_base',
 {'n_ckpt': 10,
  'goal_rate_avg': 0.026000000000000002,
  'crash_rate_avg': 0.0,
  'coin_rate_avg': 0.0,
  'return_mean_avg': 2.6,
  'goal_rate_min': 0.0,
  'goal_rate_max': 0.06})

In [22]:
best_run = best[0]

# last checkpoint name (by step parsed from ckpt filename)
ckpt_last = max(stats[best_run].keys(), key=parse_step)

# slim path (stored in stats)
slim_path = Path(stats[best_run][ckpt_last]["file"])

# big path (same stem, without "_slim")
big_path = slim_path.parent.parent / (slim_path.name.replace("_slim.pkl", ".pkl"))

print("Best run:", best_run)
print("Last checkpoint:", ckpt_last)
print("BIG PKL:", big_path)

# load slim + list goal seeds
with open(slim_path, "rb") as f:
    slim = pickle.load(f)

goal_seeds = [seed for seed, tr in slim["traces"].items() if tr.get("terminated_goal", False)]
print("Goal-reaching seeds in last checkpoint:", goal_seeds)


Best run: ckpt_3D_exploration_AC_stepwise_base
Last checkpoint: ckpt_step_000200000.pt
BIG PKL: C:\Users\timok\Desktop\rl_challenge\rl-comptetition\eval_results\ckpt_3D_exploration_AC_stepwise_base_ckpt_step_000200000.pkl
Goal-reaching seeds in last checkpoint: [1043]


In [23]:
def max_tree_depth(root) -> int:
    """
    Max depth in edges from this root to deepest descendant.
    Depth(root)=0 if no children.
    """
    if root is None:
        return 0

    max_d = 0
    stack = [(root, 0)]
    seen = set()  # avoid infinite loops if graph isn't a strict tree

    while stack:
        node, d = stack.pop()
        nid = id(node)
        if nid in seen:
            continue
        seen.add(nid)

        if d > max_d:
            max_d = d

        for ch in getattr(node, "children", []):
            child = getattr(ch, "child_node", None)
            if child is not None:
                stack.append((child, d + 1))

    return max_d


def max_depth_over_pickle(big_pkl_path: str | Path):
    big_pkl_path = Path(big_pkl_path)

    with open(big_pkl_path, "rb") as f:
        data = pickle.load(f)

    traces = data["traces"]

    per_seed_max = {}
    global_max = 0

    for seed, tr in traces.items():
        roots = tr.get("roots", [])
        seed_max = 0

        for root in roots:
            d = max_tree_depth(root)
            if d > seed_max:
                seed_max = d

        per_seed_max[seed] = seed_max
        if seed_max > global_max:
            global_max = seed_max

    return global_max, per_seed_max

In [25]:
global_max_depth, per_seed = max_depth_over_pickle(big_path)
print("Global max depth:", global_max_depth)

# top 10 seeds by max depth
top10 = sorted(per_seed.items(), key=lambda kv: kv[1], reverse=True)[:10]
print("Top 10 seeds:", top10)

Global max depth: 8
Top 10 seeds: [(1000, 8), (1001, 8), (1007, 8), (1008, 8), (1010, 8), (1017, 8), (1027, 8), (1032, 8), (1040, 8), (1049, 8)]
