# üî¨ Ultimate Model Comparison - Visual Analysis

Comprehensive visual comparison of 4 training strategies:
- Weighted vs Unweighted loss
- Full curve vs Cropped curve

**Goal:** Determine which strategy to use for 100k dataset training

In [1]:
import numpy as np
import torch
import pickle
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy import stats
import seaborn as sns

import xrd
import helpers as h
from model_common import PARAM_NAMES, RANGES

mpl.rcParams['figure.dpi'] = 100
sns.set_style('whitegrid')

print("‚úì Imports loaded")

‚úì Imports loaded


## üì¶ Load Comparison Results

In [2]:
# Load results from compare_models.py
with open('comparison_results.pkl', 'rb') as f:
    results = pickle.load(f)

configs = results['configs']
metrics = results['metrics']
predictions = results['predictions']
metadata = results['metadata']
X_true = results['X_true']

model_names = [c.name for c in configs]
print(f"‚úì Loaded results for {len(model_names)} models")
print(f"  Models: {model_names}")
print(f"  Dataset size: {X_true.shape[0]} samples")

AttributeError: Can't get attribute 'ModelConfig' on <module '__main__'>

## üìä 1. Overall Error Distribution Comparison

In [None]:
# Compare overall error distributions
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for i, param in enumerate(PARAM_NAMES):
    ax = axes[i]
    
    for model_name in model_names:
        abs_errors = metrics[model_name]['abs_errors'][:, i]
        ax.hist(abs_errors, bins=50, alpha=0.5, label=model_name, density=True)
    
    ax.set_xlabel(f'{param} Absolute Error')
    ax.set_ylabel('Density')
    ax.set_title(f'{param} Error Distribution')
    ax.legend(fontsize=8)
    ax.set_yscale('log')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('compare_error_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Saved: compare_error_distributions.png")

## üìà 2. MAE Comparison Bar Chart

In [None]:
# Bar chart comparing MAE across parameters
fig, ax = plt.subplots(figsize=(14, 6))

x = np.arange(len(PARAM_NAMES))
width = 0.2

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

for i, model_name in enumerate(model_names):
    mae_values = metrics[model_name]['mae']
    offset = (i - len(model_names)/2 + 0.5) * width
    bars = ax.bar(x + offset, mae_values, width, label=model_name, color=colors[i], alpha=0.8)
    
    # Add value labels on top of bars
    for j, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2e}',
                ha='center', va='bottom', fontsize=7, rotation=0)

