In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import yaml
import pickle
from collections import defaultdict

# Set font properties for entire plot
plt.rcParams.update({
    'font.size': 18,
    'axes.labelsize': 18,
    'axes.titlesize': 18,
    'legend.fontsize': 18,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
})

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

# Determine number of mask files (disabled units configurations)
num_masks = 9

# Group data by batch first, then by experiment, then by n_agents, then by disabled_units
batch_data = {batch: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for batch in config["batches"]}
reward_name = "dist_rewards"

# First, collect all data points across trials
for batch in config["batches"]:
    for experiment in config["experiments"]:
        for trial in config["trials"]:
            evaluation_file_name = f"disabled_mask_eval"
            checkpoint_path = Path(f"{config['base_path']}/{batch}/{experiment}/{trial}/logs/{evaluation_file_name}.dat")

            if checkpoint_path.is_file():
                with open(checkpoint_path, "rb") as handle:
                    data = pickle.load(handle)
                    
                    # data structure: {n_agents: {disabled_units: {'dist_rewards': [...], ...}}}
                    for n_agents in data.keys():
                        for disabled_units in data[n_agents].keys():
                            # Append rewards from this trial to the list
                            batch_data[batch][experiment][n_agents][disabled_units].extend(
                                data[n_agents][disabled_units][reward_name]
                            )
                        
                print(f"Added trial {trial} data to {batch}/{experiment}")

# Extract training sizes from batch names (assuming format like "batch_8" or similar)
batch_training_sizes = {}
for batch in config["batches"]:
    # Try to extract number from batch name, or manually specify
    # Adjust this based on your actual batch naming convention
    if "8" in batch or batch == config["batches"][0]:
        batch_training_sizes[batch] = 8
    elif "16" in batch or batch == config["batches"][1]:
        batch_training_sizes[batch] = 16
    # Add more mappings as needed

# Group by scale factor (evaluation_size / training_size)
scale_factor_data = defaultdict(list)  # {scale_factor: [(batch, n_agents), ...]}

for batch in config["batches"]:
    training_size = batch_training_sizes[batch]
    for experiment in config["experiments"]:
        for n_agents in batch_data[batch][experiment].keys():
            scale_factor = n_agents / training_size
            scale_factor_data[scale_factor].append((batch, n_agents, experiment))

# Sort scale factors
sorted_scale_factors = sorted(scale_factor_data.keys())

