In [None]:
# Generate data for benchmarking

# Setup
import pandas as pd
import numpy as np
from glycoforge import simulate
from methods import combat, add_noise_to_zero_variance_features, percentile_normalization, ratio_preserving_combat, harmony_correction, limma_style_correction, stratified_combat
from evaluation import quantify_batch_effect_impact, compare_differential_expression
from glycoforge.utils import clr, invclr
import json

# Define parameter grid
param_grid = {
  'bio_strength': [0.5, 1.0, 1.5, 2.0],
  'kappa_mu': [0.5, 1.0, 1.5, 2.0],
  'var_b': [0.3, 0.5, 0.8, 1.0]
}

# Run grid comparison
results_list = []
base_config = {
  'data_source': 'simulated',
  'n_glycans': 50,
  'n_H': 15,
  'n_U': 15,
  'k_dir': 100,
  'n_batches': 3,
  'random_seeds': [42, 43, 44],
  'verbose': False,
  'save_csv': True
}
total_runs = len(param_grid['bio_strength']) * len(param_grid['kappa_mu']) * len(param_grid['var_b'])
run_idx = 0
for bio_strength in param_grid['bio_strength']:
  for kappa_mu in param_grid['kappa_mu']:
    for var_b in param_grid['var_b']:
      run_idx += 1
      print(f"[{run_idx}/{total_runs}] bio={bio_strength}, kappa={kappa_mu}, var_b={var_b}")
      config = base_config.copy()
      config.update({'bio_strength': bio_strength, 'kappa_mu': kappa_mu, 'var_b': var_b, 'output_dir': f'results/grid_{bio_strength}_{kappa_mu}_{var_b}'})
      simulate(**config)
      for seed in config['random_seeds']:
        output_dir = config['output_dir']
        Y_clean = pd.read_csv(f"{output_dir}/1_Y_clean_seed{seed}.csv", index_col=0)
        Y_clean_clr = pd.read_csv(f"{output_dir}/1_Y_clean_clr_seed{seed}.csv", index_col=0)
        Y_with_batch = pd.read_csv(f"{output_dir}/2_Y_with_batch_seed{seed}.csv", index_col=0)
        Y_with_batch_clr = pd.read_csv(f"{output_dir}/2_Y_with_batch_clr_seed{seed}.csv", index_col=0)
        with open(f"{output_dir}/metadata_seed{seed}.json") as f:
          metadata = json.load(f)
        batch_labels = np.array(metadata['sample_info']['batch_labels'])
        bio_labels = np.array(metadata['sample_info']['bio_labels'])
        bio_groups = metadata['sample_info']['bio_groups']
        Y_with_batch_clr_fixed = add_noise_to_zero_variance_features(Y_with_batch_clr, noise_level=1e-10, random_seed=seed)
        Y_combat_clr = pd.DataFrame(
          combat(Y_with_batch_clr_fixed.values, batch_labels, mod=bio_labels),
          index=Y_with_batch_clr.index, columns=Y_with_batch_clr.columns
        )
        Y_combat = pd.DataFrame(index=Y_combat_clr.index, columns=Y_combat_clr.columns)
        for sample in Y_combat_clr.columns:
          Y_combat[sample] = invclr(Y_combat_clr[sample].values)
        Y_percentile = pd.DataFrame(percentile_normalization(Y_with_batch.values, batch_labels), index=Y_with_batch.index, columns=Y_with_batch.columns)
        Y_percentile_clr = pd.DataFrame(clr(Y_percentile.values.T).T, index=Y_percentile.index, columns=Y_percentile.columns)
        Y_ratio = pd.DataFrame(ratio_preserving_combat(Y_with_batch.values, batch_labels, mod=bio_labels), index=Y_with_batch.index, columns=Y_with_batch.columns)
        Y_ratio_clr = pd.DataFrame(clr(Y_ratio.values.T).T, index=Y_ratio.index, columns=Y_ratio.columns)
        Y_harmony = pd.DataFrame(harmony_correction(Y_with_batch.values, batch_labels), index=Y_with_batch.index, columns=Y_with_batch.columns)
        Y_harmony_clr = pd.DataFrame(clr(Y_harmony.values.T).T, index=Y_harmony.index, columns=Y_harmony.columns)
        Y_limma = pd.DataFrame(limma_style_correction(Y_with_batch.values, batch_labels, mod=bio_labels), index=Y_with_batch.index, columns=Y_with_batch.columns)
        Y_limma_clr = pd.DataFrame(clr(Y_limma.values.T).T, index=Y_limma.index, columns=Y_limma.columns)
        Y_stratified = pd.DataFrame(stratified_combat(Y_with_batch.values, batch_labels, bio_labels), index=Y_with_batch.index, columns=Y_with_batch.columns)
        Y_stratified_clr = pd.DataFrame(clr(Y_stratified.values.T).T, index=Y_stratified.index, columns=Y_stratified.columns)
        for method_name, Y_corrected, Y_corrected_clr in [
          ('combat', Y_combat, Y_combat_clr),
          ('percentile', Y_percentile, Y_percentile_clr),
          ('ratio', Y_ratio, Y_ratio_clr),
          ('harmony', Y_harmony, Y_harmony_clr),
          ('limma', Y_limma, Y_limma_clr),
          ('stratified', Y_stratified, Y_stratified_clr)
        ]:
          metrics = quantify_batch_effect_impact(Y_corrected_clr, batch_labels, bio_groups, verbose=False)
          de = compare_differential_expression(dataset1=Y_clean, dataset2=Y_with_batch, dataset3=Y_corrected, verbose=False)
          results_list.append({
            'bio_strength': bio_strength,
            'kappa_mu': kappa_mu,
            'var_b': var_b,
            'seed': seed,
            'method': method_name,
            'PVCA_batch_variance': metrics['pvca_batch_variance'],
            'PVCA_bio_variance': metrics['pvca_bio_variance'],
            'PVCA_residual_variance': metrics['pvca_residual_variance'],
            'silhouette': metrics['silhouette'],
            'kBET': metrics['kBET'],
            'LISI': metrics['LISI'],
            'ARI': metrics['ARI'],
            'comp_effect': metrics['compositional_effect_size'],
            'pca_batch': metrics['pca_batch_effect'],
            'tp': de['results']['compare_1v3']['after_correction_errors']['tp_count'],
            'fp': de['results']['compare_1v3']['after_correction_errors']['fp_count'],
            'fn': de['results']['compare_1v3']['after_correction_errors']['fn_count'],
            'sig_clean': de['results']['dataset1']['significant_count'],
            'sig_corrected': de['results']['dataset3']['significant_count']
          })

