# RT-TDDFT ML Model Results Analysis

This notebook provides tools for analyzing trained RT-TDDFT ML models:

1. **Training Analysis** - Loss curves, learning rate, curriculum progress
2. **Model Evaluation** - Metrics on test trajectories
3. **Error Analysis** - Error accumulation over rollout steps
4. **Physics Constraints** - Trace, Hermiticity, idempotency violations
5. **Spectral Analysis** - Absorption spectrum comparison
6. **Density Visualization** - Predicted vs true density matrices

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import json

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
# Import project modules
from src.utils import (
    # Metrics
    compute_trajectory_metrics,
    compute_step_errors,
    compute_absorption_spectrum,
    spectrum_overlap,
    MetricsAccumulator,
    # Visualization
    plot_training_curves,
    plot_loss_components,
    plot_error_accumulation,
    plot_physics_violations,
    plot_absorption_spectrum,
    plot_density_matrix,
    plot_density_comparison,
    plot_dipole_trajectory,
    plot_metrics_comparison,
    plot_curriculum_progress,
    create_training_dashboard,
    close_all,
)
from src.inference import Predictor, RolloutConfig
from src.data import Trajectory

## Configuration

Set paths to your checkpoint and data:

In [None]:
# Configuration - modify these paths
CHECKPOINT_PATH = "../checkpoints/phase1_h2p/best.pt"  # Path to trained model
DATA_PATH = "../data/processed"  # Path to test data
MOLECULE = "h2"  # Molecule to analyze

# Analysis settings
MAX_TRAJECTORIES = 10  # Number of trajectories to analyze
MAX_ROLLOUT_STEPS = 100  # Maximum rollout steps
DT = 0.1  # Time step in atomic units

---
## 1. Training Analysis

Load and visualize training history from checkpoint.

In [None]:
def load_training_history(checkpoint_path):
    """Load training history from checkpoint."""
    checkpoint_path = Path(checkpoint_path)
    
    if not checkpoint_path.exists():
        print(f"Checkpoint not found: {checkpoint_path}")
        return None
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    history = {
        'train_losses': checkpoint.get('train_losses', []),
        'val_losses': checkpoint.get('val_losses', []),
        'epoch': checkpoint.get('epoch', 0),
        'global_step': checkpoint.get('global_step', 0),
        'best_val_loss': checkpoint.get('best_val_loss', None),
    }
    
    # Try to load loss components if available
    if 'loss_history' in checkpoint:
        history['loss_components'] = checkpoint['loss_history']
    
    # Try to load learning rate history
    if 'lr_history' in checkpoint:
        history['lr_history'] = checkpoint['lr_history']
    
    return history

In [None]:
# Load training history
history = load_training_history(CHECKPOINT_PATH)

if history:
    print(f"Epochs trained: {history['epoch']}")
    print(f"Global steps: {history['global_step']}")
    print(f"Best validation loss: {history['best_val_loss']}")
    print(f"Final training loss: {history['train_losses'][-1] if history['train_losses'] else 'N/A'}")

In [None]:
# Plot training curves
if history and history['train_losses']:
    fig = plot_training_curves(
        train_losses=history['train_losses'],
        val_losses=history.get('val_losses'),
        title="Training Progress",
    )
    plt.show()
else:
    print("No training history available")

In [None]:
# Plot loss components if available
if history and 'loss_components' in history:
    fig = plot_loss_components(
        loss_history=history['loss_components'],
        title="Loss Components Over Training",
    )
    plt.show()

In [None]:
# Create training dashboard (if we have enough data)
if history and history['train_losses']:
    fig = create_training_dashboard(
        train_losses=history['train_losses'],
        val_losses=history.get('val_losses'),
        loss_components=history.get('loss_components'),
        lr_history=history.get('lr_history'),
    )
    plt.show()

---
## 2. Load Model and Data

In [None]:
def load_model(checkpoint_path, device='cpu'):
    """Load model from checkpoint."""
    checkpoint_path = Path(checkpoint_path)
    
    if not checkpoint_path.exists():
        print(f"Checkpoint not found: {checkpoint_path}")
        return None
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    if 'model' in checkpoint:
        model = checkpoint['model']
    elif 'model_state_dict' in checkpoint:
        raise ValueError("Checkpoint contains only state_dict. Need full model.")
    else:
        model = checkpoint
    
    model = model.to(device)
    model.eval()
    
    return model