# Create a separate plot for each scale factor
for scale_factor in sorted_scale_factors:
    batch_agent_pairs = scale_factor_data[scale_factor]
    
    # Get unique batches for this scale factor
    unique_batches = list(set([batch for batch, _, _ in batch_agent_pairs]))
    
    if not unique_batches:
        continue
    
    # Create a figure with subplots: one column per batch
    fig, axes = plt.subplots(1, len(unique_batches), figsize=(12 * len(unique_batches), 8), sharey=True)
    if len(unique_batches) == 1:
        axes = [axes]
    
    # Use a different color for each experiment
    experiment_colors = {}
    color_idx = 0
    
    # Store line objects and labels for shared legend
    all_lines = []
    all_labels = ["gcn_mixed", "gat_mixed", "graph_transformer_mixed", "gcn_full", "gat_full", "graph_transformer_full"]
    
    # Now plot each batch
    for batch_idx, batch in enumerate(unique_batches):
        ax = axes[batch_idx]
        
        # Get n_agents for this batch at this scale factor
        n_agents = None
        for b, n, _ in batch_agent_pairs:
            if b == batch:
                n_agents = n
                break
        
        if n_agents is None:
            continue
        
        training_size = batch_training_sizes[batch]
        
        # Set subplot title
        ax.set_title(f"Trained on {training_size} â†’ Evaluated on {n_agents} Salp-Units")
        
        # Set y-label only for leftmost plot
        if batch_idx == 0:
            ax.set_ylabel("Normalized Mean Distance Reward")
        
        # First pass: collect all means to find global min/max for normalization
        all_means = []
        experiment_data = {}
        
        for experiment in config["experiments"]:
            # Skip if no data for this experiment and n_agents
            if n_agents not in batch_data[batch][experiment]:
                print(f"No data for {batch}/{experiment}/n_agents={n_agents}")
                continue
            
            # Get sorted list of disabled unit counts
            disabled_units_list = sorted(batch_data[batch][experiment][n_agents].keys())
            
            if not disabled_units_list:
                continue
            
            # Calculate mean and standard error for each disabled_units count
            means = []
            errors = []
            for disabled_units in disabled_units_list:
                rewards = batch_data[batch][experiment][n_agents][disabled_units]
                if rewards:
                    means.append(np.mean(rewards))
                    errors.append(np.std(rewards) / np.sqrt(len(rewards)))
                else:
                    means.append(0)
                    errors.append(0)
            
            experiment_data[experiment] = {
                'disabled_units': disabled_units_list,
                'means': means,
                'errors': errors
            }
            all_means.extend(means)
        
        # Calculate global normalization parameters
        if all_means and max(all_means) != min(all_means):
            min_mean = min(all_means)
            max_mean = max(all_means)
            range_mean = max_mean - min_mean
        else:
            min_mean = 0
            range_mean = 1
        
        # Second pass: normalize and plot
        for experiment in config["experiments"]:
            if experiment not in experiment_data:
                continue
                
            # Get consistent color for this experiment
            if experiment not in experiment_colors:
                experiment_colors[experiment] = plt.cm.tab10(color_idx+1 % 10)
                color_idx += 1
            color = experiment_colors[experiment]
            
            data = experiment_data[experiment]
            disabled_units_list = data['disabled_units']
            means = data['means']
            errors = data['errors']
            
            # Normalize using global min/max
            if range_mean > 0:
                normalized_means = [(m - min_mean) / range_mean for m in means]
                normalized_errors = [e / range_mean for e in errors]
            else:
                normalized_means = means
                normalized_errors = errors
            
            # Print sample counts
            sample_counts = [len(batch_data[batch][experiment][n_agents][d]) for d in disabled_units_list]
            print(f"{batch}/{experiment}/n_agents={n_agents}: {sample_counts} samples per disabled count")
            
            # Plot with error bars
            line = ax.errorbar(
                disabled_units_list,
                normalized_means,
                yerr=normalized_errors,
                fmt="o-",
                linewidth=2,
                elinewidth=1,
                markersize=6,
                capsize=5,
                color=color,
                ecolor=color,
                label=experiment
            )
            
            # Store line and label for shared legend (only from first subplot)
            if batch_idx == 0:
                all_lines.append(line[0])
        
        ax.set_xlabel("Number of Disabled Salp Units")
        if disabled_units_list:
            ax.set_xticks(disabled_units_list)
        
        ax.grid(True, alpha=0.3)

        # Add horizontal line at 0.90
        line_90 = ax.axhline(y=0.90, color='black', linestyle='--', linewidth=2)
        
        # Add legend only for the horizontal line in each subplot
        ax.legend([line_90], ['90% performance'], loc='lower right', fontsize=16)

        
    
    # Create a single legend for the entire figure (experiments only)
    fig.legend(all_lines, all_labels, 
               loc='lower center',
               bbox_to_anchor=(0.5, -0.08),
               ncol=len(all_labels),
               fontsize=18,
               markerscale=2)
    
    # Add overall title with scale factor
    scale_percentage = int(scale_factor * 100)
    fig.suptitle(f"Performance vs Disabled Units ({scale_percentage}% Scale)", 
                 fontsize=22, y=0.98)
    
    # Adjust layout to make room for the legend and title
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.10, top=0.90)
    
    plt.savefig(f"performance_vs_disabled_units_scale_{scale_percentage}pct.png", dpi=300, bbox_inches="tight")
    plt.show()