# RT-TDDFT Data Exploration

This notebook explores the RT-TDDFT trajectory data:

1. **Simulation Index** - Load and explore trajectory metadata
2. **Data Loading** - Load and inspect HDF5 trajectory files
3. **Data Structure** - Understand density matrices, fields, overlaps
4. **Molecular Geometry** - Visualize atomic positions and bonds
5. **Density Matrix Analysis** - Properties, eigenvalues, traces
6. **External Field Analysis** - Field characteristics and spectra
7. **Data Statistics** - Distributions and summary stats
8. **Data Quality Checks** - Validate physics constraints
9. **Multi-Trajectory Analysis** - Compare across trajectories

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import glob
import h5py

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

In [None]:
# Import project modules
from src.data import Trajectory, TrajectoryDataset, SimulationIndex, SimulationRecord
from src.utils import (
    build_molecular_graph,
    check_hermiticity,
    check_trace,
    check_idempotency,
    plot_density_matrix,
    close_all,
)

## Configuration

In [None]:
# Data paths - modify as needed
DATA_DIR = Path("../data/processed")
RAW_DIR = Path("../data/raw")
INDEX_PATH = Path("../rt_simulations.pkl")  # Simulation index

# Check what data is available
print("Processed data directory:", DATA_DIR)
print("Exists:", DATA_DIR.exists())

if DATA_DIR.exists():
    h5_files = list(DATA_DIR.glob("**/*.h5")) + list(DATA_DIR.glob("**/*.hdf5"))
    print(f"Found {len(h5_files)} HDF5 files")
    for f in h5_files[:10]:
        print(f"  - {f.name}")
    if len(h5_files) > 10:
        print(f"  ... and {len(h5_files) - 10} more")

print("\nSimulation index:", INDEX_PATH)
print("Exists:", INDEX_PATH.exists())

---
## 1. Simulation Index

Explore trajectory metadata using the SimulationIndex class.

In [None]:
# Load simulation index if available
sim_index = None
if INDEX_PATH.exists():
    sim_index = SimulationIndex.load(INDEX_PATH)
    print(f"Loaded simulation index with {len(sim_index)} simulations")
    print(f"\nAvailable molecules: {sim_index.molecules}")
    print(f"Geometries: {sim_index.geometries}")
    print(f"Field types: {sim_index.field_types}")
    print(f"Basis sets: {sim_index.basis_sets}")
else:
    print("No simulation index found at", INDEX_PATH)
    print("Continuing with direct HDF5 file loading...")

In [None]:
# Explore simulation index summary and statistics
if sim_index:
    print("=" * 60)
    print("SIMULATION INDEX SUMMARY")
    print("=" * 60)
    
    # Summary table by molecule and field type
    print("\nSimulations per molecule/field type:")
    print(sim_index.summary())
    
    # Unique configurations
    print("\nUnique molecule/basis configurations:")
    configs = sim_index.get_unique_configs()
    print(configs.to_string(index=False))

In [None]:
# Filter and explore simulation records
if sim_index:
    print("=" * 60)
    print("FILTERING EXAMPLES")
    print("=" * 60)
    
    # Filter by molecule
    h2_sims = sim_index.filter(molecule="h2")
    print(f"\nH2 simulations: {len(h2_sims)}")
    
    # Filter by geometry type
    eq_sims = sim_index.filter(geometry="eq")
    print(f"Equilibrium geometry: {len(eq_sims)}")
    
    # Filter by basis size range
    small_basis = sim_index.filter(max_nbf=20)
    print(f"Small basis (nbf <= 20): {len(small_basis)}")
    
    # Combined filters
    h2_delta = sim_index.filter(molecule="h2", field_type="delta", geometry="eq")
    print(f"H2 + delta + equilibrium: {len(h2_delta)}")
    
    # Show first few records
    print("\nSample records:")
    for i, record in enumerate(sim_index):
        if i >= 5:
            print("  ...")
            break
        print(f"  {record.calc_name}: {record.molecule}, nbf={record.nbf}, "
              f"tsteps={record.tsteps}, {record.field_type}_{record.field_polarization}")