In [None]:
def load_trajectories(data_path, molecule=None, max_trajectories=None):
    """Load trajectory data from HDF5 files."""
    data_path = Path(data_path)
    trajectories = []
    
    # Find HDF5 files
    if molecule:
        patterns = [f"*{molecule}*.h5", f"*{molecule}*.hdf5"]
    else:
        patterns = ["*.h5", "*.hdf5"]
    
    files = []
    for pattern in patterns:
        files.extend(data_path.glob(pattern))
    
    if max_trajectories:
        files = files[:max_trajectories]
    
    print(f"Found {len(files)} trajectory files")
    
    for f in files:
        try:
            traj = Trajectory.load(f)
            trajectories.append(traj)
        except Exception as e:
            print(f"Failed to load {f}: {e}")
    
    return trajectories

In [None]:
# Load model
model = load_model(CHECKPOINT_PATH, device=device)

if model:
    print(f"Model loaded successfully")
    # Print model info if available
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Load test trajectories
trajectories = load_trajectories(DATA_PATH, molecule=MOLECULE, max_trajectories=MAX_TRAJECTORIES)
print(f"Loaded {len(trajectories)} trajectories")

---
## 3. Model Evaluation

Evaluate model on test trajectories and compute metrics.

In [None]:
def evaluate_trajectory(model, trajectory, max_steps=None, device='cpu'):
    """Evaluate model on a single trajectory."""
    # Convert to tensors
    densities = torch.tensor(trajectory.densities, device=device, dtype=torch.complex64)
    fields = torch.tensor(trajectory.fields, device=device, dtype=torch.float32)
    overlap = torch.tensor(trajectory.overlap, device=device, dtype=torch.complex64)
    n_electrons = trajectory.n_electrons
    
    geometry = {
        'positions': torch.tensor(trajectory.positions, device=device, dtype=torch.float32),
        'atomic_numbers': torch.tensor(trajectory.atomic_numbers, device=device),
    }
    
    # Get dipole integrals if available
    dipole_integrals = None
    if hasattr(trajectory, 'dipole_integrals') and trajectory.dipole_integrals is not None:
        dipole_integrals = torch.tensor(
            trajectory.dipole_integrals, device=device, dtype=torch.complex64
        )
    
    # Setup predictor
    predictor = Predictor(model)
    n_steps = max_steps or (len(densities) - 1)
    n_steps = min(n_steps, len(densities) - 1)
    
    config = RolloutConfig(
        max_steps=n_steps,
        apply_physics_projection=False,
    )
    
    # Run rollout
    with torch.no_grad():
        result = predictor.rollout(
            initial_density=densities[0],
            geometry=geometry,
            field_sequence=fields[:n_steps + 1],
            overlap=overlap,
            n_electrons=n_electrons,
            config=config,
        )
    
    pred = result.densities
    true = densities[1:n_steps + 1]
    
    # Ensure same length
    min_len = min(len(pred), len(true))
    pred = pred[:min_len]
    true = true[:min_len]
    
    return {
        'pred': pred,
        'true': true,
        'overlap': overlap,
        'n_electrons': n_electrons,
        'dipole_integrals': dipole_integrals,
        'fields': fields[:min_len + 1],
    }

In [None]:
# Evaluate all trajectories
results = []
accumulator = MetricsAccumulator()

if model and trajectories:
    for i, traj in enumerate(trajectories):
        print(f"Evaluating trajectory {i+1}/{len(trajectories)}...", end='\r')
        
        eval_result = evaluate_trajectory(
            model, traj, max_steps=MAX_ROLLOUT_STEPS, device=device
        )
        
        # Compute metrics
        metrics = compute_trajectory_metrics(
            trajectory_pred=eval_result['pred'],
            trajectory_true=eval_result['true'],
            overlap=eval_result['overlap'],
            n_electrons=eval_result['n_electrons'],
            dipole_integrals=eval_result['dipole_integrals'],
            dt=DT,
        )
        
        eval_result['metrics'] = metrics
        results.append(eval_result)
        
        accumulator.add({
            'mse': metrics.mse,
            'mae': metrics.mae,
            'relative_error': metrics.relative_error,
            'max_error': metrics.max_error,
            'trace_violation': metrics.trace_violation,
            'hermiticity_violation': metrics.hermiticity_violation,
            'dipole_error': metrics.dipole_error,
            'spectrum_overlap': metrics.spectrum_overlap,
        })
    
    print(f"\nEvaluated {len(results)} trajectories")
