In [None]:
# Observed vs. expected prediction scatterplots with Pearson's R for MLR and XGBRF

def figure_1_observed_vs_predicted(y_true, predictions_dict, output_path='~/Zhang-Lab/Zhang Lab Data/Saved figures/figure_1.png'):
    """
    Generate observed vs. predicted scatterplot with Pearson correlation.
    
    Parameters
    ----------
    y_true : array-like
        True target gene expression values
    predictions_dict : dict
        Dictionary with keys as model names and values as predictions
        Example: {'RNN': pred_rnn, 'XGBRFRegressor': pred_xgb, 'Linear': pred_linear}
    output_path : str
        Path to save figure
    """
    set_publication_style()
    fig, axes = plt.subplots(1, 3, figsize=FIGSIZE_TRIPLE)
    
    model_names = list(predictions_dict.keys())
    
    for idx, model_name in enumerate(model_names):
        ax = axes[idx]
        y_pred = predictions_dict[model_name]
        
        # Compute metrics
        metrics = compute_metrics(y_true, y_pred)
        
        # Scatter plot
        ax.scatter(y_true, y_pred, alpha=0.5, s=30, 
                  color=MODEL_COLORS.get(model_name, '#1f77b4'),
                  edgecolors='none')
        
        # Perfect prediction diagonal line
        min_val = min(y_true.min(), y_pred.min())
        max_val = max(y_true.max(), y_pred.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'k--', 
               lw=2, alpha=0.5, label='Perfect prediction')
        
        # Fit regression line with confidence interval
        z = np.polyfit(y_true, y_pred, 1)
        p = np.poly1d(z)
        x_line = np.linspace(y_true.min(), y_true.max(), 100)
        y_line = p(x_line)
        ax.plot(x_line, y_line, color=MODEL_COLORS.get(model_name, '#1f77b4'),
               lw=2.5, alpha=0.8, label='Linear fit')
        
        # Labels and formatting
        ax.set_xlabel('Observed Expression', fontsize=12, fontweight='bold')
        ax.set_ylabel('Predicted Expression', fontsize=12, fontweight='bold')
        ax.set_title(model_name, fontsize=13, fontweight='bold')
        
        # Add metrics text box
        textstr = f"r = {metrics['pearson_r']:.4f}\nRÂ² = {metrics['r2']:.4f}\np < 0.001" \
                 if metrics['p_value'] < 0.001 else \
                 f"r = {metrics['pearson_r']:.4f}\nRÂ² = {metrics['r2']:.4f}\np = {metrics['p_value']:.3f}"
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, 
               fontsize=10, verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        ax.legend(loc='lower right', fontsize=9)
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Figure 1 saved to {output_path}")
    plt.show()
    
    # Return metrics for reference
    metrics_summary = {}
    for model_name in model_names:
        metrics_summary[model_name] = compute_metrics(y_true, predictions_dict[model_name])
    return metrics_summary