# GP Method Comparison

Compare different GP training approaches focusing on:
- Training time cost
- Prediction accuracy (test_plot style)
- Save all plots to save_dir

In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt
import time
from datetime import datetime
import os

jax.config.update("jax_enable_x64", True)
np.random.seed(42)

# Setup save directory
save_dir = f"trained_gp_models/GP_comparison_{datetime.now().strftime('%m%d%y')}"
os.makedirs(save_dir, exist_ok=True)
print(f"Results will be saved to: {save_dir}")



Results will be saved to: trained_gp_models/GP_comparison_090525


In [2]:
# Import modules
import importlib
import GP_dataloader
importlib.reload(GP_dataloader) 
from GP_dataloader import *
import train_GP
importlib.reload(train_GP) 
from train_GP import *

# Import improved trainers with fallback
try:
    from src.models.jax_compat_trainer import train_simple_conditional_gp
    COMPAT_AVAILABLE = True
    print("✅ JAX-compatible trainer available")
except ImportError:
    COMPAT_AVAILABLE = False
    print("❌ JAX-compatible trainer not available")

try:
    from src.models.improved_gp_trainer import train_improved_conditional_gp
    IMPROVED_AVAILABLE = True
    print("✅ Improved trainer available")
except ImportError:
    IMPROVED_AVAILABLE = False
    print("❌ Improved trainer not available")

JAX devices: [CpuDevice(id=0)]
Using device: TFRT_CPU_0
JAX devices: [CpuDevice(id=0)]
Using device: TFRT_CPU_0
✅ JAX-compatible trainer available
✅ Improved trainer available


In [3]:
# Load data
sim_indices_train = np.load('data/sparse_sampling_train_indices_random.npy')  
sim_indices_test = sim_indices_train[:20]  # Use subset for testing

print(f"Training: {len(sim_indices_train)} sims, Testing: {len(sim_indices_test)} sims")

# Load test data
X_test, y_test, r_bins, k_bins = prepare_GP_data(sim_indices_test, 'CAP', 'gas')
print(f"Test data: X={X_test.shape}, y={y_test.shape}")

Training: 204 sims, Testing: 20 sims
Getting gas profiles with CAP filter for 20 simulations...
Finished getting profiles in 1360 halos.
Profiles shape: (1360, 21), Mass shape: (1360,), Params shape: (1360, 35), PkRatio shape: (1360, 255)
Test data: X=(1360, 291), y=(1360, 21)


## Method 1: Original NN+GP Training

In [None]:
print("=== Method 1: Original NN+GP Training ===")
start_time = time.time()

# Use subset for fair comparison
sim_subset = sim_indices_train[:20]
gp_models_nn, best_params_nn, model_info_nn = train_NN_gp(
    sim_subset, filterType='CAP', ptype='gas', save=False
)

nn_train_time = time.time() - start_time
print(f"NN+GP training time: {nn_train_time:.1f}s")

=== Method 1: Original NN+GP Training ===
Getting gas profiles with CAP filter for 20 simulations...
Finished getting profiles in 1360 halos.
Profiles shape: (1360, 21), Mass shape: (1360,), Params shape: (1360, 35), PkRatio shape: (1360, 255)


Training GP for each r_bin:   0%|          | 0/21 [00:00<?, ?it/s]

Start Adamw training for r_bin 0: Initial loss = 43335.0625


Training GP for each r_bin:   0%|          | 0/21 [02:21<?, ?it/s, Step=1800, Loss=5183.408203, Best=5183.408203] 

r_bin 0 in 152.91s: Final loss = 5084.053711


Training GP for each r_bin:   5%|▍         | 1/21 [02:36<52:09, 156.49s/it, Step=1800, Loss=5183.408203, Best=5183.408203]

Start Adamw training for r_bin 1: Initial loss = 3734.029296875


Training GP for each r_bin:   5%|▍         | 1/21 [04:52<52:09, 156.49s/it, Step=1800, Loss=2684.320801, Best=2683.158691]

r_bin 1 in 149.72s: Final loss = 2680.865234


Training GP for each r_bin:  10%|▉         | 2/21 [05:07<48:30, 153.19s/it, Step=1800, Loss=2684.320801, Best=2683.158691]

Start Adamw training for r_bin 2: Initial loss = 4219.56103515625


Training GP for each r_bin:  10%|▉         | 2/21 [07:23<48:30, 153.19s/it, Step=1800, Loss=2927.848145, Best=2927.848145]