---
## 2. Data Loading

Explore how to load trajectory data from HDF5 files.

In [None]:
def inspect_hdf5(filepath):
    """Inspect structure of an HDF5 file."""
    def print_attrs(name, obj):
        indent = "  " * name.count("/")
        if isinstance(obj, h5py.Dataset):
            print(f"{indent}{name}: {obj.shape} {obj.dtype}")
        else:
            print(f"{indent}{name}/")
    
    with h5py.File(filepath, 'r') as f:
        print(f"File: {filepath}")
        print("\nStructure:")
        f.visititems(print_attrs)
        
        print("\nAttributes:")
        for key, val in f.attrs.items():
            print(f"  {key}: {val}")

In [None]:
# Inspect first HDF5 file
if DATA_DIR.exists() and h5_files:
    inspect_hdf5(h5_files[0])
else:
    print("No HDF5 files found. Creating sample data for demonstration...")
    
    # Create sample trajectory for demo
    n_steps = 100
    n_basis = 6
    n_atoms = 2
    
    # Sample H2 molecule
    positions = np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]])  # Bohr
    atomic_numbers = np.array([1, 1])
    
    # Random density matrices (Hermitian)
    densities = np.random.randn(n_steps, n_basis, n_basis) + 1j * np.random.randn(n_steps, n_basis, n_basis)
    densities = 0.5 * (densities + densities.conj().transpose(0, 2, 1))
    
    # Normalize trace to n_electrons
    n_electrons = 2
    for i in range(n_steps):
        trace = np.trace(densities[i]).real
        densities[i] *= n_electrons / trace
    
    # Random fields
    fields = 0.01 * np.random.randn(n_steps, 3)
    
    # Identity overlap (orthonormal basis)
    overlap = np.eye(n_basis, dtype=np.complex128)
    
    sample_data = {
        'densities': densities,
        'fields': fields,
        'overlap': overlap,
        'positions': positions,
        'atomic_numbers': atomic_numbers,
        'n_electrons': n_electrons,
        'n_basis': n_basis,
    }
    
    print("Created sample data with:")
    for key, val in sample_data.items():
        if isinstance(val, np.ndarray):
            print(f"  {key}: {val.shape} {val.dtype}")
        else:
            print(f"  {key}: {val}")

In [None]:
# Load trajectory using Trajectory class
def load_trajectory_data(filepath=None):
    """Load trajectory from file or return sample data."""
    if filepath and Path(filepath).exists():
        return Trajectory.load(filepath)
    elif 'sample_data' in dir():
        # Create Trajectory-like object from sample data
        class SampleTrajectory:
            pass
        traj = SampleTrajectory()
        for key, val in sample_data.items():
            setattr(traj, key, val)
        return traj
    else:
        return None

# Try to load real data, fall back to sample
if DATA_DIR.exists() and h5_files:
    traj = Trajectory.load(h5_files[0])
    print(f"Loaded trajectory from {h5_files[0].name}")
elif 'sample_data' in dir():
    traj = type('Trajectory', (), sample_data)()
    print("Using sample trajectory data")
else:
    traj = None
    print("No data available")

---
## 3. Data Structure

Understand the components of RT-TDDFT trajectory data.

In [None]:
if traj:
    print("=" * 50)
    print("TRAJECTORY DATA STRUCTURE")
    print("=" * 50)
    
    # Density matrices
    densities = np.array(traj.densities)
    print(f"\nDensity matrices:")
    print(f"  Shape: {densities.shape}")
    print(f"  Dtype: {densities.dtype}")
    print(f"  n_timesteps: {densities.shape[0]}")
    print(f"  n_basis: {densities.shape[1]}")
    
    # Fields
    fields = np.array(traj.fields)
    print(f"\nExternal fields:")
    print(f"  Shape: {fields.shape}")
    print(f"  Components: (E_x, E_y, E_z)")
    
    # Overlap matrix
    overlap = np.array(traj.overlap)
    print(f"\nOverlap matrix:")
    print(f"  Shape: {overlap.shape}")
    print(f"  Dtype: {overlap.dtype}")
    
    # Geometry
    positions = np.array(traj.positions)
    atomic_numbers = np.array(traj.atomic_numbers)
    print(f"\nMolecular geometry:")
    print(f"  n_atoms: {len(atomic_numbers)}")
    print(f"  Atomic numbers: {atomic_numbers}")
    print(f"  Positions shape: {positions.shape}")
    
    # Electrons
    n_electrons = traj.n_electrons
    print(f"\nElectrons: {n_electrons}")