ax.set_xlabel('Parameter', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean Absolute Error (MAE)', fontsize=12, fontweight='bold')
ax.set_title('MAE Comparison Across All Parameters', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(PARAM_NAMES)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
ax.set_yscale('log')

plt.tight_layout()
plt.savefig('compare_mae_barchart.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Saved: compare_mae_barchart.png")

## üéØ 3. Win Rate Heatmap

Shows which model performs best for each sample

In [None]:
# For each sample, determine which model has lowest error
n_samples = X_true.shape[0]
n_params = len(PARAM_NAMES)

# Calculate mean error per sample across all parameters
sample_errors = np.zeros((n_samples, len(model_names)))
for i, model_name in enumerate(model_names):
    # Mean absolute error per sample (averaged across 7 parameters)
    sample_errors[:, i] = np.mean(metrics[model_name]['abs_errors'], axis=1)

# Find best model for each sample
best_model_per_sample = np.argmin(sample_errors, axis=1)

# Count wins per model
win_counts = np.bincount(best_model_per_sample, minlength=len(model_names))
win_percentages = win_counts / n_samples * 100

# Create bar chart
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(model_names, win_percentages, color=colors, alpha=0.8)

# Add percentage labels
for bar, count, pct in zip(bars, win_counts, win_percentages):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{pct:.1f}%\n({count} samples)',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

ax.set_ylabel('Win Rate (%)', fontsize=12, fontweight='bold')
ax.set_title('Sample-wise Win Rate: Which Model Produces Lowest Overall Error', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('compare_win_rate.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Saved: compare_win_rate.png")

## üîç 4. Per-Parameter Win Rate Heatmap

In [None]:
# For each parameter, count wins per model
param_wins = np.zeros((n_params, len(model_names)))

for param_idx in range(n_params):
    param_errors = np.zeros((n_samples, len(model_names)))
    
    for model_idx, model_name in enumerate(model_names):
        param_errors[:, model_idx] = metrics[model_name]['abs_errors'][:, param_idx]
    
    best_per_sample = np.argmin(param_errors, axis=1)
    param_wins[param_idx] = np.bincount(best_per_sample, minlength=len(model_names))

# Convert to percentages
param_wins_pct = param_wins / n_samples * 100

# Create heatmap
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(param_wins_pct, cmap='RdYlGn', aspect='auto', vmin=0, vmax=100)

# Set ticks
ax.set_xticks(np.arange(len(model_names)))
ax.set_yticks(np.arange(n_params))
ax.set_xticklabels(model_names, rotation=45, ha='right')
ax.set_yticklabels(PARAM_NAMES)

# Add text annotations
for i in range(n_params):
    for j in range(len(model_names)):
        text = ax.text(j, i, f'{param_wins_pct[i, j]:.1f}%',
                       ha="center", va="center", color="black", fontweight='bold')

ax.set_title('Per-Parameter Win Rate Heatmap\n(% of samples where model has lowest error)',
             fontsize=14, fontweight='bold', pad=20)
fig.colorbar(im, ax=ax, label='Win Rate (%)')

plt.tight_layout()
plt.savefig('compare_param_win_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Saved: compare_param_win_heatmap.png")

## üé® 5. Side-by-Side Curve Comparison for Same Samples

Visually compare how different models predict the same curve

In [None]:
def plot_model_comparison_for_sample(sample_idx, dl=100e-8):
    """
    Compare all 4 models' predictions for a single sample.
    Shows rocking curves and deformation profiles.
    """
    true_params = X_true[sample_idx].numpy()
    
    # Generate true curve and profile
    true_curve, true_profile = xrd.compute_curve_and_profile(true_params.tolist(), dl=dl)
    
    fig, axes = plt.subplots(2, len(model_names) + 1, figsize=(20, 8))
    
    # Column 0: True
    axes[0, 0].plot(true_curve.X_DeltaTeta, true_curve.Y_R_vseZ, 'k-', linewidth=2, label='True')
    axes[0, 0].set_xlabel('ŒîŒò (arcsec)')
    axes[0, 0].set_ylabel('Intensity')
    axes[0, 0].set_title('TRUE\nRocking Curve', fontweight='bold')
    axes[0, 0].set_yscale('log')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[1, 0].plot(true_profile.X, true_profile.total_Y, 'k-', linewidth=2, label='True')
    axes[1, 0].set_xlabel('Depth (m)')
    axes[1, 0].set_ylabel('Deformation')
    axes[1, 0].set_title('TRUE\nDeformation Profile', fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Add true params
    true_str = h.fparam(arr=true_params)
    axes[1, 0].text(0.5, -0.25, f"TRUE: {true_str}",
                    transform=axes[1, 0].transAxes, fontsize=7,
                    verticalalignment='top', horizontalalignment='center',
                    bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3),
                    family='monospace')
    
    # Columns 1-4: Model predictions
    for col_idx, model_name in enumerate(model_names, start=1):
        pred_params = predictions[model_name][sample_idx].numpy()
        pred_curve, pred_profile = xrd.compute_curve_and_profile(pred_params.tolist(), dl=dl)
        
        # Rocking curve
        axes[0, col_idx].plot(true_curve.X_DeltaTeta, true_curve.Y_R_vseZ, 
                              'k-', linewidth=2, alpha=0.3, label='True')
        axes[0, col_idx].plot(pred_curve.X_DeltaTeta, pred_curve.Y_R_vseZ,
                              color=colors[col_idx-1], linestyle='--', linewidth=2, label='Predicted')
        axes[0, col_idx].set_xlabel('ŒîŒò (arcsec)')
        axes[0, col_idx].set_ylabel('Intensity')
        axes[0, col_idx].set_title(f'{model_name}\nRocking Curve', fontweight='bold')
        axes[0, col_idx].set_yscale('log')
        axes[0, col_idx].legend(fontsize=8)
        axes[0, col_idx].grid(True, alpha=0.3)
        
        # Deformation profile
        axes[1, col_idx].plot(true_profile.X, true_profile.total_Y,
                              'k-', linewidth=2, alpha=0.3, label='True')
        axes[1, col_idx].plot(pred_profile.X, pred_profile.total_Y,
                              color=colors[col_idx-1], linestyle='--', linewidth=2, label='Predicted')
        axes[1, col_idx].set_xlabel('Depth (m)')
        axes[1, col_idx].set_ylabel('Deformation')
        axes[1, col_idx].set_title(f'{model_name}\nDeformation Profile', fontweight='bold')
        axes[1, col_idx].legend(fontsize=8)
        axes[1, col_idx].grid(True, alpha=0.3)
        
        # Calculate errors
        errors = pred_params - true_params
        rel_errors = errors / (np.abs(true_params) + 1e-12) * 100
        mae = np.mean(np.abs(errors))
        
        # Add prediction info
        pred_str = h.fparam(arr=pred_params)
        info_text = f"PRED: {pred_str}\nMAE: {mae:.3e}"
        
        axes[1, col_idx].text(0.5, -0.25, info_text,
                              transform=axes[1, col_idx].transAxes, fontsize=7,
                              verticalalignment='top', horizontalalignment='center',
                              bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3),
                              family='monospace')
    
    plt.suptitle(f'Model Comparison - Sample #{sample_idx}', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    filename = f'compare_sample_{sample_idx:05d}.png'
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úì Saved: {filename}")

# Show comparison for a few random samples
np.random.seed(42)
sample_indices = np.random.choice(n_samples, size=3, replace=False)

for idx in sample_indices:
    plot_model_comparison_for_sample(idx)

## üéØ 6. Best and Worst Cases per Model

In [None]:
# For each model, find its best and worst predictions
for model_name in model_names:
    sample_errors = np.mean(metrics[model_name]['abs_errors'], axis=1)
    
    best_idx = np.argmin(sample_errors)
    worst_idx = np.argmax(sample_errors)
    
    print(f"\n{'='*80}")
    print(f"Model: {model_name}")
    print(f"  Best case:  sample #{best_idx:5d}, MAE = {sample_errors[best_idx]:.6e}")
    print(f"  Worst case: sample #{worst_idx:5d}, MAE = {sample_errors[worst_idx]:.6e}")
    print(f"{'='*80}")
    
    # Plot best case
    plot_model_comparison_for_sample(best_idx)
    
    # Plot worst case
    plot_model_comparison_for_sample(worst_idx)

## üìä 7. Error Correlation Between Models

Do models fail on the same samples?

In [None]:
# Calculate correlation of errors between models
error_matrix = np.zeros((len(model_names), len(model_names)))

for i, model1 in enumerate(model_names):
    errors1 = np.mean(metrics[model1]['abs_errors'], axis=1)
    for j, model2 in enumerate(model_names):
        errors2 = np.mean(metrics[model2]['abs_errors'], axis=1)
        correlation = np.corrcoef(errors1, errors2)[0, 1]
        error_matrix[i, j] = correlation

# Plot correlation heatmap
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(error_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)

ax.set_xticks(np.arange(len(model_names)))
ax.set_yticks(np.arange(len(model_names)))
ax.set_xticklabels(model_names, rotation=45, ha='right')
ax.set_yticklabels(model_names)

# Add correlation values
for i in range(len(model_names)):
    for j in range(len(model_names)):
        text = ax.text(j, i, f'{error_matrix[i, j]:.3f}',
                       ha="center", va="center", 
                       color="white" if abs(error_matrix[i, j]) > 0.5 else "black",
                       fontweight='bold', fontsize=12)

ax.set_title('Error Correlation Between Models\n(Do they fail on the same samples?)',
             fontsize=14, fontweight='bold', pad=20)
fig.colorbar(im, ax=ax, label='Correlation')

plt.tight_layout()
plt.savefig('compare_error_correlation.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Saved: compare_error_correlation.png")
print("\nInterpretation:")
print("  High correlation (>0.9): Models fail on similar samples")
print("  Low correlation (<0.7): Models have different failure modes")

## üí° 8. Final Recommendation

In [None]:
# Calculate overall ranking based on multiple metrics
print("\n" + "="*100)
print("üí° FINAL RECOMMENDATION FOR 100K TRAINING")
print("="*100)

rankings = []
for model_name in model_names:
    avg_mae = np.mean(metrics[model_name]['mae'])
    avg_mape = np.mean(metrics[model_name]['mape'])
    max_mape = np.max(metrics[model_name]['mape'])
    val_loss = metadata[model_name]['val_loss']
    
    # Calculate win rate
    sample_errors = np.mean(metrics[model_name]['abs_errors'], axis=1)
    wins = np.sum(sample_errors[:, None] <= sample_errors[None, :]) / n_samples
    
    rankings.append({
        'model': model_name,
        'avg_mae': avg_mae,
        'avg_mape': avg_mape,
        'max_mape': max_mape,
        'val_loss': val_loss,
    })

# Sort by avg_mape (lower is better)
rankings.sort(key=lambda x: x['avg_mape'])

print("\nRanking (by Average MAPE):")
print("-"*100)
for rank, r in enumerate(rankings, 1):
    medal = "ü•á" if rank == 1 else "ü•à" if rank == 2 else "ü•â" if rank == 3 else "  "
    print(f"{medal} {rank}. {r['model']:<25} | Avg MAPE: {r['avg_mape']:>6.2f}% | "
          f"Max MAPE: {r['max_mape']:>6.2f}% | Val Loss: {r['val_loss']:.6f}")

print("-"*100)

# Recommendation
best_model = rankings[0]['model']
print(f"\nüéØ RECOMMENDED STRATEGY: {best_model}")
print("\nReasons:")
print("  ‚úì Lowest average MAPE across all parameters")
print("  ‚úì Best validation loss")
print("  ‚úì Most consistent performance")

# Extract configuration
if 'unweighted' in best_model:
    print("\n‚öôÔ∏è  Configuration for 100k training:")
    print("    WEIGHTED_TRAINING = False")
else:
    print("\n‚öôÔ∏è  Configuration for 100k training:")
    print("    WEIGHTED_TRAINING = True")

if 'full' in best_model:
    print("    FULL_CURVE_TRAINING = True")
else:
    print("    FULL_CURVE_TRAINING = False")

print("\n" + "="*100)

# Additional insights
print("\nüìà Key Insights:")
if rankings[0]['model'].startswith('v3_unweighted'):
    print("  ‚Ä¢ Unweighted loss performs better - current loss weights may be suboptimal")
    print("  ‚Ä¢ Consider: loss weights might need retuning OR unweighted is inherently better")
    
if 'full' in rankings[0]['model']:
    print("  ‚Ä¢ Full curve training shows improvement despite extra computation")
    print("  ‚Ä¢ The cropped region [50:701] may be losing important information")
else:
    print("  ‚Ä¢ Cropped training sufficient - cropped region [50:701] contains key features")

print("\n" + "="*100)

## üîç 9. Loss Weights Analysis

Investigate if loss weights are helping or hurting

In [None]:
# Compare weighted vs unweighted for same training mode (full or cropped)
LOSS_WEIGHTS = np.array([1.0, 1.2, 1.0, 1.0, 1.5, 2.0, 2.5])

print("\n" + "="*100)
print("‚öñÔ∏è  LOSS WEIGHTS ANALYSIS")
print("="*100)
print(f"\nCurrent loss weights: {LOSS_WEIGHTS}")
print(f"Parameters:           {PARAM_NAMES}")

# Compare full curve: weighted vs unweighted
print("\nüìä FULL CURVE: Weighted vs Unweighted")
print("-"*100)
weighted_full_mae = metrics['v3_full']['mae']
unweighted_full_mae = metrics['v3_unweighted_full']['mae']

print(f"{'Parameter':<10} {'Weighted':<15} {'Unweighted':<15} {'Difference':<15} {'Winner':<15}")
print("-"*100)
for i, param in enumerate(PARAM_NAMES):
    diff = weighted_full_mae[i] - unweighted_full_mae[i]
    winner = "‚úì Unweighted" if diff > 0 else "‚úì Weighted"
    print(f"{param:<10} {weighted_full_mae[i]:<15.6e} {unweighted_full_mae[i]:<15.6e} "
          f"{diff:+15.6e} {winner:<15}")

# Compare cropped: weighted vs unweighted
print("\nüìä CROPPED: Weighted vs Unweighted")
print("-"*100)
weighted_crop_mae = metrics['v3']['mae']
unweighted_crop_mae = metrics['v3_unweighted']['mae']

print(f"{'Parameter':<10} {'Weighted':<15} {'Unweighted':<15} {'Difference':<15} {'Winner':<15}")
print("-"*100)
for i, param in enumerate(PARAM_NAMES):
    diff = weighted_crop_mae[i] - unweighted_crop_mae[i]
    winner = "‚úì Unweighted" if diff > 0 else "‚úì Weighted"
    print(f"{param:<10} {weighted_crop_mae[i]:<15.6e} {unweighted_crop_mae[i]:<15.6e} "
          f"{diff:+15.6e} {winner:<15}")

# Analysis
print("\nüí° Analysis:")
print("-"*100)

# Count how many parameters benefit from weights
full_wins_weighted = np.sum(unweighted_full_mae > weighted_full_mae)
crop_wins_weighted = np.sum(unweighted_crop_mae > weighted_crop_mae)

print(f"  Full curve:  Weighted wins {full_wins_weighted}/7 parameters")
print(f"  Cropped:     Weighted wins {crop_wins_weighted}/7 parameters")

if full_wins_weighted < 4 and crop_wins_weighted < 4:
    print("\n  ‚ùå CONCLUSION: Loss weights are HURTING performance")
    print("     Recommendation: Use WEIGHTED_TRAINING = False for 100k training")
    print("\n  üí° Possible reasons:")
    print("     ‚Ä¢ Current weights over-emphasize harder parameters at expense of easy ones")
    print("     ‚Ä¢ Weights create imbalanced gradients leading to suboptimal convergence")
    print("     ‚Ä¢ Natural loss balance (unweighted) works better for this problem")
    print("\n  üîß If you want to try weighted training again:")
    print("     ‚Ä¢ Try smaller weight differences (e.g., [1.0, 1.1, 1.0, 1.0, 1.2, 1.3, 1.5])")
    print("     ‚Ä¢ Or use dynamic weighting based on validation performance")
else:
    print("\n  ‚úì CONCLUSION: Loss weights are HELPING performance")
    print("     Recommendation: Use WEIGHTED_TRAINING = True for 100k training")

print("\n" + "="*100)