In [1]:
import matplotlib.pyplot as plt
import sys, os, json
import seaborn as sns
import numpy as np

from sklearn.metrics import confusion_matrix, matthews_corrcoef
from pathlib import Path

cur_dir = Path(os.path.abspath("")).parent
project_root = cur_dir.parent.parent
sys.path.append(str(project_root))

from definitions.odinw import ODinWDefinitions
from src.eval_utils import get_precision_per_class, get_recall_per_class, get_f1_per_class, get_macro_precision, get_macro_recall, get_macro_f1, calculate_tp_fp_fn_counts, get_micro_precision_from_counts, get_micro_recall_from_counts, get_micro_f1

# Cross Model

In [2]:
COLORS = ["#ea5545",
          "#ede15b",
          "#87bc45",
          "#27aeef",
          "#b33dc6"]

In [3]:
def sample_from_preds_and_gt(data, sample_size):
    """Sample a fixed number of predictions and ground truth data with replacement for bootstrap analysis.
    
    Args:
        data (dict): Dictionary containing 'preds' and 'gt_actions' arrays
        sample_size (int): Number of samples to take
        
    Returns:
        dict: Dictionary containing sampled 'preds' and 'gt_actions' arrays
    """
    # Get the total number of samples
    n_samples = len(data['preds'])
    
    # Generate random indices with replacement
    indices = np.random.choice(n_samples, size=sample_size, replace=True)
    
    # Sample the data
    sampled_data = {
        'preds': np.array(data['preds'])[indices],
        'gt_actions': np.array(data['gt_actions'])[indices]
    }
    
    return sampled_data

In [4]:
def get_preds_and_gt_from_dict(model, data, dataset):
    result = {}
    if 'gpt' in model:
        predictions = np.array(data[model][dataset]['preds']).flatten()
        ground_truth = np.array(data[model][dataset]['gt_actions']).flatten()
    elif model == 'pi0':
        predictions = np.array(data[model][dataset]['all_preds']).flatten()
        ground_truth = np.array(data[model][dataset]['all_gt']).flatten()

    result["preds"] = predictions
    result["gt_actions"] = ground_truth
    return result

In [5]:
def get_macro_micro_metrics_std_per_dataset_from_bootstrap_sample(model, data, dataset, sample_size=20_000, n_iterations=200, metric_type='macro'):
    """Calculate standard deviation of macro metrics using bootstrap sampling.
    
    Args:
        model (str): Model name
        data (dict): Results data
        dataset (str): Dataset name
        sample_size (int): Number of samples to take in each bootstrap iteration
        n_iterations (int): Number of bootstrap iterations
        
    Returns:
        tuple: (precision_std, recall_std, f1_std)
    """
    # Get the full dataset once, outside the loop
    all_res = get_preds_and_gt_from_dict(model, data, dataset)
    valid_classes = list(range(ODinWDefinitions.SUB_DATASET_CATEGORIES[dataset]))

    # Initialize lists to store metrics
    sampled_data_precision = []
    sampled_data_recall = []
    sampled_data_f1 = []

    # Perform bootstrap sampling
    for i in range(n_iterations):
        # Sample with replacement
        sampled_data = sample_from_preds_and_gt(all_res, sample_size)
        
        # Calculate metrics on the sampled data
        if metric_type == 'macro':
            precisions = get_precision_per_class(sampled_data['preds'], sampled_data['gt_actions'], valid_classes)
            recalls = get_recall_per_class(sampled_data['preds'], sampled_data['gt_actions'], valid_classes)
            f1s = get_f1_per_class(precisions, recalls)
            # Store macro metrics
            sampled_data_precision.append(get_macro_precision(precisions))
            sampled_data_recall.append(get_macro_recall(recalls))
            sampled_data_f1.append(get_macro_f1(f1s))
        elif metric_type == 'micro':
            total_tp, total_fp, total_fn, valid_fp, invalid_fp = calculate_tp_fp_fn_counts(sampled_data['preds'], sampled_data['gt_actions'], valid_classes)
            
            precisions = get_micro_precision_from_counts(total_tp, total_fp)
            recalls = get_micro_recall_from_counts(total_tp, total_fn)
            f1s = get_micro_f1(precisions, recalls)

            sampled_data_precision.append(precisions)
            sampled_data_recall.append(recalls)
            sampled_data_f1.append(f1s)
        else:
            raise ValueError(f"Invalid metric type: {metric_type}")
    
    # Calculate standard deviations
    precision_std = np.std(sampled_data_precision)
    recall_std = np.std(sampled_data_recall)
    f1_std = np.std(sampled_data_f1)
    
    return precision_std, recall_std, f1_std

