# Batch Correction Pipeline - Two-Phase Execution

This notebook demonstrates the new two-phase batch correction pipeline:
- **Phase 1**: Data generation (calls `pipeline.py` to simulate data)
- **Phase 2**: Correction & evaluation


# Example 1: ComBat correction based on simplified-simulated data
Steps:

1. Define Dirichlet parameters directly (uniform alpha_H, heterogeneous alpha_U scaling)
2. Generate clean simulated data with biological ground truth
3. Apply batch effects 
4. Quantify batch effects before correction
5. Apply ComBat batch correction
6. Quantify batch effects after correction

Run batch correction pipeline across different parameter combinations to evaluate:
1. Batch effect correction effectiveness (5 metrics)
2. Biological signal preservation (2 metrics) 
3. Differential expression recovery (4 metrics)

Parameter grid: 4 × 4 = 16 combinations of kappa_mu and var_b values

In [1]:
import os
import sys
import yaml

project_root = '../..'
if project_root not in sys.path:
    sys.path.append(project_root)

from use_cases.batch_correction.correction import run_correction

config_path = os.path.join(project_root, 'sample_config/simlified_mode_config.yaml')

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded from:", config_path)
print(f"Output directory: {config.get('output_dir')}")
print(f"Seeds to run: {config.get('random_seeds')}")

Configuration loaded from: ../../sample_config/simlified_mode_config.yaml
Output directory: results/simplified_mode_1
Seeds to run: [42, 123, 2024, 7, 99, 56, 88, 314, 271, 1618]


In [None]:

results = run_correction(config)

print("\nPipeline completed!")
print(f"Total results: {len(results)}")

[SYSTEM] Using temp cache: /tmp/glycoforge_cache_9kwyhvkv
Generated random batch effects:
  Batch 1: 16 glycans (10↑, 6↓)
  Batch 2: 10 glycans (5↑, 5↓)
  Batch 3: 7 glycans (5↑, 2↓)
PHASE 1: DATA GENERATION
Parameter combinations: 16
Seeds per combination: 10
Total tasks: 160

Processing combination 1/16: {'kappa_mu': 0.5, 'var_b': 0.5}
  [1/160] Generating data: seed=42
  [2/160] Generating data: seed=123
  [3/160] Generating data: seed=2024
  [4/160] Generating data: seed=7
  [5/160] Generating data: seed=99
  [6/160] Generating data: seed=56
  [7/160] Generating data: seed=88
  [8/160] Generating data: seed=314
  [9/160] Generating data: seed=271
  [10/160] Generating data: seed=1618

Processing combination 2/16: {'kappa_mu': 0.5, 'var_b': 1.0}
  [11/160] Generating data: seed=42
  [12/160] Generating data: seed=123
  [13/160] Generating data: seed=2024
  [14/160] Generating data: seed=7
  [15/160] Generating data: seed=99
  [16/160] Generating data: seed=56
  [17/160] Generating d

In [None]:
# Visualize the results from the parameter grid search
# The plotting function will automatically scan the output directory for results.
from visualization import plot_parameter_grid_metrics
# Identify grid parameters (lists in config)
grid_params = {k: v for k, v in config.items() if isinstance(v, list)}

if grid_params:
    print(f"Plotting results for parameters: {list(grid_params.keys())}")
    
    # Ensure output directory ends with a separator for the save path prefix
    output_dir = config.get('output_dir')
    save_prefix = os.path.join(output_dir, '') if output_dir else None
    
    plot_parameter_grid_metrics(
        results_dir=output_dir,
        save_path=save_prefix
    )
else:
    print("No parameter grid found (single run configuration). Skipping grid summary plots.")

In [None]:
import json
from pprint import pprint

paper_json_path = os.path.join(project_root, 
    '/proj/naiss2024-5-630/users/x_siyhu/glycowork-batch-correction/test_batch_correction/results/unified_simulated/kappa_mu_0.5_var_b_0.5/comprehensive_metrics_seed2024.json')
new_json_path = os.path.join(project_root, 
    'use_cases/batch_correction/results/simplified_mode/kappa_mu_0.5_var_b_0.5/comprehensive_metrics_seed2024.json')

with open(paper_json_path, 'r') as f:
    paper_data = json.load(f)

with open(new_json_path, 'r') as f:
    new_data = json.load(f)

print("Paper JSON keys:", list(paper_data.keys()))
print("\nNew JSON keys:", list(new_data.keys()))