r_bin 2 in 150.12s: Final loss = 2890.435547


Training GP for each r_bin:  14%|█▍        | 3/21 [07:38<45:43, 152.40s/it, Step=1800, Loss=2927.848145, Best=2927.848145]

Start Adamw training for r_bin 3: Initial loss = 7949.49853515625


Training GP for each r_bin:  14%|█▍        | 3/21 [09:52<45:43, 152.40s/it, Step=1800, Loss=3535.622314, Best=3535.622314]

r_bin 3 in 147.67s: Final loss = 3530.483887


Training GP for each r_bin:  19%|█▉        | 4/21 [10:07<42:47, 151.00s/it, Step=1800, Loss=3535.622314, Best=3535.622314]

Start Adamw training for r_bin 4: Initial loss = 16447.71484375


Training GP for each r_bin:  19%|█▉        | 4/21 [10:09<42:47, 151.00s/it, Step=0, Loss=16447.714844, Best=16447.714844] 

In [None]:
# Generate NN+GP predictions
X_train_nn, y_train_nn, _, _ = prepare_GP_data(sim_subset, 'CAP', 'gas')

pred_means_nn = []
pred_vars_nn = []
model = build_NN_gp()

for i in range(len(gp_models_nn)):
    cond_gp = model.apply(best_params_nn[i], X_test, y_train_nn[:,i])[1]
    pred_means_nn.append(cond_gp.mean)
    pred_vars_nn.append(cond_gp.variance)

pred_means_nn = np.array(pred_means_nn)
pred_vars_nn = np.array(pred_vars_nn)
print(f"NN+GP predictions shape: {pred_means_nn.shape}")

## Method 2: Hierarchical GP Training

In [None]:
print("=== Method 2: Hierarchical GP Training ===")
start_time = time.time()

gp_models_hier, best_params_hier, model_info_hier = train_conditional_gp(
    sim_subset, build_hierarchical_gp, maxiter=1000,
    filterType='CAP', ptype='gas', save=False
)

hier_train_time = time.time() - start_time
print(f"Hierarchical GP training time: {hier_train_time:.1f}s")

In [None]:
# Generate Hierarchical GP predictions
pred_means_hier = []
pred_vars_hier = []

for i, gp_model in enumerate(gp_models_hier):
    _, cond_gp = gp_model.condition(y_train_nn[:, i], X_test)
    pred_means_hier.append(cond_gp.mean)
    pred_vars_hier.append(cond_gp.variance)

pred_means_hier = np.array(pred_means_hier)
pred_vars_hier = np.array(pred_vars_hier)
print(f"Hierarchical GP predictions shape: {pred_means_hier.shape}")

## Method 3: JAX-Compatible Trainer (if available)

In [None]:
if COMPAT_AVAILABLE:
    print("=== Method 3: JAX-Compatible Training ===")
    start_time = time.time()
    
    gp_models_compat, best_params_compat, model_info_compat = train_simple_conditional_gp(
        sim_subset, kernel_name='hierarchical', maxiter=500
    )
    
    compat_train_time = time.time() - start_time
    print(f"JAX-Compatible training time: {compat_train_time:.1f}s")
    
    # Generate predictions (simplified - use same test approach as hierarchical)
    pred_means_compat = pred_means_hier  # Placeholder for now
    pred_vars_compat = pred_vars_hier
else:
    print("Method 3 skipped - JAX-compatible trainer not available")
    compat_train_time = None

## Method 4: Improved Kernels (if available)

In [None]:
if IMPROVED_AVAILABLE:
    print("=== Method 4: Improved Kernels ===")
    start_time = time.time()
    
    try:
        gp_models_improved, best_params_improved, model_info_improved = train_improved_conditional_gp(
            sim_subset, kernel_name='multiscale', maxiter=1000
        )
        
        improved_train_time = time.time() - start_time
        print(f"Improved kernels training time: {improved_train_time:.1f}s")
        
        # Generate predictions
        pred_means_improved = []
        pred_vars_improved = []
        
        for i, gp_model in enumerate(gp_models_improved):
            _, cond_gp = gp_model.condition(y_train_nn[:, i], X_test)
            pred_means_improved.append(cond_gp.mean)
            pred_vars_improved.append(cond_gp.variance)
        
        pred_means_improved = np.array(pred_means_improved)
        pred_vars_improved = np.array(pred_vars_improved)
        
    except Exception as e:
        print(f"Improved kernels failed: {e}")
        improved_train_time = None