# Create results DataFrame
results_df = pd.DataFrame(results_list)
results_df['tp_rate'] = results_df['tp'] / results_df['sig_clean']
results_df['fp_rate'] = results_df['fp'] / (50 - results_df['sig_clean'])
results_df['fn_rate'] = results_df['fn'] / results_df['sig_clean']
results_df['f1'] = 2 * results_df['tp'] / (2 * results_df['tp'] + results_df['fp'] + results_df['fn'])

# Aggregate across seeds
agg_results = results_df.groupby(['bio_strength', 'kappa_mu', 'var_b', 'method']).agg({
  'PVCA_batch_variance': 'mean',
  'PVCA_bio_variance': 'mean',
  'PVCA_residual_variance': 'mean',
  'silhouette': 'mean',
  'kBET': 'mean',
  'LISI': 'mean',
  'ARI': 'mean',
  'comp_effect': 'mean',
  'pca_batch': 'mean',
  'tp_rate': 'mean',
  'fp_rate': 'mean',
  'fn_rate': 'mean',
  'f1': 'mean'
}).reset_index()

# Find winner for each metric and condition
def find_winners(df, metric, lower_is_better=True):
  winners = []
  for (bio, kappa, var_b), group in df.groupby(['bio_strength', 'kappa_mu', 'var_b']):
    if lower_is_better:
      best = group.loc[group[metric].idxmin(), 'method']
    else:
      best = group.loc[group[metric].idxmax(), 'method']
    winners.append({'bio_strength': bio, 'kappa_mu': kappa, 'var_b': var_b, 'winner': best})
  return pd.DataFrame(winners)

