In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
from pathlib import Path
import numpy as np
import pickle
import yaml

fontsize = 22
plt.rcParams.update({
    'font.size': fontsize,
    'axes.labelsize': fontsize,
    'axes.titlesize': fontsize,
    'legend.fontsize': fontsize,
    'xtick.labelsize': fontsize,
    'ytick.labelsize': fontsize,
})

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

# Create a figure with two subplots stacked vertically
fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharex=True)  # sharex=True makes x-axis scale the same

# Store data by batch and experiment
batch_data = {}
for batch in config["batches"]:
    batch_data[batch] = defaultdict(list)

max_len = 0

# First pass: collect data for each experiment grouped by batch
for i, batch in enumerate(config["batches"]):
    datapoints = config["datapoints"][i]
    for experiment in config["experiments"]:
        for trial in config["trials"]:
            checkpoint_path = Path(f"{config['base_path']}/{batch}/{experiment}/{trial}/logs/train.dat")

            if checkpoint_path.is_file():
                with open(checkpoint_path, "rb") as handle:
                    data = pickle.load(handle)
                    # data = data["rewards_per_iteration"][:datapoints]

                    data = data["rewards_per_episode"][:config["datapoints"][i]]
                    
                    # Apply moving average
                    if len(data) > config["moving_avg_window_size"]:
                        smoothed_data = moving_average(data, config["moving_avg_window_size"])
                        batch_data[batch][experiment].append(smoothed_data)
                        max_len = max(max_len, len(smoothed_data))
                    else:
                        print(f"Warning: {batch}/{experiment}/{trial} has too few data points for smoothing")

# Second pass: pad arrays to the same length if needed
for batch in batch_data:
    for exp in batch_data[batch]:
        padded_data = []
        for trial_data in batch_data[batch][exp]:
            if len(trial_data) < max_len:
                # Pad with NaN
                padding = np.full(max_len - len(trial_data), np.nan)
                padded_data.append(np.concatenate([trial_data, padding]))
            else:
                padded_data.append(trial_data)
        
        batch_data[batch][exp] = padded_data

# Store line objects and labels for combined legend
all_lines = []
all_labels = ["mlp", "gcn_mixed", "gat_mixed", "graph_transformer_mixed", "gcn_full", "gat_full", "graph_transformer_full"]

titles = ["8 Salp-Unit Policy Learning Curve", "16 Salp-Unit Policy Learning Curve"]

# Plot each batch in its respective subplot
for batch_idx, batch in enumerate(config["batches"]):
    ax = axes[batch_idx]
    color_idx = 0

    
    # Add batch name as subplot title
    ax.set_title(f"{titles[batch_idx]}")
    
    # First pass: collect all data for this batch to find batch-specific min/max
    batch_all_data = []
    for exp in batch_data[batch]:
        if batch_data[batch][exp]:
            data_array = np.array(batch_data[batch][exp])
            batch_all_data.append(data_array)
    
    # Calculate batch-specific normalization parameters
    if batch_all_data:
        batch_combined = np.concatenate([arr.flatten() for arr in batch_all_data])
        batch_min = np.nanmin(batch_combined)
        batch_max = np.nanmax(batch_combined)
        batch_range = batch_max - batch_min
    else:
        batch_min = 0
        batch_range = 1
    
    # Second pass: plot with batch-specific normalization
    for exp in batch_data[batch]:
        if not batch_data[batch][exp]:
            print(f"No data for {batch}/{exp}")
            continue
            
        # Convert list of arrays to 2D numpy array
        data_array = np.array(batch_data[batch][exp])

        # Normalize using batch-specific min/max
        if batch_range > 0:
            data_array = (data_array - batch_min) / batch_range
        
        # Calculate mean and standard error across trials
        mean_rewards = np.nanmean(data_array, axis=0)
        
        # Standard Error = StdDev / sqrt(n)
        n_trials = np.sum(~np.isnan(data_array), axis=0)  # Count non-NaN values at each step
        std_rewards = np.nanstd(data_array, axis=0)
        se_rewards = std_rewards / np.sqrt(np.maximum(n_trials, 1))  # Avoid division by zero
        
        # X axis
        x = np.arange(len(mean_rewards))
        
        # Get color for this experiment
        color = plt.cm.tab10(color_idx % 10)
        color_idx += 1
        
        # Plot mean line
        line = ax.plot(x, mean_rewards, linewidth=2, label=exp, color=color)

        # Store line and label for combined legend
        # Only add to legend from first subplot to avoid duplicates
        if batch_idx == 0:
            all_lines.append(line[0])
        
        # Plot standard error band
        ax.fill_between(
            x,
            mean_rewards - se_rewards,
            mean_rewards + se_rewards,
            alpha=0.3,
            color=color
        )
        
        non_nan_count =  0
        for datapoint in batch_data[batch][exp]:
            non_nan_count += np.sum(~np.isnan(datapoint))

        print(f"{batch}/{exp}: {len(batch_data[batch][exp])} trials, {non_nan_count} datapoints")

    # Add legend, labels, grid to each subplot
    # Only add x-label to bottom subplot
    if batch_idx == 1:
        ax.set_xlabel("Episodes")
    
    ax.set_ylabel("Normalized SCLD Reward")
    ax.grid(True, alpha=0.3)

# Create a single legend for the entire figure with 2 rows of ~3-4 items each
leg = fig.legend(all_lines, all_labels, 
           loc='lower center',  
           bbox_to_anchor=(0.5, -0.06),  
           ncol=3,  # 3 columns = 3 items per row (7 total = 3, 3, 1)
           markerscale=2)

# Make legend lines thicker
for line in leg.get_lines():
    line.set_linewidth(6)  # Adjust thickness here

# Add more padding at the bottom for the legend
plt.tight_layout()
plt.subplots_adjust(bottom=0.12)  # Adjusted for 3-row legend

plt.savefig("learning_curves_normalized.png", dpi=800, bbox_inches="tight")
plt.show()