else:
    print("Model or trajectories not available")

In [None]:
# Print summary statistics
if results:
    summary = accumulator.compute_summary()
    
    print("=" * 50)
    print("EVALUATION SUMMARY")
    print("=" * 50)
    
    for metric, stats in summary.items():
        print(f"\n{metric}:")
        print(f"  mean:   {stats['mean']:.6e}")
        print(f"  std:    {stats['std']:.6e}")
        print(f"  min:    {stats['min']:.6e}")
        print(f"  max:    {stats['max']:.6e}")

---
## 4. Error Accumulation Analysis

Analyze how errors accumulate over rollout steps.

In [None]:
# Compute step-by-step errors for first trajectory
if results:
    result = results[0]
    step_errors = compute_step_errors(result['pred'], result['true'])
    
    print(f"Trajectory length: {len(result['pred'])} steps")
    print(f"Initial step MSE: {step_errors['step_mse'][0]:.6e}")
    print(f"Final step MSE: {step_errors['step_mse'][-1]:.6e}")
    print(f"Error growth factor: {step_errors['step_mse'][-1] / step_errors['step_mse'][0]:.2f}x")

In [None]:
# Plot error accumulation
if results:
    fig = plot_error_accumulation(
        step_errors=step_errors['step_relative_error'],
        cumulative_errors=step_errors['cumulative_error'],
        dt=DT,
        title="Error Accumulation Over Rollout",
    )
    plt.show()

In [None]:
# Average error accumulation across all trajectories
if len(results) > 1:
    all_step_errors = []
    min_len = min(len(r['pred']) for r in results)
    
    for r in results:
        errors = compute_step_errors(r['pred'][:min_len], r['true'][:min_len])
        all_step_errors.append(errors['step_relative_error'].cpu().numpy())
    
    all_step_errors = np.array(all_step_errors)
    mean_errors = all_step_errors.mean(axis=0)
    std_errors = all_step_errors.std(axis=0)
    
    # Plot with error bands
    plt.figure(figsize=(10, 6))
    x = np.arange(min_len) * DT
    plt.plot(x, mean_errors, label='Mean Error', color='tab:blue')
    plt.fill_between(x, mean_errors - std_errors, mean_errors + std_errors, 
                     alpha=0.3, color='tab:blue', label='Std Dev')
    plt.xlabel('Time (a.u.)')
    plt.ylabel('Relative Error')
    plt.title('Average Error Accumulation Across Trajectories')
    plt.legend()
    plt.grid(True)
    plt.show()

---
## 5. Physics Constraint Analysis

Analyze how well the model preserves physics constraints.

In [None]:
from src.utils.metrics import trace_violation, hermiticity_violation, idempotency_violation

def compute_physics_violations(pred_trajectory, overlap, n_electrons):
    """Compute physics violations at each step."""
    trace_viol = trace_violation(pred_trajectory, overlap, n_electrons)
    herm_viol = hermiticity_violation(pred_trajectory)
    idem_viol = idempotency_violation(pred_trajectory, overlap)
    
    return {
        'trace': trace_viol.cpu().numpy(),
        'hermiticity': herm_viol.cpu().numpy(),
        'idempotency': idem_viol.cpu().numpy(),
    }

In [None]:
# Compute physics violations for first trajectory
if results:
    result = results[0]
    violations = compute_physics_violations(
        result['pred'], result['overlap'], result['n_electrons']
    )
    
    print(f"Trace violation - mean: {violations['trace'].mean():.6e}, max: {violations['trace'].max():.6e}")
    print(f"Hermiticity violation - mean: {violations['hermiticity'].mean():.6e}, max: {violations['hermiticity'].max():.6e}")
    print(f"Idempotency violation - mean: {violations['idempotency'].mean():.6e}, max: {violations['idempotency'].max():.6e}")

In [None]:
# Plot physics violations
if results:
    fig = plot_physics_violations(
        trace_violations=violations['trace'],
        hermiticity_violations=violations['hermiticity'],
        idempotency_violations=violations['idempotency'],
        dt=DT,
        title="Physics Constraint Violations",
    )
    plt.show()