In [None]:
# Element symbols lookup
ELEMENT_SYMBOLS = {
    1: 'H', 2: 'He', 3: 'Li', 4: 'Be', 5: 'B', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 10: 'Ne',
    11: 'Na', 12: 'Mg', 13: 'Al', 14: 'Si', 15: 'P', 16: 'S', 17: 'Cl', 18: 'Ar',
}

if traj:
    print("\nMolecular formula:")
    symbols = [ELEMENT_SYMBOLS.get(z, f'X{z}') for z in atomic_numbers]
    formula = ''.join(symbols)
    print(f"  {formula}")
    
    print("\nAtomic positions (Bohr):")
    for i, (sym, pos) in enumerate(zip(symbols, positions)):
        print(f"  {sym}{i+1}: ({pos[0]:8.4f}, {pos[1]:8.4f}, {pos[2]:8.4f})")

---
## 4. Molecular Geometry Visualization

In [None]:
def plot_molecule_3d(positions, atomic_numbers, title="Molecular Geometry"):
    """Plot 3D molecular structure."""
    # Element colors
    COLORS = {
        1: 'white', 6: 'gray', 7: 'blue', 8: 'red', 9: 'green',
        3: 'purple', 11: 'purple', 16: 'yellow', 17: 'green',
    }
    # Element sizes (van der Waals radii, scaled)
    SIZES = {
        1: 100, 6: 200, 7: 180, 8: 170, 9: 160,
        3: 250, 11: 300, 16: 220, 17: 200,
    }
    
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot atoms
    for i, (pos, z) in enumerate(zip(positions, atomic_numbers)):
        color = COLORS.get(z, 'pink')
        size = SIZES.get(z, 150)
        ax.scatter(*pos, c=color, s=size, edgecolors='black', linewidth=1)
        ax.text(pos[0], pos[1], pos[2] + 0.3, 
                ELEMENT_SYMBOLS.get(z, f'X{z}'), fontsize=12, ha='center')
    
    # Draw bonds (simple distance-based)
    bond_threshold = 3.5  # Bohr
    for i in range(len(positions)):
        for j in range(i + 1, len(positions)):
            dist = np.linalg.norm(positions[i] - positions[j])
            if dist < bond_threshold:
                ax.plot([positions[i, 0], positions[j, 0]],
                       [positions[i, 1], positions[j, 1]],
                       [positions[i, 2], positions[j, 2]],
                       'k-', linewidth=2)
    
    ax.set_xlabel('X (Bohr)')
    ax.set_ylabel('Y (Bohr)')
    ax.set_zlabel('Z (Bohr)')
    ax.set_title(title)
    
    # Equal aspect ratio
    max_range = np.max(np.ptp(positions, axis=0)) / 2
    mid = positions.mean(axis=0)
    ax.set_xlim(mid[0] - max_range - 1, mid[0] + max_range + 1)
    ax.set_ylim(mid[1] - max_range - 1, mid[1] + max_range + 1)
    ax.set_zlim(mid[2] - max_range - 1, mid[2] + max_range + 1)
    
    return fig

In [None]:
if traj:
    fig = plot_molecule_3d(positions, atomic_numbers, title=f"Molecule: {formula}")
    plt.show()

In [None]:
# Compute interatomic distances
if traj:
    print("Interatomic distances (Bohr):")
    n_atoms = len(positions)
    for i in range(n_atoms):
        for j in range(i + 1, n_atoms):
            dist = np.linalg.norm(positions[i] - positions[j])
            sym_i = ELEMENT_SYMBOLS.get(atomic_numbers[i], 'X')
            sym_j = ELEMENT_SYMBOLS.get(atomic_numbers[j], 'X')
            print(f"  {sym_i}{i+1}-{sym_j}{j+1}: {dist:.4f} Bohr ({dist * 0.529177:.4f} Angstrom)")

