In [None]:
%run '/home/christianl/Zhang-Lab/Zhang Lab Code/Boilerplate/Fig_config_utilities.py'

In [None]:
def bland_altman_plot(data1, data2, ax, model1_name, model2_name, confidence_interval=1.96):
    """
    Helper function to create a single Bland-Altman plot on given axis.
    
    Parameters
    ----------
    data1, data2 : array-like
        Predictions from two models
    ax : matplotlib axis
        Axis to plot on
    model1_name, model2_name : str
        Names of models being compared
    confidence_interval : float
        Z-score for confidence interval (default 1.96 for 95% CI)
    """
    data1 = np.asarray(data1)
    data2 = np.asarray(data2)
    
    mean = np.mean([data1, data2], axis=0)
    diff = data1 - data2
    
    md = np.mean(diff)
    sd = np.std(diff)
    
    ci_low = md - confidence_interval * sd
    ci_high = md + confidence_interval * sd
    
    # Scatter plot
    ax.scatter(mean, diff, alpha=0.5, s=30, color='steelblue', edgecolors='none')
    
    # Mean difference line
    ax.axhline(md, color='black', linestyle='-', lw=2.5, label=f'Mean: {md:.4f}')
    
    # Confidence interval lines
    ax.axhline(ci_high, color='gray', linestyle='--', lw=2, alpha=0.7)
    ax.axhline(ci_low, color='gray', linestyle='--', lw=2, alpha=0.7)
    
    # Add text annotations for CI limits
    ax.text(mean.max() * 1.02, ci_high, f'+{confidence_interval}Ïƒ\n{ci_high:.4f}',
           fontsize=9, verticalalignment='center')
    ax.text(mean.max() * 1.02, ci_low, f'-{confidence_interval}Ïƒ\n{ci_low:.4f}',
           fontsize=9, verticalalignment='center')
    
    ax.set_xlabel(f'Mean of {model1_name} and {model2_name}', fontsize=11, fontweight='bold')
    ax.set_ylabel(f'{model1_name} - {model2_name}', fontsize=11, fontweight='bold')
    ax.set_title(f'{model1_name} vs {model2_name}', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.3)
    
    return md, sd, ci_low, ci_high


def figure_3_bland_altman(predictions_dict, output_path='figure_3.png'):
    """
    Generate Bland-Altman plots for pairwise model comparisons.
    
    Parameters
    ----------
    predictions_dict : dict
        Dictionary of predictions by model
    output_path : str
        Path to save figure
    """
    set_publication_style()
    model_names = list(predictions_dict.keys())
    
    # Generate all pairwise comparisons (for 3 models: 3 comparisons)
    comparisons = [
        (model_names[0], model_names[1]),
        (model_names[0], model_names[2]),
        (model_names[1], model_names[2])
    ]
    
    fig, axes = plt.subplots(1, 3, figsize=FIGSIZE_TRIPLE)
    
    for idx, (model1, model2) in enumerate(comparisons):
        ax = axes[idx]
        bland_altman_plot(
            predictions_dict[model1],
            predictions_dict[model2],
            ax, model1, model2
        )
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Figure 3 saved to {output_path}")
    plt.show()