# Overall winner counts
print("=== WINNER COUNTS BY METRIC ===")
for metric, lower_better in [('PVCA_batch_variance', True), ('silhouette', True), ('kBET', True), ('ARI', True), ('comp_effect', True), ('pca_batch', True), ('LISI', False), ('f1', False), ('tp_rate', False), ('fp_rate', True), ('fn_rate', True)]:
  winners = find_winners(agg_results, metric, lower_better)
  counts = winners['winner'].value_counts()
  print(f"\n{metric} ({'lower' if lower_better else 'higher'} is better):")
  print(counts)

# Visualize by condition severity
agg_results['batch_severity'] = agg_results['kappa_mu'] + agg_results['var_b']
agg_results['bio_effect_strength'] = agg_results['bio_strength']
severity_summary = agg_results.groupby(['method', pd.cut(agg_results['batch_severity'], bins=[0, 1.0, 1.5, 2.0, 3.0], labels=['weak', 'moderate', 'strong', 'extreme'])]).agg({
  'PVCA_batch_variance': 'mean',
  'PVCA_bio_variance': 'mean',
  'f1': 'mean',
  'kBET': 'mean',
  'silhouette': 'mean'
}).round(4)
print("\n=== PERFORMANCE BY BATCH SEVERITY ===")
print(severity_summary)

# Best method recommendation by scenario
print("\n=== RECOMMENDATION BY SCENARIO ===")
for bio in param_grid['bio_strength']:
  for severity in ['weak', 'moderate', 'strong', 'extreme']:
    if severity == 'weak':
      conditions = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity'] <= 1.0)]
    elif severity == 'moderate':
      conditions = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity'] > 1.0) & (agg_results['batch_severity'] <= 1.5)]
    elif severity == 'strong':
      conditions = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity'] > 1.5) & (agg_results['batch_severity'] <= 2.0)]
    else:
      conditions = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity'] > 2.0)]
    if len(conditions) > 0:
      best_f1 = conditions.loc[conditions['f1'].idxmax(), 'method']
      best_batch = conditions.loc[conditions['kBET'].idxmin(), 'method']
      print(f"Bio={bio}, Batch={severity:8s}: Best F1={best_f1:10s}, Best batch removal={best_batch:10s}")

# Save detailed results
results_df.to_csv('results/method_comparison_detailed.csv', index=False)
agg_results.to_csv('results/method_comparison_aggregated.csv', index=False)
print("\nDetailed results saved to results/method_comparison_detailed.csv")
print("Aggregated results saved to results/method_comparison_aggregated.csv")

In [None]:
# Benchmark batch effect removal methods

import matplotlib.pyplot as plt
import seaborn as sns

# Load aggregated results
agg_results = pd.read_csv('results/method_comparison_aggregated.csv')

# 2x2 grid
fig = plt.figure(figsize=(18, 14))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# Batch effect reduction across methods (aggregate across all conditions)
ax1 = fig.add_subplot(gs[0, 0])
batch_metrics = ['PVCA_batch_variance', 'silhouette', 'kBET', 'ARI', 'comp_effect', 'pca_batch', 'LISI']
method_order = ['combat', 'percentile', 'ratio', 'harmony', 'limma', 'stratified']
method_labels = ['ComBat', 'Percentile', 'Ratio-ComBat', 'Harmony', 'Limma', 'Stratified']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
method_means = agg_results.groupby('method')[batch_metrics].mean()
method_means = method_means.reindex(method_order)
x = np.arange(len(batch_metrics))
width = 0.13
for i, method in enumerate(method_order):
  offset = (i - 2.5) * width
  values = method_means.loc[method].values
  ax1.bar(x + offset, values, width, label=method_labels[i], color=colors[i], alpha=0.8, edgecolor='black')
