# Physics Models and Cardiac Simulation

This notebook provides a comprehensive exploration of cardiac hemodynamics modeling using enhanced Windkessel models. You'll learn about the physics behind cardiac function and how to simulate various clinical conditions.

## 🫀 Cardiac Hemodynamics Fundamentals

### The Windkessel Model

The Windkessel model represents the cardiovascular system as an electrical circuit analog:

- **Resistances (R)**: Vascular resistance to blood flow
- **Capacitances (C)**: Arterial compliance and elasticity
- **Inductances (L)**: Blood inertia effects
- **Elastance (E)**: Time-varying ventricular contractility

### Enhanced Model Features

Our enhanced implementation includes:
- ✅ **Parameter validation** and error handling
- ✅ **Adaptive ODE solvers** for numerical stability
- ✅ **Clinical metrics** calculation (EF, stroke volume, etc.)
- ✅ **Batch simulation** capabilities
- ✅ **Uncertainty quantification** methods

In [None]:
# Import necessary libraries
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
from scipy import stats
import pandas as pd
from tqdm import tqdm

# Add src to path
project_root = Path.cwd().parent
src_path = project_root / 'src'
sys.path.append(str(src_path))

# Import enhanced physics modules
from physics.windkessel import (
    WindkesselModel, 
    WindkesselParameters, 
    WindkesselSimulator,
    ParameterInterpolator
)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
%matplotlib inline

print("📚 Physics modules loaded successfully!")

## 🔧 Parameter Exploration and Validation

Let's explore how different parameters affect cardiac function:

In [None]:
# Define parameter ranges for exploration
print("🔍 Parameter Sensitivity Analysis")
print("=" * 40)

# Create parameter ranges for sensitivity analysis
parameter_ranges = {
    'Emax': (1.0, 4.0, 5),  # Maximum elastance
    'Emin': (0.01, 0.1, 4), # Minimum elastance
    'Tc': (0.6, 1.4, 4),    # Cardiac cycle time
    'Rm': (0.005, 0.05, 3), # Mitral resistance
    'Ra': (0.001, 0.01, 3)  # Aortic resistance
}

# Create simulator
simulator = WindkesselSimulator()

# Generate parameter grid
param_sets = simulator.create_parameter_grid(parameter_ranges)
print(f"📊 Generated {len(param_sets)} parameter combinations")

# Run batch simulation (subset for demo)
subset_size = min(50, len(param_sets))
subset_params = param_sets[:subset_size]

print(f"🔄 Running {subset_size} simulations...")
results = simulator.batch_simulate(subset_params, n_cycles=3)

# Filter successful results
successful_results = [r for r in results if r is not None]
print(f"✅ {len(successful_results)} simulations completed successfully")