In [None]:
def compare_json_structures(data1, data2, path="root"):
    """Recursively compare two JSON structures and find common keys"""
    common_keys = []
    different_values = []
    same_values = []
    
    if isinstance(data1, dict) and isinstance(data2, dict):
        keys1 = set(data1.keys())
        keys2 = set(data2.keys())
        common = keys1 & keys2
        
        for key in common:
            new_path = f"{path}.{key}"
            common_keys.append(new_path)
            
            val1 = data1[key]
            val2 = data2[key]
            
            if isinstance(val1, (dict, list)):
                sub_common, sub_same, sub_diff = compare_json_structures(val1, val2, new_path)
                common_keys.extend(sub_common)
                same_values.extend(sub_same)
                different_values.extend(sub_diff)
            else:
                if val1 == val2:
                    same_values.append((new_path, val1))
                else:
                    different_values.append((new_path, val1, val2))
    
    elif isinstance(data1, list) and isinstance(data2, list):
        min_len = min(len(data1), len(data2))
        for i in range(min_len):
            new_path = f"{path}[{i}]"
            if isinstance(data1[i], (dict, list)):
                sub_common, sub_same, sub_diff = compare_json_structures(data1[i], data2[i], new_path)
                common_keys.extend(sub_common)
                same_values.extend(sub_same)
                different_values.extend(sub_diff)
            else:
                common_keys.append(new_path)
                if data1[i] == data2[i]:
                    same_values.append((new_path, data1[i]))
                else:
                    different_values.append((new_path, data1[i], data2[i]))
    
    return common_keys, same_values, different_values

common_keys, same_values, different_values = compare_json_structures(paper_data, new_data)

print("=" * 80)
print("COMPARISON SUMMARY")
print("=" * 80)
print(f"\nTotal common keys: {len(set(common_keys))}")
print(f"Same values: {len(same_values)}")
print(f"Different values: {len(different_values)}")

In [None]:
print("\n" + "=" * 80)
print("KEYS WITH SAME VALUES")
print("=" * 80)
for path, value in same_values[:20]:
    print(f"{path}: {value}")
if len(same_values) > 20:
    print(f"\n... and {len(same_values) - 20} more matching values")

In [None]:
print("\n" + "=" * 80)
print("KEYS WITH DIFFERENT VALUES")
print("=" * 80)
for path, val1, val2 in different_values:
    if isinstance(val1, (int, float, str, bool, type(None))):
        print(f"\n{path}:")
        print(f"  Paper:  {val1}")
        print(f"  New:    {val2}")
        if isinstance(val1, float) and isinstance(val2, float):
            diff = abs(val1 - val2)
            rel_diff = diff / abs(val1) if val1 != 0 else float('inf')
            print(f"  Diff:   {diff:.6f} (relative: {rel_diff:.2%})")

In [None]:
print("\n" + "=" * 80)
print("FOCUS: BATCH EFFECT METRICS AFTER CORRECTION")
print("=" * 80)

if 'batch_effect_metrics' in paper_data and 'after_correction' in paper_data['batch_effect_metrics']:
    paper_after = paper_data['batch_effect_metrics']['after_correction']
    new_after = new_data['batch_effect_metrics']['after_correction']
    
    print("\nMetric comparison:")
    for key in paper_after.keys():
        if key in new_after:
            p_val = paper_after[key]
            n_val = new_after[key]
            if isinstance(p_val, (int, float)) and isinstance(n_val, (int, float)):
                match = "✓ MATCH" if abs(p_val - n_val) < 1e-6 else "✗ DIFFER"
                print(f"\n{key}: {match}")
                print(f"  Paper: {p_val:.6f}")
                print(f"  New:   {n_val:.6f}")
                if abs(p_val - n_val) >= 1e-6:
                    diff = abs(p_val - n_val)
                    rel_diff = diff / abs(p_val) if p_val != 0 else float('inf')
                    print(f"  Diff:  {diff:.6f} ({rel_diff:.2%})")

# Example 2: ComBat correction based on hybrid-simulated data

Steps:

1. Load real-world glycomics data (CSV)
2. Estimate biological effect sizes from real data (Robust CLR-space processing)
3. Generate clean simulated data preserving real biological signal
4. Apply batch effects
5. Quantify batch effects before correction
6. Apply ComBat batch correction
7. Quantify batch effects after correction

Run batch correction pipeline across different parameter combinations to evaluate:
1. Batch effect correction effectiveness (5 metrics)
2. Biological signal preservation (2 metrics)
3. Differential expression recovery (4 metrics)

Parameter grid: Defined in `hybrid_mode_config.yaml`