---
## 6. Spectral Analysis

Compare predicted and true absorption spectra.

In [None]:
def compute_dipole_trajectory(densities, dipole_integrals):
    """Compute dipole moments from density trajectory."""
    return torch.einsum("tij,cji->tc", densities, dipole_integrals).real

In [None]:
# Compute and compare spectra
if results and results[0]['dipole_integrals'] is not None:
    result = results[0]
    
    # Compute dipole trajectories
    dipole_pred = compute_dipole_trajectory(result['pred'], result['dipole_integrals'])
    dipole_true = compute_dipole_trajectory(result['true'], result['dipole_integrals'])
    
    # Compute absorption spectra
    freqs, spec_pred = compute_absorption_spectrum(dipole_pred, DT)
    _, spec_true = compute_absorption_spectrum(dipole_true, DT)
    
    # Compute overlap
    overlap_val = spectrum_overlap(spec_pred, spec_true, freqs, freq_range=(0, 20))
    print(f"Spectrum overlap: {overlap_val:.4f}")
else:
    print("Dipole integrals not available for spectral analysis")

In [None]:
# Plot spectrum comparison
if results and results[0]['dipole_integrals'] is not None:
    fig = plot_absorption_spectrum(
        freqs=freqs,
        spectrum_pred=spec_pred,
        spectrum_true=spec_true,
        freq_range=(0, 20),
        title=f"Absorption Spectrum (Overlap: {overlap_val:.4f})",
    )
    plt.show()

In [None]:
# Plot dipole moment trajectories
if results and results[0]['dipole_integrals'] is not None:
    fig = plot_dipole_trajectory(
        dipole_pred=dipole_pred.cpu().numpy(),
        dipole_true=dipole_true.cpu().numpy(),
        dt=DT,
        title="Dipole Moment Trajectory",
    )
    plt.show()

---
## 7. Density Matrix Visualization

Visualize predicted vs true density matrices.

