In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import yaml
import pickle
from collections import defaultdict

plt.rcParams.update({'font.size': 16})

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

fig = plt.figure(figsize=(12, 7))
ax = fig.add_subplot(111)

# Group data by experiment first
experiment_data = defaultdict(lambda: defaultdict(list))
reward_name = "dist_rewards"

# First, collect all data points across trials
for batch in config["batches"]:
    for experiment in config["experiments"]:
        exp_key = f"{batch}-{experiment}"
        
        for trial in config["trials"]:
            checkpoint_path = Path(f"{config['base_path']}/{batch}/{experiment}/{trial}/logs/evaluation.dat")

            if checkpoint_path.is_file():
                with open(checkpoint_path, "rb") as handle:
                    data = pickle.load(handle)
                    
                    # Group rewards by n_agents for this experiment
                    for n in data.keys():
                        experiment_data[exp_key][n].extend(data[n][reward_name])
                        
                print(f"Added trial {trial} data to {exp_key}")

# Use a different color for each experiment
experiment_colors = {}
color_idx = 0

# Now plot aggregated data with standard error
for exp_key, agent_data in experiment_data.items():
    # Get consistent color
    if exp_key not in experiment_colors:
        experiment_colors[exp_key] = plt.cm.tab10(color_idx % 10)
        color_idx += 1
    color = experiment_colors[exp_key]
    
    # Get sorted list of agent counts
    n_agents = sorted(agent_data.keys())
    
    # Calculate mean and standard error across all trials
    means = [np.mean(agent_data[n]) for n in n_agents]
    errors = [np.std(agent_data[n]) / np.sqrt(len(agent_data[n])) for n in n_agents]
    
    # Print how many samples we have for each agent count
    print(f"{exp_key}: {[len(agent_data[n]) for n in n_agents]} samples per agent count")
    
    # Plot with error bars
    ax.errorbar(
        n_agents,
        means,
        yerr=errors,
        fmt="o-",
        linewidth=2,
        elinewidth=1,
        markersize=6,
        capsize=5,
        color=color,
        ecolor=color,
        label=exp_key.split("-")[1]  # Just show experiment name without batch
    )

ax.set_xticks(n_agents)
ax.legend(loc='best')
ax.set_xlabel("Number of Salp Units in Chain")
ax.set_ylabel(f"Mean {reward_name.replace('_', ' ').title()}")
ax.set_title("Performance vs. Number of Agents (with SE across trials)")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()