ax1.set_xlabel('Batch Effect Metric', fontsize=12, fontweight='bold')
ax1.set_ylabel('Mean Value', fontsize=12, fontweight='bold')
ax1.set_title('Panel A: Batch Effect Metrics by Method\n(averaged across all conditions)', fontsize=13, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(['PVCA-Batch↓', 'Silhouette↓', 'kBET↓', 'ARI↓', 'Comp.Eff↓', 'PCA↓', 'LISI↑'], fontsize=9)
ax1.legend(fontsize=9, ncol=2)
ax1.grid(alpha=0.3, axis='y')

# F1 score comparison across batch severity levels
ax2 = fig.add_subplot(gs[0, 1])
agg_results['batch_severity'] = pd.cut(agg_results['kappa_mu'] + agg_results['var_b'], bins=[0, 1.0, 1.5, 2.0, 3.0], labels=['Weak', 'Moderate', 'Strong', 'Extreme'])
severity_f1 = agg_results.groupby(['method', 'batch_severity'])['f1'].mean().unstack()
severity_f1 = severity_f1.reindex(method_order)
severity_f1.plot(kind='bar', ax=ax2, color=['#90EE90', '#FFD700', '#FF8C00', '#DC143C'], edgecolor='black', width=0.8)
ax2.set_xlabel('Method', fontsize=12, fontweight='bold')
ax2.set_ylabel('Mean F1 Score', fontsize=12, fontweight='bold')
ax2.set_title('Panel B: F1 Score by Batch Severity\n(DE recovery quality)', fontsize=13, fontweight='bold')
ax2.set_xticklabels(method_labels, rotation=45, ha='right', fontsize=10)
ax2.legend(title='Batch Severity', fontsize=9)
ax2.grid(alpha=0.3, axis='y')
ax2.set_ylim(0, 1)

# TP vs FP rates scatter
ax3 = fig.add_subplot(gs[1, 0])
method_summary = agg_results.groupby('method')[['tp_rate', 'fp_rate']].mean()
method_summary = method_summary.reindex(method_order)
for i, method in enumerate(method_order):
  ax3.scatter(method_summary.loc[method, 'fp_rate'], method_summary.loc[method, 'tp_rate'], 
             s=200, c=colors[i], label=method_labels[i], alpha=0.8, edgecolors='black', linewidths=2)
ax3.set_xlabel('Mean False Positive Rate', fontsize=12, fontweight='bold')
ax3.set_ylabel('Mean True Positive Rate', fontsize=12, fontweight='bold')
ax3.set_title('Panel C: TP vs FP Trade-off\n(averaged across all conditions)', fontsize=13, fontweight='bold')
ax3.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='Random')
ax3.legend(fontsize=9)
ax3.grid(alpha=0.3)
ax3.set_xlim(-0.05, max(method_summary['fp_rate']) * 1.1)
ax3.set_ylim(0, 1.05)

# Heatmap of best method by condition using composite score
ax4 = fig.add_subplot(gs[1, 1])
agg_results['batch_severity_score'] = agg_results['kappa_mu'] + agg_results['var_b']
# Composite score: minimize batch, maximize bio preservation, maximize DE recovery
# Normalize each component to [0,1] scale within each condition
pivot_data = []
bio_values = sorted(agg_results['bio_strength'].unique())
severity_bins = [(0, 1.0), (1.0, 1.5), (1.5, 2.0), (2.0, 3.0)]
severity_labels = ['Weak', 'Moderate', 'Strong', 'Extreme']
for bio in bio_values:
  row = []
  for (low, high), label in zip(severity_bins, severity_labels):
    subset = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity_score'] > low) & (agg_results['batch_severity_score'] <= high)]
    if len(subset) > 0:
      subset = subset.copy()
      subset['pvca_batch_norm'] = (subset['PVCA_batch_variance'].max() - subset['PVCA_batch_variance']) / (subset['PVCA_batch_variance'].max() - subset['PVCA_batch_variance'].min() + 1e-6)
      subset['pvca_bio_norm'] = (subset['PVCA_bio_variance'] - subset['PVCA_bio_variance'].min()) / (subset['PVCA_bio_variance'].max() - subset['PVCA_bio_variance'].min() + 1e-6)
      subset['f1_norm'] = (subset['f1'] - subset['f1'].min()) / (subset['f1'].max() - subset['f1'].min() + 1e-6)
      subset['composite_score'] = 0.5 * subset['pvca_batch_norm'] + 0.3 * subset['pvca_bio_norm'] + 0.2 * subset['f1_norm']
      best_method = subset.loc[subset['composite_score'].idxmax(), 'method']
      row.append(method_order.index(best_method))
    else:
      row.append(-1)
  pivot_data.append(row)
