In [None]:
import pickle
from pathlib import Path

import pandas as pd
import yaml

config_path = Path("./config.yaml")
with config_path.open("r") as f:
    config = yaml.safe_load(f)

base_path = Path(config["base_path"])
runs = []
missing_paths = []

for batch in config["batches"]:
    for experiment in config["experiments"]:
        for trial in config["trials"]:
            stats_path = (
                base_path
                / batch
                / experiment
                / trial
                / "logs"
                / "training_stats_checkpoint.pkl"
            )

            if not stats_path.exists():
                missing_paths.append(str(stats_path))
                continue

            with stats_path.open("rb") as f:
                training_stats = pickle.load(f)

            run_df = pd.DataFrame(
                {
                    "total_steps": training_stats["total_steps"],
                    "reward": training_stats["reward"],
                }
            )
            run_df["batch"] = batch
            run_df["experiment"] = experiment
            run_df["trial"] = trial
            runs.append(run_df)

if not runs:
    raise FileNotFoundError("No training stats files found for the configured runs")

all_rewards_df = pd.concat(runs, ignore_index=True)
reward_summary = (
    all_rewards_df.groupby("total_steps", as_index=False)
    .agg(
        mean_reward=("reward", "mean"),
        sem_reward=(
            "reward",
            lambda s: s.std(ddof=1) / (len(s) ** 0.5) if len(s) > 1 else 0.0,
        ),
        n_runs=("reward", "size"),
    )
    .sort_values("total_steps")
)

print(
    f"Loaded {len(runs)} run(s) across {all_rewards_df['batch'].nunique()} batch(es)."
)
if missing_paths:
    print(f"Skipped {len(missing_paths)} missing run file(s).")

reward_summary



In [None]:
import matplotlib.pyplot as plt

x = reward_summary["total_steps"]
y = reward_summary["mean_reward"]
sem = reward_summary["sem_reward"]

plt.figure(figsize=(10, 5))
plt.plot(x, y, label="Mean reward")
plt.fill_between(x, y - sem, y + sem, alpha=0.2, label="SEM")
plt.xlabel("Total environment steps")
plt.ylabel("Reward")
plt.title("Reward Across Batches (mean Â± SEM)")
plt.grid(True)
plt.legend()
plt.show()