In [None]:
# Build molecular graph
if traj:
    pos_tensor = torch.tensor(positions, dtype=torch.float32)
    z_tensor = torch.tensor(atomic_numbers)
    
    graph = build_molecular_graph(pos_tensor, z_tensor, cutoff=5.0)
    
    print("\nMolecular Graph:")
    print(f"  Nodes: {graph.num_nodes}")
    print(f"  Edges: {graph.num_edges}")
    print(f"  Edge index shape: {graph.edge_index.shape}")

---
## 5. Density Matrix Analysis

In [None]:
if traj:
    # Analyze density matrix properties
    print("=" * 50)
    print("DENSITY MATRIX ANALYSIS")
    print("=" * 50)
    
    # Sample timesteps
    t_samples = [0, len(densities) // 2, len(densities) - 1]
    
    for t in t_samples:
        rho = densities[t]
        print(f"\nTimestep {t}:")
        
        # Trace
        trace = np.trace(rho @ overlap).real
        print(f"  Trace(rho*S): {trace:.6f} (expected: {n_electrons})")
        
        # Hermiticity
        herm_error = np.abs(rho - rho.conj().T).max()
        print(f"  Hermiticity error: {herm_error:.2e}")
        
        # Eigenvalues
        eigvals = np.linalg.eigvalsh(rho)
        print(f"  Eigenvalues: min={eigvals.min():.4f}, max={eigvals.max():.4f}")
        
        # Frobenius norm
        frob_norm = np.sqrt(np.abs(rho).sum())
        print(f"  Frobenius norm: {frob_norm:.4f}")

In [None]:
# Visualize initial density matrix
if traj:
    rho_0 = torch.tensor(densities[0], dtype=torch.complex64)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Real part
    im0 = axes[0].imshow(rho_0.real.numpy(), cmap='RdBu_r')
    axes[0].set_title('Real Part')
    plt.colorbar(im0, ax=axes[0])
    
    # Imaginary part
    im1 = axes[1].imshow(rho_0.imag.numpy(), cmap='RdBu_r')
    axes[1].set_title('Imaginary Part')
    plt.colorbar(im1, ax=axes[1])
    
    # Magnitude
    im2 = axes[2].imshow(rho_0.abs().numpy(), cmap='viridis')
    axes[2].set_title('Magnitude')
    plt.colorbar(im2, ax=axes[2])
    
    plt.suptitle('Initial Density Matrix (t=0)')
    plt.tight_layout()
    plt.show()

In [None]:
# Plot eigenvalue evolution
if traj:
    n_eigvals = min(densities.shape[1], 10)
    eigval_history = np.zeros((len(densities), n_eigvals))
    
    for t in range(len(densities)):
        eigvals = np.sort(np.linalg.eigvalsh(densities[t]))[::-1]
        eigval_history[t] = eigvals[:n_eigvals]
    
    plt.figure(figsize=(12, 6))
    for i in range(n_eigvals):
        plt.plot(eigval_history[:, i], label=f'λ_{i+1}', alpha=0.7)
    
    plt.xlabel('Timestep')
    plt.ylabel('Eigenvalue')
    plt.title('Density Matrix Eigenvalue Evolution')
    plt.legend(loc='upper right')
    plt.grid(True)
    plt.show()

In [None]:
# Plot trace evolution
if traj:
    traces = []
    for t in range(len(densities)):
        trace = np.trace(densities[t] @ overlap).real
        traces.append(trace)
    
    plt.figure(figsize=(10, 4))
    plt.plot(traces)
    plt.axhline(y=n_electrons, color='r', linestyle='--', label=f'Expected: {n_electrons}')
    plt.xlabel('Timestep')
    plt.ylabel('Tr(ρS)')
    plt.title('Trace Conservation')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    print(f"Trace deviation: mean={np.mean(np.abs(np.array(traces) - n_electrons)):.2e}, "
          f"max={np.max(np.abs(np.array(traces) - n_electrons)):.2e}")

---
## 6. External Field Analysis

In [None]:
if traj:
    print("=" * 50)
    print("EXTERNAL FIELD ANALYSIS")
    print("=" * 50)
    
    print(f"\nField shape: {fields.shape}")
    print(f"\nField statistics:")
    print(f"  E_x: mean={fields[:, 0].mean():.4e}, std={fields[:, 0].std():.4e}, "
          f"range=[{fields[:, 0].min():.4e}, {fields[:, 0].max():.4e}]")
    print(f"  E_y: mean={fields[:, 1].mean():.4e}, std={fields[:, 1].std():.4e}, "
          f"range=[{fields[:, 1].min():.4e}, {fields[:, 1].max():.4e}]")
    print(f"  E_z: mean={fields[:, 2].mean():.4e}, std={fields[:, 2].std():.4e}, "
          f"range=[{fields[:, 2].min():.4e}, {fields[:, 2].max():.4e}]")
    
    # Total field magnitude
    field_mag = np.linalg.norm(fields, axis=1)
    print(f"\nTotal magnitude:")
    print(f"  mean={field_mag.mean():.4e}, max={field_mag.max():.4e}")

In [None]:
# Plot field components
if traj:
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Individual components
    axes[0, 0].plot(fields[:, 0], label='E_x', color='tab:blue')
    axes[0, 0].plot(fields[:, 1], label='E_y', color='tab:orange')
    axes[0, 0].plot(fields[:, 2], label='E_z', color='tab:green')
    axes[0, 0].set_xlabel('Timestep')
    axes[0, 0].set_ylabel('Field (a.u.)')
    axes[0, 0].set_title('Field Components')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Field magnitude
    axes[0, 1].plot(field_mag, color='tab:purple')
    axes[0, 1].set_xlabel('Timestep')
    axes[0, 1].set_ylabel('|E| (a.u.)')
    axes[0, 1].set_title('Field Magnitude')
    axes[0, 1].grid(True)
    
    # Field histogram
    axes[1, 0].hist(fields.flatten(), bins=50, edgecolor='black', alpha=0.7)
    axes[1, 0].set_xlabel('Field value (a.u.)')
    axes[1, 0].set_ylabel('Count')
    axes[1, 0].set_title('Field Distribution')
    
    # FFT of field (if it's a pulse)
    fft_field = np.abs(np.fft.rfft(fields[:, 0]))
    freqs = np.fft.rfftfreq(len(fields), d=1)  # Assuming dt=1
    axes[1, 1].plot(freqs[:len(freqs)//4], fft_field[:len(freqs)//4])
    axes[1, 1].set_xlabel('Frequency')
    axes[1, 1].set_ylabel('Amplitude')
    axes[1, 1].set_title('Field Spectrum (E_x)')
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

---
## 7. Data Statistics

In [None]:
if traj:
    print("=" * 50)
    print("DENSITY MATRIX STATISTICS")
    print("=" * 50)
    
    # Real and imaginary parts
    real_parts = densities.real
    imag_parts = densities.imag
    
    print(f"\nReal part:")
    print(f"  mean: {real_parts.mean():.4e}")
    print(f"  std:  {real_parts.std():.4e}")
    print(f"  min:  {real_parts.min():.4e}")
    print(f"  max:  {real_parts.max():.4e}")
    
    print(f"\nImaginary part:")
    print(f"  mean: {imag_parts.mean():.4e}")
    print(f"  std:  {imag_parts.std():.4e}")
    print(f"  min:  {imag_parts.min():.4e}")
    print(f"  max:  {imag_parts.max():.4e}")
    
    print(f"\nMagnitude:")
    magnitudes = np.abs(densities)
    print(f"  mean: {magnitudes.mean():.4e}")
    print(f"  std:  {magnitudes.std():.4e}")
    print(f"  max:  {magnitudes.max():.4e}")

In [None]:
# Distribution plots
if traj:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Real part distribution
    axes[0].hist(real_parts.flatten(), bins=100, edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Value')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Real Part Distribution')
    axes[0].set_yscale('log')
    
    # Imaginary part distribution
    axes[1].hist(imag_parts.flatten(), bins=100, edgecolor='black', alpha=0.7)
    axes[1].set_xlabel('Value')
    axes[1].set_ylabel('Count')
    axes[1].set_title('Imaginary Part Distribution')
    axes[1].set_yscale('log')
    
    # Magnitude distribution
    axes[2].hist(magnitudes.flatten(), bins=100, edgecolor='black', alpha=0.7)
    axes[2].set_xlabel('Value')
    axes[2].set_ylabel('Count')
    axes[2].set_title('Magnitude Distribution')
    axes[2].set_yscale('log')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Element-wise variance over time
if traj:
    variance_map = np.var(np.abs(densities), axis=0)
    
    plt.figure(figsize=(8, 6))
    plt.imshow(variance_map, cmap='hot')
    plt.colorbar(label='Variance')
    plt.xlabel('Basis Index')
    plt.ylabel('Basis Index')
    plt.title('Element-wise Variance Over Trajectory')
    plt.show()
    
    print(f"Most variable elements (indices):")
    flat_idx = np.argsort(variance_map.flatten())[::-1][:5]
    for idx in flat_idx:
        i, j = np.unravel_index(idx, variance_map.shape)
        print(f"  [{i}, {j}]: variance = {variance_map[i, j]:.4e}")

---
## 8. Data Quality Checks

In [None]:
if traj:
    print("=" * 50)
    print("DATA QUALITY CHECKS")
    print("=" * 50)
    
    # Convert to torch tensors for checks
    rho_tensor = torch.tensor(densities, dtype=torch.complex64)
    S_tensor = torch.tensor(overlap, dtype=torch.complex64)
    
    # Check Hermiticity
    herm_violations = []
    for t in range(len(densities)):
        violation = check_hermiticity(rho_tensor[t])
        herm_violations.append(violation.item())
    
    herm_violations = np.array(herm_violations)
    print(f"\n1. Hermiticity Check (||ρ - ρ†||):")
    print(f"   mean: {herm_violations.mean():.2e}")
    print(f"   max:  {herm_violations.max():.2e}")
    print(f"   PASS" if herm_violations.max() < 1e-6 else f"   WARN: max violation = {herm_violations.max():.2e}")
    
    # Check trace
    trace_violations = []
    for t in range(len(densities)):
        violation = check_trace(rho_tensor[t], S_tensor, n_electrons)
        trace_violations.append(violation.item())
    
    trace_violations = np.array(trace_violations)
    print(f"\n2. Trace Conservation Check (|Tr(ρS) - N_e|):")
    print(f"   mean: {trace_violations.mean():.2e}")
    print(f"   max:  {trace_violations.max():.2e}")
    print(f"   PASS" if trace_violations.max() < 0.01 else f"   WARN: max violation = {trace_violations.max():.2e}")
    
    # Check for NaN/Inf
    has_nan = np.isnan(densities).any()
    has_inf = np.isinf(densities).any()
    print(f"\n3. Numerical Stability:")
    print(f"   Contains NaN: {has_nan}")
    print(f"   Contains Inf: {has_inf}")
    print(f"   PASS" if not (has_nan or has_inf) else "   FAIL")
    
    # Check overlap matrix
    print(f"\n4. Overlap Matrix Check:")
    S_herm = np.abs(overlap - overlap.conj().T).max()
    S_eigvals = np.linalg.eigvalsh(overlap)
    print(f"   Hermiticity error: {S_herm:.2e}")
    print(f"   Eigenvalues: [{S_eigvals.min():.4f}, {S_eigvals.max():.4f}]")
    print(f"   Positive definite: {S_eigvals.min() > 0}")

In [None]:
# Plot quality metrics over time
if traj:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    axes[0].plot(herm_violations)
    axes[0].set_xlabel('Timestep')
    axes[0].set_ylabel('Hermiticity Violation')
    axes[0].set_title('Hermiticity Over Time')
    axes[0].set_yscale('log')
    axes[0].grid(True)
    
    axes[1].plot(trace_violations)
    axes[1].axhline(y=0, color='r', linestyle='--')
    axes[1].set_xlabel('Timestep')
    axes[1].set_ylabel('Trace Violation')
    axes[1].set_title('Trace Conservation Over Time')
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

---
## 9. Multi-Trajectory Analysis

Compare statistics across multiple trajectories.

In [None]:
# Load multiple trajectories if available
if DATA_DIR.exists() and len(h5_files) > 1:
    print(f"Loading {min(len(h5_files), 10)} trajectories for comparison...")
    
    all_stats = []
    for f in h5_files[:10]:
        try:
            t = Trajectory.load(f)
            d = np.array(t.densities)
            stats = {
                'file': f.name,
                'n_steps': len(d),
                'n_basis': d.shape[1],
                'n_electrons': t.n_electrons,
                'mean_magnitude': np.abs(d).mean(),
                'max_magnitude': np.abs(d).max(),
                'field_max': np.abs(t.fields).max(),
            }
            all_stats.append(stats)
        except Exception as e:
            print(f"  Failed to load {f.name}: {e}")
    
    print(f"\nLoaded {len(all_stats)} trajectories")
else:
    all_stats = None
    print("Multiple trajectories not available")

In [None]:
# Display comparison table
if all_stats:
    print("\nTrajectory Comparison:")
    print("-" * 90)
    print(f"{'File':<30} {'Steps':>8} {'Basis':>6} {'N_e':>4} {'Mean |ρ|':>12} {'Max |E|':>12}")
    print("-" * 90)
    
    for s in all_stats:
        print(f"{s['file'][:28]:<30} {s['n_steps']:>8} {s['n_basis']:>6} {s['n_electrons']:>4} "
              f"{s['mean_magnitude']:>12.4e} {s['field_max']:>12.4e}")

In [None]:
# Plot comparison
if all_stats and len(all_stats) > 1:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Trajectory lengths
    axes[0].bar(range(len(all_stats)), [s['n_steps'] for s in all_stats])
    axes[0].set_xlabel('Trajectory')
    axes[0].set_ylabel('Steps')
    axes[0].set_title('Trajectory Lengths')
    
    # Basis sizes
    axes[1].bar(range(len(all_stats)), [s['n_basis'] for s in all_stats])
    axes[1].set_xlabel('Trajectory')
    axes[1].set_ylabel('Basis Size')
    axes[1].set_title('Basis Sizes')
    
    # Field strengths
    axes[2].bar(range(len(all_stats)), [s['field_max'] for s in all_stats])
    axes[2].set_xlabel('Trajectory')
    axes[2].set_ylabel('Max |E|')
    axes[2].set_title('Maximum Field Strengths')
    
    plt.tight_layout()
    plt.show()

---
## Summary

In [None]:
if traj:
    print("=" * 50)
    print("DATA EXPLORATION SUMMARY")
    print("=" * 50)
    print(f"\nTrajectory:")
    print(f"  Timesteps: {len(densities)}")
    print(f"  Basis size: {densities.shape[1]}")
    print(f"  Electrons: {n_electrons}")
    print(f"  Atoms: {len(atomic_numbers)}")
    print(f"\nDensity Matrix:")
    print(f"  Real range: [{real_parts.min():.4f}, {real_parts.max():.4f}]")
    print(f"  Imag range: [{imag_parts.min():.4f}, {imag_parts.max():.4f}]")
    print(f"\nExternal Field:")
    print(f"  Max magnitude: {field_mag.max():.4e} a.u.")
    print(f"\nData Quality:")
    print(f"  Hermiticity: {'PASS' if herm_violations.max() < 1e-6 else 'CHECK'}")
    print(f"  Trace: {'PASS' if trace_violations.max() < 0.01 else 'CHECK'}")
    print(f"  Numerical: {'PASS' if not (has_nan or has_inf) else 'FAIL'}")

In [None]:
# Cleanup
close_all()
print("\nExploration complete!")