pivot_df = pd.DataFrame(pivot_data, index=[f'Bio={b}' for b in bio_values], columns=severity_labels)
im = ax4.imshow(pivot_df.values, cmap='tab10', aspect='auto', vmin=0, vmax=len(method_order)-1)
ax4.set_xticks(range(len(severity_labels)))
ax4.set_yticks(range(len(bio_values)))
ax4.set_xticklabels(severity_labels, fontsize=10)
ax4.set_yticklabels([f'Bio={b}' for b in bio_values], fontsize=10)
ax4.set_xlabel('Batch Severity', fontsize=12, fontweight='bold')
ax4.set_ylabel('Biological Signal Strength', fontsize=12, fontweight='bold')
ax4.set_title('Panel D: Best Method by Condition\n(composite: 50% batch removal, 30% bio preservation, 20% DE recovery)', fontsize=13, fontweight='bold')
for i in range(len(bio_values)):
  for j in range(len(severity_labels)):
    method_idx = pivot_df.iloc[i, j]
    if method_idx >= 0:
      ax4.text(j, i, method_labels[method_idx], ha='center', va='center', color='white', fontsize=9, fontweight='bold')
cbar = plt.colorbar(im, ax=ax4, ticks=range(len(method_order)))
cbar.set_ticklabels(method_labels)