In [None]:
# Select timestep to visualize
if results:
    result = results[0]
    n_steps = len(result['pred'])
    
    # Visualize at different timesteps
    timesteps = [0, n_steps // 4, n_steps // 2, 3 * n_steps // 4, n_steps - 1]
    
    print(f"Trajectory length: {n_steps}")
    print(f"Visualizing timesteps: {timesteps}")

In [None]:
# Plot density comparison at selected timestep
if results:
    t = n_steps // 2  # Middle of trajectory
    
    fig = plot_density_comparison(
        rho_pred=result['pred'][t],
        rho_true=result['true'][t],
        component='abs',
        title=f"Density Matrix at t={t} (|rho|)",
    )
    plt.show()

In [None]:
# Plot real and imaginary parts
if results:
    t = n_steps // 2
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Real part
    for idx, (ax, data, title) in enumerate([
        (axes[0, 0], result['pred'][t].real.cpu().numpy(), 'Predicted (Real)'),
        (axes[0, 1], result['true'][t].real.cpu().numpy(), 'True (Real)'),
        (axes[0, 2], (result['pred'][t] - result['true'][t]).real.cpu().numpy(), 'Diff (Real)'),
    ]):
        vmax = max(np.abs(result['pred'][t].real.cpu().numpy()).max(),
                   np.abs(result['true'][t].real.cpu().numpy()).max())
        if 'Diff' in title:
            vmax = np.abs(data).max()
        im = ax.imshow(data, cmap='RdBu_r', vmin=-vmax, vmax=vmax)
        ax.set_title(title)
        plt.colorbar(im, ax=ax)
    
    # Imaginary part
    for idx, (ax, data, title) in enumerate([
        (axes[1, 0], result['pred'][t].imag.cpu().numpy(), 'Predicted (Imag)'),
        (axes[1, 1], result['true'][t].imag.cpu().numpy(), 'True (Imag)'),
        (axes[1, 2], (result['pred'][t] - result['true'][t]).imag.cpu().numpy(), 'Diff (Imag)'),
    ]):
        vmax = max(np.abs(result['pred'][t].imag.cpu().numpy()).max(),
                   np.abs(result['true'][t].imag.cpu().numpy()).max())
        if 'Diff' in title:
            vmax = np.abs(data).max()
        im = ax.imshow(data, cmap='RdBu_r', vmin=-vmax, vmax=vmax)
        ax.set_title(title)
        plt.colorbar(im, ax=ax)
    
    plt.suptitle(f'Density Matrix Comparison at t={t}')
    plt.tight_layout()
    plt.show()

In [None]:
# Animate density evolution (create frames)
if results:
    from matplotlib.animation import FuncAnimation
    from IPython.display import HTML
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Get data ranges for consistent colorbar
    vmax_pred = np.abs(result['pred'].cpu().numpy()).max()
    vmax_true = np.abs(result['true'].cpu().numpy()).max()
    vmax = max(vmax_pred, vmax_true)
    
    ims = []
    for ax in axes:
        im = ax.imshow(np.zeros_like(result['pred'][0].abs().cpu().numpy()), 
                       cmap='viridis', vmin=0, vmax=vmax)
        ims.append(im)
    
    axes[0].set_title('Predicted |rho|')
    axes[1].set_title('True |rho|')
    axes[2].set_title('Difference')
    
    def update(frame):
        pred = result['pred'][frame].abs().cpu().numpy()
        true = result['true'][frame].abs().cpu().numpy()
        diff = np.abs(pred - true)
        
        ims[0].set_array(pred)
        ims[1].set_array(true)
        ims[2].set_array(diff)
        fig.suptitle(f't = {frame}')
        return ims
    
    # Create animation (subsample for speed)
    frames = range(0, len(result['pred']), max(1, len(result['pred']) // 50))
    anim = FuncAnimation(fig, update, frames=frames, interval=100, blit=True)
    
    plt.close()  # Don't show static figure
    HTML(anim.to_jshtml())

---
## 8. Model Comparison

Compare metrics across different models or checkpoints.

In [None]:
# Define checkpoints to compare
checkpoints_to_compare = {
    # 'Phase 1': '../checkpoints/phase1_h2p/best.pt',
    # 'Phase 2': '../checkpoints/phase2_multi_mol/best.pt',
    # 'Phase 3': '../checkpoints/phase3_generalization/best.pt',
}

if not checkpoints_to_compare:
    print("Add checkpoint paths to compare models")

In [None]:
# Evaluate each checkpoint and collect metrics
comparison_metrics = {}

for name, ckpt_path in checkpoints_to_compare.items():
    print(f"\nEvaluating {name}...")
    
    model = load_model(ckpt_path, device=device)
    if model is None:
        continue
    
    acc = MetricsAccumulator()
    
    for traj in trajectories[:5]:  # Use subset for speed
        eval_result = evaluate_trajectory(
            model, traj, max_steps=MAX_ROLLOUT_STEPS, device=device
        )
        metrics = compute_trajectory_metrics(
            trajectory_pred=eval_result['pred'],
            trajectory_true=eval_result['true'],
            overlap=eval_result['overlap'],
            n_electrons=eval_result['n_electrons'],
        )
        acc.add({
            'mse': metrics.mse,
            'mae': metrics.mae,
            'relative_error': metrics.relative_error,
        })
    
    summary = acc.compute_summary()
    comparison_metrics[name] = {k: v['mean'] for k, v in summary.items()}
    print(f"  MSE: {comparison_metrics[name]['mse']:.6e}")

In [None]:
# Plot comparison
if comparison_metrics:
    fig = plot_metrics_comparison(
        metrics_dict=comparison_metrics,
        title="Model Comparison",
    )
    plt.show()

---
## 9. Export Results

Save results to JSON for further analysis.

In [None]:
# Prepare results for export
if results:
    export_data = {
        'config': {
            'checkpoint': str(CHECKPOINT_PATH),
            'data_path': str(DATA_PATH),
            'molecule': MOLECULE,
            'max_rollout_steps': MAX_ROLLOUT_STEPS,
            'dt': DT,
        },
        'summary': accumulator.compute_summary(),
        'n_trajectories': len(results),
    }
    
    # Save to file
    output_path = Path('../results') / f'{MOLECULE}_evaluation.json'
    output_path.parent.mkdir(exist_ok=True)
    
    with open(output_path, 'w') as f:
        json.dump(export_data, f, indent=2)
    
    print(f"Results saved to {output_path}")

---
## Cleanup

In [None]:
# Close all figures
close_all()
print("Analysis complete!")