else:
    print("Method 4 skipped - Improved trainer not available")
    improved_train_time = None

## Training Time Comparison

In [None]:
# Training time comparison plot
methods = ['NN+GP', 'Hierarchical']
times = [nn_train_time, hier_train_time]
colors = ['blue', 'red']

if compat_train_time is not None:
    methods.append('JAX-Compatible')
    times.append(compat_train_time)
    colors.append('green')
    
if 'improved_train_time' in locals() and improved_train_time is not None:
    methods.append('Improved Kernels')
    times.append(improved_train_time)
    colors.append('orange')

plt.figure(figsize=(10, 6))
bars = plt.bar(methods, times, color=colors, alpha=0.7)
plt.ylabel('Training Time [s]', fontsize=14)
plt.title('GP Training Time Comparison', fontsize=16)
plt.xticks(rotation=45, ha='right')

# Add value labels on bars
for bar, time_val in zip(bars, times):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(times)*0.01,
             f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(f'{save_dir}/training_time_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("Training Time Summary:")
for method, time_val in zip(methods, times):
    print(f"  {method}: {time_val:.1f}s")

## Test Plot Comparison (Your Format)

In [None]:
# Compute ground truth statistics
upper = np.quantile(y_test, 0.25, axis=0)
lower = np.quantile(y_test, 0.75, axis=0)
median = np.median(y_test, axis=0)
yerr_lower = np.abs(median - lower)
yerr_upper = np.abs(upper - median)
yerr_truth = [yerr_lower, yerr_upper]

print(f"Ground truth computed for {len(r_bins)} radius bins")

In [None]:
# Main comparison plot (your format)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

# Plot 1: Predictions vs Data
ax1.errorbar(r_bins, median, yerr=yerr_truth, fmt='o', capsize=5, capthick=2, 
             linewidth=2, markersize=6, color='black', label='Ground Truth')

# NN+GP predictions
upper_pred_nn = np.quantile(pred_means_nn, 0.25, axis=1)
lower_pred_nn = np.quantile(pred_means_nn, 0.75, axis=1)
median_pred_nn = np.mean(pred_means_nn, axis=1)
yerr_lower_nn = np.abs(median_pred_nn - lower_pred_nn)
yerr_upper_nn = np.abs(upper_pred_nn - median_pred_nn)
yerr_nn = [yerr_lower_nn, yerr_upper_nn]

ax1.errorbar(r_bins, median_pred_nn, yerr=yerr_nn, fmt='s', capsize=5, 
             capthick=2, linewidth=2, markersize=6, color='blue', label='NN+GP')
ax1.fill_between(r_bins, median_pred_nn - np.mean(np.sqrt(pred_vars_nn), axis=1), 
                 median_pred_nn + np.mean(np.sqrt(pred_vars_nn), axis=1), 
                 color='blue', alpha=0.2, label='NN+GP 1σ')

# Hierarchical GP predictions  
upper_pred_hier = np.quantile(pred_means_hier, 0.25, axis=1)
lower_pred_hier = np.quantile(pred_means_hier, 0.75, axis=1)
median_pred_hier = np.mean(pred_means_hier, axis=1)
yerr_lower_hier = np.abs(median_pred_hier - lower_pred_hier)
yerr_upper_hier = np.abs(upper_pred_hier - median_pred_hier)
yerr_hier = [yerr_lower_hier, yerr_upper_hier]

ax1.errorbar(r_bins, median_pred_hier, yerr=yerr_hier, fmt='^', capsize=5,
             capthick=2, linewidth=2, markersize=6, color='red', label='Hierarchical GP')
ax1.fill_between(r_bins, median_pred_hier - np.mean(np.sqrt(pred_vars_hier), axis=1),
                 median_pred_hier + np.mean(np.sqrt(pred_vars_hier), axis=1), 
                 color='red', alpha=0.2, label='Hierarchical GP 1σ')

ax1.set_yscale('log')
ax1.legend()
ax1.set_xlabel('Radius [Mpc/h]', fontsize=14)
ax1.set_ylabel(r'CAP gas Profile [$M_\odot$/h]', fontsize=14)
ax1.set_title('GP Method Comparison: Predictions vs Ground Truth', fontsize=16)

# Plot 2: Percentage Error
percent_error = lambda pred: 100 * (pred - median) / median

ax2.plot(r_bins, percent_error(median_pred_nn), marker='s', linestyle='-', 
         color='blue', linewidth=2, markersize=6, label='NN+GP Error')
ax2.fill_between(r_bins, 
                 percent_error(median_pred_nn - np.mean(np.sqrt(pred_vars_nn), axis=1)),
                 percent_error(median_pred_nn + np.mean(np.sqrt(pred_vars_nn), axis=1)),
                 color='blue', alpha=0.2)

ax2.plot(r_bins, percent_error(median_pred_hier), marker='^', linestyle='-',
         color='red', linewidth=2, markersize=6, label='Hierarchical GP Error')
ax2.fill_between(r_bins,
                 percent_error(median_pred_hier - np.mean(np.sqrt(pred_vars_hier), axis=1)),
                 percent_error(median_pred_hier + np.mean(np.sqrt(pred_vars_hier), axis=1)),
                 color='red', alpha=0.2)

ax2.axhline(0, color='k', linestyle='--', linewidth=1)
ax2.set_xlabel('Radius [Mpc/h]', fontsize=14)
ax2.set_ylabel('Percentage Error [%]', fontsize=14)
ax2.set_title('GP Prediction Percentage Error', fontsize=16)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{save_dir}/test_plot_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary Statistics

In [None]:
# Compute summary metrics
def compute_summary_metrics(pred_means, pred_vars, ground_truth, method_name):
    pred_median = np.mean(pred_means, axis=1)
    
    # Only use valid (non-NaN) data points
    valid_mask = ~(np.isnan(pred_median) | np.isnan(ground_truth))
    
    if not np.any(valid_mask):
        return f"{method_name}: No valid predictions"
    
    pred_valid = pred_median[valid_mask]
    gt_valid = ground_truth[valid_mask]
    
    # Compute metrics
    mse = np.mean((pred_valid - gt_valid)**2)
    mae = np.mean(np.abs(pred_valid - gt_valid))
    mape = np.mean(np.abs((pred_valid - gt_valid) / gt_valid)) * 100
    
    # High radius performance (last 5 bins)
    high_r_mask = valid_mask[-5:]
    if np.any(high_r_mask):
        high_r_mape = np.mean(np.abs((pred_median[-5:][high_r_mask] - ground_truth[-5:][high_r_mask]) / 
                                   ground_truth[-5:][high_r_mask])) * 100
    else:
        high_r_mape = np.nan
    
    return {
        'method': method_name,
        'mse': mse,
        'mae': mae,
        'mape': mape,
        'high_radius_mape': high_r_mape,
        'n_valid': np.sum(valid_mask)
    }

print("\n" + "="*60)
print("SUMMARY METRICS")
print("="*60)

# Compute metrics for all methods
metrics_nn = compute_summary_metrics(pred_means_nn, pred_vars_nn, median, 'NN+GP')
metrics_hier = compute_summary_metrics(pred_means_hier, pred_vars_hier, median, 'Hierarchical GP')

all_metrics = [metrics_nn, metrics_hier]

print(f"{'Method':<15} {'Valid':<5} {'MSE':<10} {'MAE':<10} {'MAPE%':<8} {'HighR%':<8}")
print("-" * 70)

for metrics in all_metrics:
    if isinstance(metrics, dict):
        print(f"{metrics['method']:<15} {metrics['n_valid']:<5} "
              f"{metrics['mse']:<10.2e} {metrics['mae']:<10.2e} "
              f"{metrics['mape']:<8.1f} {metrics['high_radius_mape']:<8.1f}")
    else:
        print(metrics)

print(f"\nTraining times:")
for method, time_val in zip(methods, times):
    print(f"  {method}: {time_val:.1f}s")

print(f"\n📁 All plots saved to: {save_dir}/")

In [None]:
# Save summary data
summary = {
    'timestamp': datetime.now().isoformat(),
    'training_times': dict(zip(methods, times)),
    'metrics': {
        'nn_gp': metrics_nn if isinstance(metrics_nn, dict) else {'error': str(metrics_nn)},
        'hierarchical': metrics_hier if isinstance(metrics_hier, dict) else {'error': str(metrics_hier)}
    },
    'data_info': {
        'n_train_sims': len(sim_subset),
        'n_test_sims': len(sim_indices_test),
        'n_radius_bins': len(r_bins),
        'filter_type': 'CAP',
        'particle_type': 'gas'
    }
}

import json
with open(f'{save_dir}/comparison_summary.json', 'w') as f:
    json.dump(summary, f, indent=2, default=str)

print(f"Summary saved to: {save_dir}/comparison_summary.json")
print("\n✅ GP method comparison complete!")