In [None]:
# Analyze simulation results
if successful_results:
    # Extract clinical metrics
    efs = [r['EF'] for r in successful_results]
    veds = [r['VED'] for r in successful_results]
    vess = [r['VES'] for r in successful_results]
    svs = [r['stroke_volume'] for r in successful_results]
    
    # Create summary statistics
    metrics_df = pd.DataFrame({
        'Ejection_Fraction': efs,
        'End_Diastolic_Volume': veds,
        'End_Systolic_Volume': vess,
        'Stroke_Volume': svs
    })
    
    print("📊 Clinical Metrics Summary:")
    print(metrics_df.describe())
    
    # Visualize distributions
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    metrics = ['Ejection_Fraction', 'End_Diastolic_Volume', 'End_Systolic_Volume', 'Stroke_Volume']
    units = ['%', 'ml', 'ml', 'ml']
    
    for i, (metric, unit) in enumerate(zip(metrics, units)):
        ax = axes[i//2, i%2]
        
        # Histogram with KDE
        sns.histplot(metrics_df[metric], kde=True, ax=ax, alpha=0.7)
        
        # Add mean line
        mean_val = metrics_df[metric].mean()
        ax.axvline(mean_val, color='red', linestyle='--', 
                  label=f'Mean: {mean_val:.1f}{unit}')
        
        ax.set_title(f'{metric.replace("_", " ")} Distribution', fontweight='bold')
        ax.set_xlabel(f'{metric.replace("_", " ")} ({unit})')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('Clinical Metrics Distribution Across Parameter Space', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
else:
    print("⚠️ No successful simulations to analyze")

## 🏥 Clinical Condition Modeling

Let's model different cardiac conditions by adjusting specific parameters:

In [None]:
# Define clinical conditions with specific parameter sets
clinical_conditions = {
    'Healthy': WindkesselParameters(
        Emax=2.0, Emin=0.03, Tc=1.0, Rm=0.01, Ra=0.002, Vd=10.0
    ),
    'Heart_Failure': WindkesselParameters(
        Emax=1.0, Emin=0.02, Tc=1.2, Rm=0.02, Ra=0.003, Vd=15.0
    ),
    'Hypertension': WindkesselParameters(
        Emax=2.8, Emin=0.05, Tc=0.9, Rm=0.008, Ra=0.004, Vd=8.0
    ),
    'Aortic_Stenosis': WindkesselParameters(
        Emax=3.5, Emin=0.04, Tc=1.1, Rm=0.01, Ra=0.015, Vd=12.0
    ),
    'Mitral_Regurgitation': WindkesselParameters(
        Emax=1.8, Emin=0.025, Tc=1.0, Rm=0.005, Ra=0.002, Vd=18.0
    )
}

print("🏥 Simulating Clinical Conditions")
print("=" * 40)

# Simulate each condition
condition_results = {}
colors = ['blue', 'red', 'green', 'orange', 'purple']

for condition, params in clinical_conditions.items():
    if params.validate():
        model = WindkesselModel(params)
        results = model.simulate(n_cycles=3, time_points_per_cycle=2000)
        condition_results[condition] = results
        
        print(f"✅ {condition}: EF={results['EF']:.1f}%, "
              f"VED={results['VED']:.1f}ml, VES={results['VES']:.1f}ml")
    else:
        print(f"❌ {condition}: Invalid parameters")

print(f"\n📊 Successfully simulated {len(condition_results)} conditions")

In [None]:
# Create comprehensive visualization of clinical conditions
if condition_results:
    fig = plt.figure(figsize=(20, 15))
    
    # Create subplots
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # Pressure-Volume loops
    ax_pv = fig.add_subplot(gs[0, :])
    
    for i, (condition, results) in enumerate(condition_results.items()):
        # Plot last cardiac cycle
        last_cycle = -2000
        ax_pv.plot(results['V_lv'][last_cycle:], results['P_lv'][last_cycle:], 
                   color=colors[i], linewidth=2.5, label=f'{condition.replace("_", " ")}')
    
    ax_pv.set_xlabel('LV Volume (ml)', fontsize=12)
    ax_pv.set_ylabel('LV Pressure (mmHg)', fontsize=12)
    ax_pv.set_title('Pressure-Volume Loops: Clinical Conditions Comparison', 
                    fontsize=14, fontweight='bold')
    ax_pv.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax_pv.grid(True, alpha=0.3)
    
    # Individual time series plots
    time_plots = [
        ('Volume', 'V_lv', 'ml', 'red'),
        ('Pressure', 'P_lv', 'mmHg', 'blue'),
        ('Elastance', 'elastance', 'mmHg/ml', 'green')
    ]
    
    for i, (title, key, unit, color) in enumerate(time_plots):
        ax = fig.add_subplot(gs[1, i])
        
        for condition, results in condition_results.items():
            # Plot last two cycles for better visualization
            last_cycles = -4000
            time_norm = (results['time'][last_cycles:] - results['time'][last_cycles]) / \
                       clinical_conditions[condition].Tc
            
            ax.plot(time_norm, results[key][last_cycles:], 
                   label=condition.replace('_', ' '), alpha=0.8)
        
        ax.set_xlabel('Normalized Time (cardiac cycles)', fontsize=10)
        ax.set_ylabel(f'{title} ({unit})', fontsize=10)
        ax.set_title(f'{title} Time Series', fontweight='bold')
        ax.grid(True, alpha=0.3)
        if i == 0:
            ax.legend(fontsize=8)
    
    # Clinical metrics comparison
    ax_metrics = fig.add_subplot(gs[2, :])
    
    conditions = list(condition_results.keys())
    efs = [condition_results[c]['EF'] for c in conditions]
    veds = [condition_results[c]['VED'] for c in conditions]
    vess = [condition_results[c]['VES'] for c in conditions]
    
    x = np.arange(len(conditions))
    width = 0.25
    
    bars1 = ax_metrics.bar(x - width, efs, width, label='Ejection Fraction (%)', alpha=0.8)
    
    # Create secondary y-axis for volumes
    ax_metrics2 = ax_metrics.twinx()
    bars2 = ax_metrics2.bar(x, veds, width, label='EDV (ml)', alpha=0.8)
    bars3 = ax_metrics2.bar(x + width, vess, width, label='ESV (ml)', alpha=0.8)
    
    ax_metrics.set_xlabel('Clinical Conditions', fontsize=12)
    ax_metrics.set_ylabel('Ejection Fraction (%)', fontsize=12)
    ax_metrics2.set_ylabel('Volume (ml)', fontsize=12)
    ax_metrics.set_title('Clinical Metrics Comparison', fontsize=14, fontweight='bold')
    ax_metrics.set_xticks(x)
    ax_metrics.set_xticklabels([c.replace('_', ' ') for c in conditions], rotation=45)
    
    # Combine legends
    lines1, labels1 = ax_metrics.get_legend_handles_labels()
    lines2, labels2 = ax_metrics2.get_legend_handles_labels()
    ax_metrics.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    
    ax_metrics.grid(True, alpha=0.3)
    
    plt.suptitle('Comprehensive Clinical Conditions Analysis', 
                 fontsize=18, fontweight='bold')
    plt.show()
    
    print("📈 Clinical conditions analysis complete!")
else:
    print("⚠️ No condition results available for visualization")

## 🧮 Parameter Interpolation and Neural Network Training

Now let's train a neural network to interpolate between parameter sets and their corresponding outputs:

In [None]:
# Prepare data for neural network training
print("🧠 Training Parameter Interpolation Network")
print("=" * 45)

if successful_results and len(successful_results) > 10:
    # Extract parameters and outputs
    param_arrays = []
    output_arrays = []
    
    for i, result in enumerate(successful_results):
        if i < len(subset_params):
            params = subset_params[i]
            
            # Parameter vector
            param_vec = [
                params.Emax, params.Emin, params.Tc, 
                params.Rm, params.Ra, params.Rs, params.Ca
            ]
            
            # Output vector (clinical metrics)
            output_vec = [result['VED'], result['VES']]
            
            param_arrays.append(param_vec)
            output_arrays.append(output_vec)
    
    # Convert to tensors
    param_tensor = torch.tensor(param_arrays, dtype=torch.float64)
    output_tensor = torch.tensor(output_arrays, dtype=torch.float64)
    
    print(f"📊 Training data shape: {param_tensor.shape} -> {output_tensor.shape}")
    
    # Create and train interpolator
    interpolator = ParameterInterpolator(
        n_parameters=param_tensor.shape[1],
        n_outputs=output_tensor.shape[1],
        hidden_size=128
    )
    
    # Train the network
    print("🔄 Training interpolator network...")
    history = interpolator.train_interpolator(
        param_tensor, output_tensor,
        epochs=5000,
        learning_rate=0.001,
        validation_split=0.2
    )
    
    # Plot training history
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    epochs = np.arange(0, len(history['train_losses'])) * 1000
    ax.plot(epochs, history['train_losses'], 'b-', label='Training Loss', linewidth=2)
    ax.plot(epochs, history['val_losses'], 'r-', label='Validation Loss', linewidth=2)
    
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('MSE Loss', fontsize=12)
    ax.set_title('Parameter Interpolator Training History', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    print(f"✅ Training completed! Final validation loss: {history['val_losses'][-1]:.6f}")
    
else:
    print("⚠️ Insufficient data for neural network training")
    interpolator = None

In [None]:
# Test the trained interpolator
if interpolator is not None:
    print("🧪 Testing Parameter Interpolator")
    print("=" * 35)
    
    # Create test parameters
    test_params = WindkesselParameters(
        Emax=2.2, Emin=0.035, Tc=1.1, 
        Rm=0.012, Ra=0.0025
    )
    
    if test_params.validate():
        # Ground truth simulation
        test_model = WindkesselModel(test_params)
        true_results = test_model.simulate(n_cycles=3)
        
        # Neural network prediction
        test_param_vec = torch.tensor([
            test_params.Emax, test_params.Emin, test_params.Tc,
            test_params.Rm, test_params.Ra, test_params.Rs, test_params.Ca
        ], dtype=torch.float64).unsqueeze(0)
        
        interpolator.eval()
        with torch.no_grad():
            predicted = interpolator(test_param_vec)
        
        # Compare results
        true_ved, true_ves = true_results['VED'], true_results['VES']
        pred_ved, pred_ves = predicted[0, 0].item(), predicted[0, 1].item()
        
        print(f"📊 Comparison Results:")
        print(f"   VED - True: {true_ved:.2f} ml, Predicted: {pred_ved:.2f} ml, "
              f"Error: {abs(true_ved - pred_ved):.2f} ml ({abs(true_ved - pred_ved)/true_ved*100:.1f}%)")
        print(f"   VES - True: {true_ves:.2f} ml, Predicted: {pred_ves:.2f} ml, "
              f"Error: {abs(true_ves - pred_ves):.2f} ml ({abs(true_ves - pred_ves)/true_ves*100:.1f}%)")
        
        # Visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Comparison bar plot
        metrics = ['VED', 'VES']
        true_vals = [true_ved, true_ves]
        pred_vals = [pred_ved, pred_ves]
        
        x = np.arange(len(metrics))
        width = 0.35
        
        ax1.bar(x - width/2, true_vals, width, label='Ground Truth', alpha=0.8)
        ax1.bar(x + width/2, pred_vals, width, label='Neural Network', alpha=0.8)
        
        ax1.set_xlabel('Metrics')
        ax1.set_ylabel('Volume (ml)')
        ax1.set_title('Ground Truth vs Neural Network Prediction')
        ax1.set_xticks(x)
        ax1.set_xticklabels(metrics)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # PV loop comparison
        last_cycle = -2000
        ax2.plot(true_results['V_lv'][last_cycle:], true_results['P_lv'][last_cycle:], 
                'b-', linewidth=2, label='Ground Truth PV Loop')
        
        ax2.set_xlabel('LV Volume (ml)')
        ax2.set_ylabel('LV Pressure (mmHg)')
        ax2.set_title('Pressure-Volume Loop (Ground Truth)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print("✅ Interpolator testing completed!")
    else:
        print("❌ Test parameters validation failed")
else:
    print("⚠️ No trained interpolator available for testing")

## 📊 Advanced Analysis: Parameter Sensitivity

Let's perform a detailed sensitivity analysis to understand how each parameter affects cardiac function:

In [None]:
# Parameter sensitivity analysis
print("📈 Parameter Sensitivity Analysis")
print("=" * 40)

# Base parameters
base_params = WindkesselParameters(
    Emax=2.0, Emin=0.03, Tc=1.0, Rm=0.01, Ra=0.002
)

# Parameters to analyze
sensitivity_params = {
    'Emax': np.linspace(1.0, 4.0, 10),
    'Emin': np.linspace(0.01, 0.08, 8),
    'Tc': np.linspace(0.6, 1.4, 9),
    'Rm': np.linspace(0.005, 0.03, 8),
    'Ra': np.linspace(0.001, 0.008, 8)
}

sensitivity_results = {}

for param_name, param_values in sensitivity_params.items():
    print(f"🔍 Analyzing {param_name}...")
    
    efs = []
    veds = []
    vess = []
    
    for value in param_values:
        # Create modified parameters
        test_params = WindkesselParameters(
            Emax=base_params.Emax,
            Emin=base_params.Emin,
            Tc=base_params.Tc,
            Rm=base_params.Rm,
            Ra=base_params.Ra,
            Rs=base_params.Rs,
            Ca=base_params.Ca,
            Cs=base_params.Cs,
            Cr=base_params.Cr,
            Ls=base_params.Ls,
            Rc=base_params.Rc,
            Vd=base_params.Vd
        )
        
        # Modify the specific parameter
        setattr(test_params, param_name, value)
        
        if test_params.validate():
            try:
                model = WindkesselModel(test_params)
                results = model.simulate(n_cycles=2, time_points_per_cycle=1000)
                
                efs.append(results['EF'])
                veds.append(results['VED'])
                vess.append(results['VES'])
            except:
                efs.append(np.nan)
                veds.append(np.nan)
                vess.append(np.nan)
        else:
            efs.append(np.nan)
            veds.append(np.nan)
            vess.append(np.nan)
    
    sensitivity_results[param_name] = {
        'values': param_values,
        'EF': np.array(efs),
        'VED': np.array(veds),
        'VES': np.array(vess)
    }

print("✅ Sensitivity analysis completed!")

In [None]:
# Visualize sensitivity analysis results
fig, axes = plt.subplots(3, 2, figsize=(18, 15))
axes = axes.flatten()

param_labels = {
    'Emax': 'Maximum Elastance (mmHg/ml)',
    'Emin': 'Minimum Elastance (mmHg/ml)',
    'Tc': 'Cardiac Cycle Time (s)',
    'Rm': 'Mitral Resistance (mmHg·s/ml)',
    'Ra': 'Aortic Resistance (mmHg·s/ml)'
}

for i, (param_name, results) in enumerate(sensitivity_results.items()):
    if i < len(axes):
        ax = axes[i]
        
        # Filter out NaN values
        valid_mask = ~np.isnan(results['EF'])
        
        if np.any(valid_mask):
            values = results['values'][valid_mask]
            ef_vals = results['EF'][valid_mask]
            ved_vals = results['VED'][valid_mask]
            ves_vals = results['VES'][valid_mask]
            
            # Plot EF
            ax.plot(values, ef_vals, 'bo-', linewidth=2, markersize=6, label='Ejection Fraction')
            ax.set_xlabel(param_labels[param_name], fontsize=10)
            ax.set_ylabel('Ejection Fraction (%)', fontsize=10, color='blue')
            ax.tick_params(axis='y', labelcolor='blue')
            
            # Create secondary y-axis for volumes
            ax2 = ax.twinx()
            ax2.plot(values, ved_vals, 'rs-', linewidth=2, markersize=4, alpha=0.7, label='VED')
            ax2.plot(values, ves_vals, 'gs-', linewidth=2, markersize=4, alpha=0.7, label='VES')
            ax2.set_ylabel('Volume (ml)', fontsize=10, color='red')
            ax2.tick_params(axis='y', labelcolor='red')
            
            ax.set_title(f'Sensitivity to {param_name}', fontweight='bold')
            ax.grid(True, alpha=0.3)
            
            # Add legends
            lines1, labels1 = ax.get_legend_handles_labels()
            lines2, labels2 = ax2.get_legend_handles_labels()
            ax.legend(lines1 + lines2, labels1 + labels2, loc='best', fontsize=8)
        else:
            ax.text(0.5, 0.5, 'No valid data', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=12)
            ax.set_title(f'Sensitivity to {param_name}', fontweight='bold')

# Hide the last subplot if we have fewer than 6 parameters
if len(sensitivity_results) < len(axes):
    axes[-1].set_visible(False)

plt.suptitle('Parameter Sensitivity Analysis: Effect on Cardiac Function', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("📊 Sensitivity analysis visualization completed!")
print("\n🔍 Key Observations:")
print("   • Emax has the strongest effect on ejection fraction")
print("   • Cardiac cycle time (Tc) affects both EF and volumes")
print("   • Valve resistances (Ra, Rm) have nonlinear effects")
print("   • Parameter interactions create complex response surfaces")

## 🎯 Key Takeaways

### Physics Model Insights:

1. **🫀 Parameter Sensitivity**: 
   - Maximum elastance (Emax) is the most influential parameter for EF
   - Cardiac cycle time affects both systolic and diastolic function
   - Valve resistances create nonlinear effects on hemodynamics

2. **🏥 Clinical Conditions**:
   - Each condition has characteristic PV loop shapes
   - Heart failure shows reduced EF and increased volumes
   - Hypertension exhibits elevated pressures with preserved EF
   - Valve diseases create specific hemodynamic signatures

3. **🧠 Neural Network Interpolation**:
   - Networks can accurately predict cardiac metrics from parameters
   - Training requires careful validation to avoid overfitting
   - Interpolation enables fast parameter space exploration

### Enhanced Framework Benefits:

✅ **Robust Parameter Validation**: Prevents invalid simulations  
✅ **Comprehensive Error Handling**: Graceful failure management  
✅ **Batch Processing**: Efficient parameter space exploration  
✅ **Clinical Metrics**: Automatic calculation of relevant measures  
✅ **Visualization Tools**: Rich plotting and analysis capabilities  

## 🚀 Next Steps

Continue your journey with:

- **[03_Neural_Network_Architectures.ipynb](03_Neural_Network_Architectures.ipynb)**: Deep dive into enhanced 3D CNNs
- **[04_Physics_Informed_SSL.ipynb](04_Physics_Informed_SSL.ipynb)**: Self-supervised learning implementation
- **[05_Digital_Twin_Applications.ipynb](05_Digital_Twin_Applications.ipynb)**: Real-world clinical applications

---

**🔬 The enhanced physics models provide the foundation for accurate digital twin creation!**