plt.savefig('results/figures/fig3_method_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()

# PVCA variance decomposition
fig_pvca, (ax_pvca1, ax_pvca2) = plt.subplots(1, 2, figsize=(16, 6))
# Left: Stacked bar showing PVCA decomposition by method
pvca_decomp = agg_results.groupby('method')[['PVCA_batch_variance', 'PVCA_bio_variance', 'PVCA_residual_variance']].mean()
pvca_decomp = pvca_decomp.reindex(method_order)
pvca_decomp.plot(kind='bar', stacked=True, ax=ax_pvca1, color=['#DC143C', '#4169E1', '#D3D3D3'], edgecolor='black', width=0.7)
ax_pvca1.set_xlabel('Method', fontsize=12, fontweight='bold')
ax_pvca1.set_ylabel('Variance Explained (%)', fontsize=12, fontweight='bold')
ax_pvca1.set_title('PVCA Variance Decomposition by Method\n(averaged across all conditions)', fontsize=13, fontweight='bold')
ax_pvca1.set_xticklabels(method_labels, rotation=45, ha='right', fontsize=10)
ax_pvca1.legend(['Batch', 'Biological', 'Residual'], fontsize=10, loc='upper right')
ax_pvca1.set_ylim(0, 100)
ax_pvca1.grid(alpha=0.3, axis='y')
# Right: Bio vs Batch variance scatter
method_pvca = agg_results.groupby('method')[['PVCA_batch_variance', 'PVCA_bio_variance']].mean()
method_pvca = method_pvca.reindex(method_order)
for i, method in enumerate(method_order):
  ax_pvca2.scatter(method_pvca.loc[method, 'PVCA_batch_variance'], method_pvca.loc[method, 'PVCA_bio_variance'], 
                  s=200, c=colors[i], label=method_labels[i], alpha=0.8, edgecolors='black', linewidths=2)
ax_pvca2.plot([0, max(method_pvca['PVCA_batch_variance'])*1.1], [0, max(method_pvca['PVCA_batch_variance'])*1.1], 'k--', alpha=0.3, label='Equal variance')
ax_pvca2.set_xlabel('Batch Variance (%)', fontsize=12, fontweight='bold')
ax_pvca2.set_ylabel('Biological Variance (%)', fontsize=12, fontweight='bold')
ax_pvca2.set_title('PVCA: Biological vs Batch Variance\n(upper-left = ideal: low batch, high bio)', fontsize=13, fontweight='bold')
ax_pvca2.legend(fontsize=9)
ax_pvca2.grid(alpha=0.3)
ax_pvca2.axvline(x=10, color='orange', linestyle=':', alpha=0.5, label='Batch<10% threshold')
ax_pvca2.axhline(y=20, color='green', linestyle=':', alpha=0.5, label='Bio>20% threshold')
plt.tight_layout()
plt.savefig('results/figures/fig3_pvca_decomposition.pdf', dpi=300, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\n=== FIGURE 3 SUMMARY ===")
print("\nPVCA Variance Decomposition (mean % across all conditions)")
pvca_stats = agg_results.groupby('method')[['PVCA_batch_variance', 'PVCA_bio_variance', 'PVCA_residual_variance']].mean()
pvca_stats = pvca_stats.reindex(method_order)
pvca_stats.index = method_labels
print(pvca_stats.round(2))
print("\nPVCA Signal-to-Batch Ratio (Bio/Batch, higher = better)")
pvca_stats['bio_to_batch_ratio'] = pvca_stats['PVCA_bio_variance'] / (pvca_stats['PVCA_batch_variance'] + 1e-6)
print(pvca_stats['bio_to_batch_ratio'].round(3))
print("\nMean batch effect metrics")
print(method_means.round(3))
print("\nF1 by severity")
print(severity_f1.round(3))
print("\nTP vs FP rates")
print(method_summary.round(3))
print("\nWinner distribution (PVCA-based)")
print(pd.Series(pivot_df.values.flatten()).value_counts())

# Sensitivity analysis: Winner consistency across different criteria
print("\n=== SENSITIVITY ANALYSIS: WINNER CONSISTENCY ===")
scoring_schemes = {
  'PVCA-only': {'PVCA_batch_variance': -1.0},
  'PVCA+Bio': {'PVCA_batch_variance': -0.6, 'PVCA_bio_variance': 0.4},
  'Composite (main)': {'PVCA_batch_variance': -0.5, 'PVCA_bio_variance': 0.3, 'f1': 0.2},
  'F1-only': {'f1': 1.0}
}
sensitivity_results = {}
for scheme_name, weights in scoring_schemes.items():
  pivot_data = []
  for bio in bio_values:
    row = []
    for (low, high), label in zip(severity_bins, severity_labels):
      subset = agg_results[(agg_results['bio_strength'] == bio) & (agg_results['batch_severity_score'] > low) & (agg_results['batch_severity_score'] <= high)]
      if len(subset) > 0:
        subset = subset.copy()
        subset['score'] = 0
        for metric, weight in weights.items():
          if weight < 0:
            subset['score'] += weight * subset[metric]
          else:
            subset['score'] += weight * subset[metric]
        best_method = subset.loc[subset['score'].idxmax(), 'method']
        row.append(best_method)
      else:
        row.append('none')
    pivot_data.append(row)
  sensitivity_results[scheme_name] = pd.DataFrame(pivot_data, index=[f'Bio={b}' for b in bio_values], columns=severity_labels)
# Check consistency
print("\nWinner counts by scoring scheme:")
for scheme_name, df in sensitivity_results.items():
  counts = pd.Series(df.values.flatten()).value_counts()
  print(f"\n{scheme_name}:")
  print(counts)
# Calculate agreement between schemes
from sklearn.metrics import cohen_kappa_score
print("\nAgreement between scoring schemes (Cohen's kappa):")
schemes = list(scoring_schemes.keys())
for i, scheme1 in enumerate(schemes):
  for scheme2 in schemes[i+1:]:
    flat1 = sensitivity_results[scheme1].values.flatten()
    flat2 = sensitivity_results[scheme2].values.flatten()
    mask = (flat1 != 'none') & (flat2 != 'none')
    if mask.sum() > 0:
      kappa = cohen_kappa_score(flat1[mask], flat2[mask])
      print(f"  {scheme1} vs {scheme2}: κ = {kappa:.3f}")