In [None]:
def plot_cross_model_macro_micro_metric(results, models, metric='recall', metric_type='macro'):
    """Create grouped bar plots comparing model performance across datasets, arranged in 4 vertical subplots"""
    
    # Calculate metrics for each model
    model_scores = {}
    model_scores_std = {}
    
    if metric_type == 'macro':
        if metric == 'precision':
            metric_key = 'macro_precision'
            metric_std_index = 0
        elif metric == 'recall':
            metric_key = 'macro_recall'
            metric_std_index = 1
        elif metric == 'f1':
            metric_key = 'macro_f1'
            metric_std_index = 2
        else:
            raise ValueError(f"Invalid metric: {metric}")
    elif metric_type == 'micro':
        if metric == 'precision':
            metric_key = 'precision'
            metric_std_index = 0
        elif metric == 'recall':
            metric_key = 'recall'
            metric_std_index = 1
        elif metric == 'f1':
            metric_key = 'f1'
            metric_std_index = 2
        elif metric == 'accuracy':
            metric_key = 'exact_match_rate'
            metric_std_index = 0
        else:
            raise ValueError(f"Invalid metric: {metric}")
    
        
    # Collect scores for each model
    for model in models:
        scores = []
        scores_std = []
        for dataset in results[model].keys():
            try:
                std = get_macro_micro_metrics_std_per_dataset_from_bootstrap_sample(
                    model, results, dataset, metric_type=metric_type)[metric_std_index]
                scores_std.append(std)
            except Exception as e:
                print(f"Error in {model}: {e}")
                scores_std.append(0.0)
            try:
                score = results[model][dataset][metric_key]
                scores.append(score)
            except (KeyError, TypeError):
                scores.append(0)
            
            model_scores[model] = scores
            model_scores_std[model] = scores_std
    
    metrics = {model: {} for model in models}
    for model in models:
        metrics[model][metric_key] = model_scores[model]
        metrics[model][f'{metric_key}_std'] = model_scores_std[model]

    subdatasets = list(results[models[0]].keys())
    # Split datasets into groups of 4
    num_datasets = len(subdatasets)
    datasets_per_subplot = 4
    num_subplots = (num_datasets + datasets_per_subplot - 1) // datasets_per_subplot

    # Create figure with subplots and adjust spacing
    fig = plt.figure(figsize=(15, 6*num_subplots))
    # Add space at the top for the title and legend
    gs = plt.GridSpec(num_subplots + 1, 1, height_ratios=[0.3] + [1]*num_subplots, hspace=0.2)
    
    # Create a special axes for the title and legend
    title_ax = fig.add_subplot(gs[0])
    title_ax.axis('off')  # Hide the axes
    
    # Create subplot axes
    axs = [fig.add_subplot(gs[i+1]) for i in range(num_subplots)]

    # Plot settings
    bar_width = 0.15
    opacity = 0.8

    # Create plots for each group of datasets
    for subplot_idx in range(num_subplots):
        ax = axs[subplot_idx]
        
        # Get the datasets for this subplot
        start_idx = subplot_idx * datasets_per_subplot
        end_idx = min(start_idx + datasets_per_subplot, num_datasets)
        current_datasets = subdatasets[start_idx:end_idx]
        current_datasets = [ds[:15] + '...' if len(ds) > 15 else ds for ds in current_datasets]
        # Calculate x positions
        x = np.arange(len(current_datasets))
        
        # Calculate maximum y value including error bars for this subplot
        max_y = 0
        for model in models:
            current_scores = model_scores[model][start_idx:end_idx]
            current_stds = model_scores_std[model][start_idx:end_idx]
            try:
                max_with_error = max([score + std for score, std in zip(current_scores, current_stds) if score > 0])
            except:
                max_with_error = 0
            max_y = max(max_y, max_with_error)
        
        # Add padding to y-axis limit (30% extra space)
        y_limit = max_y * 1.3
        
        # Plot bars for each model
        for i, model in enumerate(models):
            # Get scores for current datasets
            current_scores = model_scores[model][start_idx:end_idx]
            current_stds = model_scores_std[model][start_idx:end_idx]
            
            # Plot bars with error bars
            bars = ax.bar(x + i*bar_width, current_scores, bar_width,
                         alpha=opacity, color=COLORS[i], label=model,
                         yerr=current_stds, capsize=5, error_kw={'elinewidth': 2})
            
            # Add value labels on top of bars
            for idx, (value, std) in enumerate(zip(current_scores, current_stds)):
                if value > 0:  # Only show non-zero values
                    # Calculate label position
                    label_height = value + std
                    if label_height > y_limit * 0.85:  # If label would be too close to top
                        label_height = value - std  # Place label below error bar
                        va = 'top'
                    else:
                        va = 'bottom'
                    
                    ax.text(x[idx] + i*bar_width, label_height,
                           f'{value:.2f}', ha='center', va=va,
                           rotation=45, fontsize=16)
        
        # Customize subplot
        if metric != 'accuracy':
            ax.set_ylabel(f'{metric_type.capitalize()} {metric.capitalize()}', fontsize=20)
        else:
            ax.set_ylabel(f'Accuracy', fontsize=20)
        ax.set_xticks(x + bar_width * (len(models)-1)/2)

        ax.set_xticklabels(current_datasets, ha='right', fontsize=20)
        ax.grid(True, axis='y', alpha=0.3)
        ax.tick_params(axis='both', which='major', labelsize=16)
        
        # Set y-axis limits
        ax.set_ylim(0, y_limit)

    # Add overall title
    if metric != 'accuracy':
        title = f'Model {metric_type.capitalize()} {metric.capitalize()} Comparison for ODinW'
    else:
        title = f'Model Accuracy Comparison for ODinW'

    title_ax.text(0.5, 0.7, title, fontsize=24, ha='center', va='bottom')

    # Create a single legend for all subplots
    handles, labels = axs[0].get_legend_handles_labels()
    title_ax.legend(handles, labels, 
                   loc='center',
                   bbox_to_anchor=(0.5, 0.2),
                   ncol=len(models) + 1,  # +1 for random line only
                   fontsize=20)

    # Save the plot with high DPI
    if metric != 'accuracy':
        filename = f'model_comparison_{metric_type}_{metric_key}_with_invalids_odinw.png'
    else:
        filename = f'model_comparison_accuracy_odinw.png'
    output_path = os.path.abspath(os.path.join("plots", filename))
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    

In [31]:
# Load gpt5 results
models = ['gpt5', 'pi0']
all_results = {'gpt5': {}, 'pi0': {}}
for dataset in ODinWDefinitions.SUB_DATASET_CATEGORIES.keys():
    gpt5_results_dir = f"{project_root}/src/v1/results/genesis/gpt_5/low_reasoning/odinw/"
    results_file = f"{dataset}_results.json"
    with open(Path(gpt5_results_dir) / results_file, 'r') as f:
        gpt5_results = json.load(f)[dataset]

    # Load pi0 results
    pi0_results_dir = f"{project_root}/src/v1/results/pi0/odinw"
    results_file = f"pi0_hf_odinw_{dataset}_inference_results.json"
    with open(Path(pi0_results_dir) / results_file, 'r') as f:
        pi0_results = json.load(f)

    all_results['gpt5'][dataset] = gpt5_results
    all_results['pi0'][dataset] = pi0_results


In [26]:
plot_cross_model_macro_micro_metric(all_results, models, metric='recall', metric_type='macro')

In [32]:
plot_cross_model_macro_micro_metric(all_results, models, metric='accuracy', metric_type='micro')