In [None]:
# Setup and imports
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch

# Configure matplotlib
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# Add project paths
PROJECT_ROOT = Path('/data/bfys/gscriven/TE_stack/Rec/Tr/TrackExtrapolators/experiments/next_generation')
sys.path.insert(0, str(PROJECT_ROOT / 'models'))
sys.path.insert(0, str(PROJECT_ROOT / 'analysis'))

# Import analysis modules
from analyze_models import TrackExtrapolatorAnalyzer
from physics_analysis import PhysicsAnalyzer
from trajectory_visualizer import TrajectoryVisualizer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
MODELS_DIR = PROJECT_ROOT / 'trained_models'
DATA_PATH = PROJECT_ROOT / 'data_generation/data/training_50M.npz'
OUTPUT_DIR = PROJECT_ROOT / 'analysis/results'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Number of samples for analysis (use more for final analysis)
N_ANALYSIS_SAMPLES = 100000

print(f"Models directory: {MODELS_DIR}")
print(f"Data path: {DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

## 1. Load Models and Data

First, let's load all trained models and the test data.

In [None]:
# Initialize the main analyzer
analyzer = TrackExtrapolatorAnalyzer(
    models_dir=MODELS_DIR,
    data_path=DATA_PATH
)

# Load test data
analyzer.load_data(n_samples=N_ANALYSIS_SAMPLES)

# Load all trained models
analyzer.load_all_models(pattern='*_v1')

print(f"\nLoaded {len(analyzer.models)} models")

In [None]:
# Get list of all models and their types
model_names = list(analyzer.models.keys())

# Categorize by type
mlp_models = [m for m in model_names if 'mlp' in m.lower() and 'pinn' not in m.lower() and 'res' not in m.lower()]
resmlp_models = [m for m in model_names if 'resmlp' in m.lower()]
pinn_models = [m for m in model_names if 'pinn' in m.lower() and 'rkpinn' not in m.lower()]
rkpinn_models = [m for m in model_names if 'rkpinn' in m.lower()]

print(f"Model Types:")
print(f"  MLP: {len(mlp_models)}")
print(f"  ResidualMLP: {len(resmlp_models)}")
print(f"  PINN: {len(pinn_models)}")
print(f"  RK-PINN: {len(rkpinn_models)}")

## 2. Model Performance Overview

Let's compute and display the performance statistics for all models.

In [None]:
# Compute statistics for all models
all_stats = analyzer.compute_statistical_summary(model_names)

# Sort by position error
sorted_stats = sorted(all_stats.items(), key=lambda x: x[1]['pos_mean'])

# Display top 15 models
print("="*90)
print(f"{'Rank':<5} {'Model':<30} {'Type':<12} {'Params':>10} {'Pos Err (mm)':>12} {'Slope (mrad)':>12}")
print("="*90)

for rank, (name, stats) in enumerate(sorted_stats[:15], 1):
    print(f"{rank:<5} {name:<30} {stats['model_type']:<12} {stats['parameters']:>10,} "
          f"{stats['pos_mean']:>12.4f} {stats['slope_mean_mrad']:>12.4f}")

print("="*90)
print(f"\nüèÜ Best Model: {sorted_stats[0][0]}")
print(f"   Position Error: {sorted_stats[0][1]['pos_mean']:.4f} ¬± {sorted_stats[0][1]['pos_std']:.4f} mm")

In [None]:
# Create performance comparison plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Group by model type
colors = {'MLP': 'blue', 'ResidualMLP': 'green', 'PINN': 'red', 'RKPINN': 'orange'}

for name, stats in all_stats.items():
    model_type = stats['model_type']
    if model_type == 'RK-PINN':
        model_type = 'RKPINN'
    color = colors.get(model_type, 'gray')
    
    # Position error vs parameters
    axes[0].scatter(stats['parameters'], stats['pos_mean'], 
                   c=color, s=80, alpha=0.7, label=model_type)
    
    # Slope error vs parameters  
    axes[1].scatter(stats['parameters'], stats['slope_mean_mrad'],
                   c=color, s=80, alpha=0.7)
    
    # Position vs Slope error (trade-off)
    axes[2].scatter(stats['pos_mean'], stats['slope_mean_mrad'],
                   c=color, s=80, alpha=0.7)

# Remove duplicate legend entries
handles, labels = axes[0].get_legend_handles_labels()
by_label = dict(zip(labels, handles))
axes[0].legend(by_label.values(), by_label.keys(), loc='upper right')

axes[0].set_xlabel('Parameters')
axes[0].set_ylabel('Position Error [mm]')
axes[0].set_title('Position Error vs Model Size')
axes[0].set_xscale('log')

axes[1].set_xlabel('Parameters')
axes[1].set_ylabel('Slope Error [mrad]')
axes[1].set_title('Slope Error vs Model Size')
axes[1].set_xscale('log')

axes[2].set_xlabel('Position Error [mm]')
axes[2].set_ylabel('Slope Error [mrad]')
axes[2].set_title('Error Trade-off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Trajectory Visualization

Let's visualize how the models predict track trajectories.

In [None]:
# Select top models from each category for visualization
best_models = []

# Best MLP
mlp_sorted = [(n, all_stats[n]['pos_mean']) for n in mlp_models if n in all_stats]
if mlp_sorted:
    best_models.append(min(mlp_sorted, key=lambda x: x[1])[0])

# Best ResidualMLP
resmlp_sorted = [(n, all_stats[n]['pos_mean']) for n in resmlp_models if n in all_stats]
if resmlp_sorted:
    best_models.append(min(resmlp_sorted, key=lambda x: x[1])[0])

# Best PINN
pinn_sorted = [(n, all_stats[n]['pos_mean']) for n in pinn_models if n in all_stats]
if pinn_sorted:
    best_models.append(min(pinn_sorted, key=lambda x: x[1])[0])

# Best RK-PINN
rkpinn_sorted = [(n, all_stats[n]['pos_mean']) for n in rkpinn_models if n in all_stats]
if rkpinn_sorted:
    best_models.append(min(rkpinn_sorted, key=lambda x: x[1])[0])

print("Best models per category:")
for m in best_models:
    print(f"  {m}: {all_stats[m]['pos_mean']:.4f} mm")

In [None]:
# Plot trajectory comparison
fig = analyzer.plot_trajectory_comparison(
    best_models,
    n_tracks=4,
    save_path=OUTPUT_DIR / 'trajectory_comparison.png'
)
plt.show()

In [None]:
# Plot residual distributions
fig = analyzer.plot_trajectory_residuals(
    best_models,
    save_path=OUTPUT_DIR / 'residual_distributions.png'
)
plt.show()

## 4. Physics Constraint Analysis

Now let's analyze whether the models correctly learn the underlying physics, particularly:
- **ty conservation**: In a vertical magnetic field, the y-slope should be conserved
- **Charge consistency**: Opposite charges should bend in opposite directions

In [None]:
# ty Conservation Analysis
ty_results = analyzer.analyze_ty_conservation(
    best_models,
    save_path=OUTPUT_DIR / 'ty_conservation.png'
)
plt.show()

print("\nty Conservation Metrics (lower = better physics learning):")
print(f"{'Model':<30} {'Mean Œîty':>12} {'Std Œîty':>12} {'RMSE':>12}")
print("-"*70)
for name, metrics in ty_results.items():
    print(f"{name:<30} {metrics['mean_dty']*1000:>12.4f}mrad {metrics['std_dty']*1000:>12.4f}mrad {metrics['rmse_dty']*1000:>12.4f}mrad")

In [None]:
# Charge Consistency Analysis
charge_results = analyzer.analyze_charge_consistency(
    best_models,
    save_path=OUTPUT_DIR / 'charge_consistency.png'
)
plt.show()

print("\nCharge Consistency Metrics:")
print(f"{'Model':<30} {'Asymmetry':>12} {'q+ Err':>12} {'q- Err':>12}")
print("-"*70)
for name, metrics in charge_results.items():
    print(f"{name:<30} {metrics['asymmetry']:>12.4f} {metrics['pos_err_mean']:>12.4f}mm {metrics['neg_err_mean']:>12.4f}mm")

## 5. Momentum-Dependent Performance

Low momentum tracks bend more in the magnetic field and are harder to extrapolate accurately. Let's analyze how model performance varies with momentum.

In [None]:
# Momentum dependence analysis
momentum_results = analyzer.analyze_momentum_dependence(
    best_models,
    n_bins=20,
    save_path=OUTPUT_DIR / 'momentum_dependence.png'
)
plt.show()

## 6. Advanced Physics Analysis: Lorentz Force

The Lorentz force $\vec{F} = q(\vec{v} \times \vec{B})$ dictates how charged particles bend in a magnetic field:
- $\frac{d(tx)}{dz} \propto \frac{q}{p}$ (for vertical B field)
- $\frac{d(ty)}{dz} \approx 0$ (ty conserved)

Let's test whether models correctly capture this physics.

In [None]:
# Initialize physics analyzer
physics = PhysicsAnalyzer()
physics.load_data(DATA_PATH, n_samples=N_ANALYSIS_SAMPLES)

# Load models
for model_name in best_models:
    physics.load_model(MODELS_DIR / model_name)

print(f"Loaded {len(physics.models)} models for physics analysis")

In [None]:
# Lorentz Force Analysis
lorentz_results = physics.analyze_lorentz_force(
    best_models,
    save_path=OUTPUT_DIR / 'lorentz_force.png'
)
plt.show()

print("\nLorentz Force Learning Metrics:")
print(f"{'Model':<30} {'Slope Ratio':>12} {'R¬≤':>12} {'Œîty Std':>12}")
print("-"*70)
for name, metrics in lorentz_results.items():
    print(f"{name:<30} {metrics['slope_ratio']:>12.4f} {metrics['dtx_vs_qop_r2']:>12.4f} {metrics['dty_std']*1000:>12.4f}mrad")

In [None]:
# Phase Space Analysis
physics.analyze_phase_space(
    best_models,
    save_path=OUTPUT_DIR / 'phase_space.png'
)
plt.show()

In [None]:
# Systematic Error Analysis
systematic_results = physics.analyze_systematic_errors(
    best_models,
    save_path=OUTPUT_DIR / 'systematic_errors.png'
)
plt.show()

print("\nSystematic Error Analysis:")
print(f"{'Model':<30} {'X Bias':>12} {'X Random':>12} {'Bias/Random':>12}")
print("-"*70)
for name, metrics in systematic_results.items():
    print(f"{name:<30} {metrics['dx_bias']:>12.4f}mm {metrics['dx_random']:>12.4f}mm {metrics['bias_to_random_x']:>12.4f}")

## 7. PINN vs MLP Deep Comparison

Let's specifically examine whether physics-informed constraints (ty conservation, charge consistency) actually improve model performance.

In [None]:
# Comprehensive PINN constraint analysis
pinn_constraint_results = physics.analyze_pinn_constraints(
    best_models,
    save_path=OUTPUT_DIR / 'pinn_constraints.png'
)
plt.show()

In [None]:
# Overall PINN vs MLP comparison
comparison = analyzer.compare_pinn_vs_mlp(
    save_path=OUTPUT_DIR / 'pinn_vs_mlp.png'
)
plt.show()

## 8. Summary and Conclusions

Let's summarize the key findings from our analysis.

In [None]:
# Generate summary report
print("="*80)
print("TRACK EXTRAPOLATOR MODEL ANALYSIS SUMMARY")
print("="*80)

# Best overall model
best_overall = sorted_stats[0]
print(f"\nüèÜ BEST OVERALL MODEL: {best_overall[0]}")
print(f"   Type: {best_overall[1]['model_type']}")
print(f"   Parameters: {best_overall[1]['parameters']:,}")
print(f"   Position Error: {best_overall[1]['pos_mean']:.4f} ¬± {best_overall[1]['pos_std']:.4f} mm")
print(f"   Slope Error: {best_overall[1]['slope_mean_mrad']:.4f} ¬± {best_overall[1]['slope_std_mrad']:.4f} mrad")

# Best by category
print("\nüìä BEST BY CATEGORY:")
for category, models in [('MLP', mlp_models), ('ResidualMLP', resmlp_models), 
                          ('PINN', pinn_models), ('RK-PINN', rkpinn_models)]:
    if models:
        stats_list = [(n, all_stats[n]['pos_mean']) for n in models if n in all_stats]
        if stats_list:
            best = min(stats_list, key=lambda x: x[1])
            print(f"   {category}: {best[0]} ({best[1]:.4f} mm)")

# Physics learning assessment
print("\nüî¨ PHYSICS LEARNING ASSESSMENT:")
if lorentz_results:
    best_lorentz = min(lorentz_results.items(), key=lambda x: abs(1 - x[1]['slope_ratio']))
    print(f"   Best Lorentz Force Learning: {best_lorentz[0]} (slope ratio: {best_lorentz[1]['slope_ratio']:.4f})")
    
    best_ty = min(lorentz_results.items(), key=lambda x: x[1]['dty_std'])
    print(f"   Best ty Conservation: {best_ty[0]} (œÉ: {best_ty[1]['dty_std']*1000:.4f} mrad)")

print("\n" + "="*80)

In [None]:
# Save all results to JSON
full_results = {
    'model_statistics': all_stats,
    'ty_conservation': ty_results if 'ty_results' in dir() else {},
    'charge_consistency': charge_results if 'charge_results' in dir() else {},
    'lorentz_force': lorentz_results if 'lorentz_results' in dir() else {},
    'systematic_errors': systematic_results if 'systematic_results' in dir() else {},
}

with open(OUTPUT_DIR / 'full_analysis_results.json', 'w') as f:
    json.dump(full_results, f, indent=2, default=lambda x: float(x) if hasattr(x, 'item') else x)

print(f"Results saved to {OUTPUT_DIR / 'full_analysis_results.json'}")

---

## Analysis Complete!

All plots have been saved to the output directory. Key findings:

1. **Model Performance**: Review the ranking table to see which architectures perform best
2. **Physics Constraints**: Check whether PINN models better preserve ty conservation and charge consistency
3. **Momentum Dependence**: Low momentum tracks are harder - verify models handle this correctly
4. **Systematic Errors**: Good models should have low bias and uncorrelated residuals

For production use, select the model with the best trade-off between:
- Position accuracy
- Physics consistency
- Model size (for inference speed)