In [None]:
import matplotlib.pyplot as plt
import sys, os, json
import seaborn as sns
import numpy as np

from pathlib import Path

cur_dir = Path(os.path.abspath("")).parent
project_root = cur_dir.parent.parent
sys.path.append(str(project_root))


In [2]:
COLORS = ["#ea5545",
          "#ede15b",
          "#87bc45",
          "#27aeef",
          "#b33dc6"]

In [46]:
def plot_cross_model_brier_mae(results, metric_type='brier_mae'):
    """Create a plot comparing different MAE metrics across models
    Split into 4 vertically stacked subplots with 4 datasets each.
    
    Args:
        results_dir (str): Directory containing results
        metric_type (str): Type of metric to plot. One of:
            - 'brier_mae'
            - 'quantile_filtered_normalized_brier_mae'
            - 'normalized_brier_mae'
            - 'max_relative_mae'
    """
    # Load results for all models and ensure consistent order with COLORS
    models = ['gpt5', 'pi0']
    model_colors = {
        'gpt5': COLORS[0],     # Red
        'pi0': COLORS[4],    # Purple
    }
    # Get list of all subdatasets -- common ones
    subdatasets = sorted(list(results[models[0]].keys()))
    
    # Set width of bars and positions
    width = 0.2  # Slightly wider bars since we have 4 models
    
    # Dictionary to map models to their metric keys
    titles = {
        'brier_mae': 'Average MAE',
        'quantile_filtered_normalized_brier_mae': 'Quantile Filtered Normalized MAE',
        'normalized_brier_mae': 'Normalized Average MAE',
        'max_relative_mae': 'Max Relative MAE'
    }
    metric_keys = {
        'brier_mae': {
            'gpt5': 'avg_dataset_amae',
            'pi0': 'avg_dataset_amae'
        },
        'quantile_filtered_normalized_brier_mae': {
            'gpt5': 'normalized_quantile_filtered_amae',
            'pi0': 'normalized_quantile_filtered_amae'
        },
        'normalized_brier_mae': {
            'gpt5': 'normalized_amae',
            'pi0': 'normalized_amae'
        },
        'max_relative_mae': {
            'gpt5': 'max_relative_mae',
            'pi0': 'max_relative_mae'
        }
    }

    # Get the appropriate keys for the selected metric
    current_metric_keys = metric_keys[metric_type]

    # Split datasets into groups of 5
    datasets_per_subplot = 5
    num_subplots = (len(subdatasets) + datasets_per_subplot - 1) // datasets_per_subplot

    # Create figure with subplots and adjust spacing
    fig = plt.figure(figsize=(15, 6*num_subplots))
    # Add space at the top for the title and legend
    gs = plt.GridSpec(num_subplots + 1, 1, height_ratios=[0.3] + [1]*num_subplots, hspace=0.2)
    
    # Create a special axes for the title and legend
    title_ax = fig.add_subplot(gs[0])
    title_ax.axis('off')  # Hide the axes
    
    # Create subplot axes
    axs = [fig.add_subplot(gs[i+1]) for i in range(num_subplots)]
    
    # Plot for each subplot (group of datasets)
    for subplot_idx in range(num_subplots):
        ax = axs[subplot_idx]
        
        # Get the datasets for this subplot
        start_idx = subplot_idx * datasets_per_subplot
        end_idx = min(start_idx + datasets_per_subplot, len(subdatasets))
        current_datasets = subdatasets[start_idx:end_idx]
        
        # Calculate x positions
        x = np.arange(len(current_datasets))
        
        # Plot bars for each model and track min/max values
        all_scores = []
        for i, model in enumerate(models):
            model_scores = []
            
            for dataset in current_datasets:
                score = results[model][dataset][current_metric_keys[model]]
                model_scores.append(score)
                all_scores.append(score)
            # Plot bars with error bars using consistent colors
            bars = ax.bar(x + i*width, model_scores, width, 
                         label=model, color=model_colors[model], alpha=0.8)
            
            # # Add value labels on top of bars
            # for idx, value in enumerate(model_scores):
            #     if value > 0:  # Only show non-zero values
            #         ax.text(x[idx] + i*width, value, f'{value:.2f}',
            #                ha='center', va='bottom', rotation=45,
            #                fontsize=20)
        
        # Calculate y-axis limits for this subplot
        non_zero_scores = [s for s in all_scores if s > 0]
        if non_zero_scores:
            min_val = min(non_zero_scores)
            max_val = max(non_zero_scores)
            # Start y-axis from 20% below the minimum non-zero value
            y_min = max(0, min_val - (max_val - min_val) * 0.2)
            # Add 30% padding above maximum value
            y_max = max_val + (max_val - min_val) * 0.3
            
            # Set y-axis limits
            ax.set_ylim(y_min, y_max)
            
            # Add broken axis indicator if not starting from 0
            if y_min > 0:
                d = .015  # Size of diagonal lines
                kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
                ax.plot((-d, +d), (-d, +d), **kwargs)
                ax.plot((1 - d, 1 + d), (-d, +d), **kwargs)
        
        # Customize subplot
        metric_label = titles[metric_type]
        ax.set_ylabel(metric_label, fontsize=20)
        ax.set_xticks(x + width * (len(models)-1)/2)
        
        current_datasets = [ds.replace('openx_', '') for ds in current_datasets]
        current_datasets = [ds[:15]+'...' if len(ds) > 15 else ds for ds in current_datasets]
        ax.set_xticklabels(current_datasets, ha='right', fontsize=20)
        
        #set rotation of xticklabels
        ax.tick_params(axis='x', rotation=30)
        
        ax.grid(True, axis='y', alpha=0.3)
        ax.tick_params(axis='both', which='major', labelsize=20)
        
        # # Add light vertical lines to separate datasets
        # for i in x:
        #     ax.axvline(i - width/2, color='gray', linestyle='-', alpha=0.1)
    
    # Add overall title
    title_ax.text(0.5, 0.7, f'{metric_label} Comparison Across Subdatasets for OpenX',
                 fontsize=24, ha='center', va='bottom')

    # Create a single legend for all subplots
    handles, labels = axs[0].get_legend_handles_labels()
    title_ax.legend(handles, labels, 
                   loc='center',
                   bbox_to_anchor=(0.5, 0.2),
                   ncol=len(models),
                   fontsize=20)
    
    # Save the plot
    filename = f'model_comparison_{metric_type}_openx.png'
    output_path = os.path.abspath(os.path.join("plots", filename))
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()

In [47]:
models = ['gpt5', 'pi0']
all_results = {'gpt5': {}, 'pi0': {}}

with open(Path(f"{project_root}/src/v1/results/pi0/pi0_base_openx_results.json"), 'r') as f:
    all_results['pi0'] = json.load(f)

all_results['gpt5'] = {}
for dataset in all_results['pi0'].keys():
    gpt5_results_dir = f"{project_root}/src/v1/results/genesis/gpt_5/low_reasoning/openx/"
    results_file = f"{dataset}.json"
    with open(Path(gpt5_results_dir) / results_file, 'r') as f:
        res = json.load(f)
    all_results['gpt5'][dataset] = next(iter(res.values()))


In [48]:
plot_cross_model_brier_mae(all_results, metric_type='brier_mae')
plot_cross_model_brier_mae(all_results, metric_type='quantile_filtered_normalized_brier_mae')
plot_cross_model_brier_mae(all_results, metric_type='normalized_brier_mae')
plot_cross_model_brier_mae(all_results, metric_type='max_relative_mae')