# Location: notebooks/06_special_topics/17_differentiable_odes.ipynb

## Differentiable Ordinary Differential Equations (ODEs) in JAX

This notebook explores solving and differentiating through ODEs using JAX, including neural ODEs, adjoint methods, and applications to dynamical systems modeling.

## Basic ODE Solving with JAX

```python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial

# Simple Euler method implementation
def euler_step(f, y, t, dt):
    """Single Euler integration step"""
    return y + dt * f(y, t)

def ode_solve_euler(f, y0, t_span, dt):
    """Solve ODE using Euler method"""
    t_start, t_end = t_span
    n_steps = int((t_end - t_start) / dt)
    
    def scan_step(carry, t):
        y = carry
        y_next = euler_step(f, y, t, dt)
        return y_next, y_next
    
    t_points = jnp.linspace(t_start, t_end, n_steps + 1)
    _, trajectory = jax.lax.scan(scan_step, y0, t_points[:-1])
    
    return jnp.concatenate([y0[None], trajectory]), t_points

# Example: Simple harmonic oscillator
def harmonic_oscillator(y, t):
    """dy/dt = [y1, -y0] for y = [position, velocity]"""
    return jnp.array([y[1], -y[0]])

# Solve harmonic oscillator
y0 = jnp.array([1.0, 0.0])  # Initial condition: position=1, velocity=0
dt = 0.01
t_span = (0.0, 2 * jnp.pi)

trajectory, t_points = ode_solve_euler(harmonic_oscillator, y0, t_span, dt)

print(f"Trajectory shape: {trajectory.shape}")
print(f"Time points shape: {t_points.shape}")
print(f"Initial state: {trajectory[0]}")
print(f"Final state: {trajectory[-1]}")

# The solution should be approximately sinusoidal
expected_final = jnp.array([jnp.cos(2 * jnp.pi), -jnp.sin(2 * jnp.pi)])
print(f"Expected final state: {expected_final}")
print(f"Close to expected: {jnp.allclose(trajectory[-1], expected_final, atol=1e-1)}")
```

## Advanced ODE Solvers

```python
def runge_kutta_4(f, y, t, dt):
    """Fourth-order Runge-Kutta step"""
    k1 = dt * f(y, t)
    k2 = dt * f(y + k1/2, t + dt/2)
    k3 = dt * f(y + k2/2, t + dt/2) 
    k4 = dt * f(y + k3, t + dt)
    return y + (k1 + 2*k2 + 2*k3 + k4) / 6

def ode_solve_rk4(f, y0, t_span, dt):
    """Solve ODE using RK4 method"""
    t_start, t_end = t_span
    n_steps = int((t_end - t_start) / dt)
    
    def scan_step(carry, t):
        y = carry
        y_next = runge_kutta_4(f, y, t, dt)
        return y_next, y_next
    
    t_points = jnp.linspace(t_start, t_end, n_steps + 1)
    _, trajectory = jax.lax.scan(scan_step, y0, t_points[:-1])
    
    return jnp.concatenate([y0[None], trajectory]), t_points

# Adaptive step size solver (simplified)
def adaptive_step_solver(f, y0, t_span, initial_dt=0.01, tol=1e-6):
    """Simple adaptive step size ODE solver"""
    t_start, t_end = t_span
    
    def adaptive_step(state):
        y, t, dt = state
        
        # Full step
        y_full = runge_kutta_4(f, y, t, dt)
        
        # Two half steps  
        y_half1 = runge_kutta_4(f, y, t, dt/2)
        y_half2 = runge_kutta_4(f, y_half1, t + dt/2, dt/2)
        
        # Error estimate
        error = jnp.linalg.norm(y_full - y_half2)
        
        # Adjust step size
        dt_new = jnp.where(error > tol, dt * 0.8, 
                          jnp.where(error < tol/10, dt * 1.2, dt))
        dt_new = jnp.clip(dt_new, 1e-6, 0.1)
        
        # Accept step if error is acceptable
        y_new = jnp.where(error <= tol, y_half2, y)
        t_new = jnp.where(error <= tol, t + dt, t)
        
        return (y_new, t_new, dt_new), (y_new, t_new, error)
    
    # Initialize
    state = (y0, t_start, initial_dt)
    
    # Integrate until t_end (simplified version)
    states = [state]
    outputs = [(y0, t_start, 0.0)]
    
    for _ in range(1000):  # Maximum iterations
        state, output = adaptive_step(state)
        states.append(state)
        outputs.append(output)
        
        if state[1] >= t_end:
            break
    
    # Extract results
    y_vals = jnp.array([out[0] for out in outputs])
    t_vals = jnp.array([out[1] for out in outputs])
    errors = jnp.array([out[2] for out in outputs])
    
    return y_vals, t_vals, errors

# Compare different solvers on a challenging problem
def van_der_pol(y, t, mu=2.0):
    """Van der Pol oscillator"""
    x, v = y
    return jnp.array([v, mu * (1 - x**2) * v - x])

y0_vdp = jnp.array([2.0, 0.0])
t_span_vdp = (0.0, 10.0)

# Solve with different methods
traj_euler, t_euler = ode_solve_euler(van_der_pol, y0_vdp, t_span_vdp, 0.01)
traj_rk4, t_rk4 = ode_solve_rk4(van_der_pol, y0_vdp, t_span_vdp, 0.01)

print("Van der Pol Oscillator Solutions:")
print(f"Euler final state: {traj_euler[-1]}")
print(f"RK4 final state: {traj_rk4[-1]}")
print(f"Solutions differ by: {jnp.linalg.norm(traj_euler[-1] - traj_rk4[-1])}")
```

## Neural ODEs: Parameterized Dynamics

```python
def create_neural_ode(hidden_dim=64):
    """Create a neural ODE with learnable dynamics"""
    
    def init_params(key, input_dim):
        k1, k2, k3 = jax.random.split(key, 3)
        
        # Simple 2-layer MLP for dynamics
        params = {
            'w1': jax.random.normal(k1, (input_dim, hidden_dim)) * 0.1,
            'b1': jnp.zeros(hidden_dim),
            'w2': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.1,
            'b2': jnp.zeros(hidden_dim),
            'w3': jax.random.normal(k3, (hidden_dim, input_dim)) * 0.1,
            'b3': jnp.zeros(input_dim)
        }
        return params
    
    def neural_dynamics(params, y, t):
        """Neural network dynamics function"""
        # Input is [y, t] concatenated
        x = jnp.concatenate([y, jnp.array([t])])
        
        # Forward pass through MLP
        h1 = jax.nn.tanh(x @ params['w1'] + params['b1'])
        h2 = jax.nn.tanh(h1 @ params['w2'] + params['b2'])
        dydt = h2 @ params['w3'] + params['b3']
        
        return dydt
    
    def neural_ode_forward(params, y0, t_span, dt=0.01):
        """Forward pass through neural ODE"""
        dynamics_fn = lambda y, t: neural_dynamics(params, y, t)
        trajectory, t_points = ode_solve_rk4(dynamics_fn, y0, t_span, dt)
        return trajectory, t_points
    
    return init_params, neural_ode_forward

# Initialize neural ODE
init_neural_ode, neural_ode_forward = create_neural_ode(hidden_dim=32)
neural_params = init_neural_ode(jax.random.PRNGKey(42), input_dim=2)

print("Neural ODE Parameters:")
for name, param in neural_params.items():
    print(f"  {name}: {param.shape}")

# Test neural ODE
y0_test = jnp.array([1.0, 0.0])
neural_traj, neural_t = neural_ode_forward(neural_params, y0_test, (0.0, 2.0))

print(f"Neural ODE trajectory shape: {neural_traj.shape}")
print(f"Initial state: {neural_traj[0]}")
print(f"Final state: {neural_traj[-1]}")
```

## Training Neural ODEs

```python
def create_neural_ode_loss():
    """Create loss function for training neural ODEs"""
    
    def trajectory_loss(params, y0, target_trajectory, t_points, dt=0.01):
        """Loss based on trajectory matching"""
        # Get predicted trajectory
        t_span = (t_points[0], t_points[-1])
        pred_trajectory, _ = neural_ode_forward(params, y0, t_span, dt)
        
        # Interpolate predictions to match target time points
        # Simplified: assume same time grid
        n_target = target_trajectory.shape[0]
        n_pred = pred_trajectory.shape[0]
        
        if n_pred >= n_target:
            # Downsample predictions
            indices = jnp.linspace(0, n_pred-1, n_target).astype(int)
            pred_interp = pred_trajectory[indices]
        else:
            # Use available predictions (simplified)
            pred_interp = pred_trajectory[:n_target]
        
        # MSE loss
        loss = jnp.mean((pred_interp - target_trajectory) ** 2)
        return loss
    
    def endpoint_loss(params, y0, target_endpoint, t_final):
        """Loss based only on final endpoint"""
        t_span = (0.0, t_final)
        trajectory, _ = neural_ode_forward(params, y0, t_span)
        final_state = trajectory[-1]
        return jnp.mean((final_state - target_endpoint) ** 2)
    
    return trajectory_loss, endpoint_loss

# Create training setup
trajectory_loss_fn, endpoint_loss_fn = create_neural_ode_loss()

# Generate synthetic training data (spiral dynamics)
def true_spiral_dynamics(y, t, omega=1.0):
    """True dynamics we want to learn"""
    x, v = y
    return jnp.array([-omega * v, omega * x])

# Generate training trajectory
y0_train = jnp.array([1.0, 0.0])
true_traj, true_t = ode_solve_rk4(true_spiral_dynamics, y0_train, (0.0, 4.0), 0.02)

print(f"Training data shape: {true_traj.shape}")

# Training step
def train_step(params, y0, target_traj, t_points, lr=0.01):
    """Single training step"""
    loss, grads = jax.value_and_grad(trajectory_loss_fn)(params, y0, target_traj, t_points)
    
    # Simple SGD update
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# Training loop
train_params = neural_params
losses = []

print("Training Neural ODE:")
for epoch in range(50):
    train_params, loss = train_step(train_params, y0_train, true_traj, true_t)
    losses.append(loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss:.6f}")

print(f"Final loss: {losses[-1]:.6f}")

# Test trained model
trained_traj, _ = neural_ode_forward(train_params, y0_train, (0.0, 4.0))
final_error = jnp.linalg.norm(trained_traj[-1] - true_traj[-1])
print(f"Final state error: {final_error:.6f}")
```

## Adjoint Method for Memory-Efficient Training

```python
def create_adjoint_neural_ode():
    """Create neural ODE with adjoint method for memory efficiency"""
    
    def augmented_dynamics(aug_state, t, params):
        """Augmented dynamics for adjoint method"""
        # aug_state = [y, lambda, params_flat]
        n_y = 2  # Dimension of original state
        y = aug_state[:n_y]
        
        # Forward dynamics
        dydt = neural_dynamics(params, y, t)
        
        return dydt
    
    def adjoint_solve(params, y0, loss_grad_output, t_span, dt=0.01):
        """Solve adjoint equation for gradients"""
        # This is a simplified version of the adjoint method
        
        # Forward pass to get trajectory
        trajectory, t_points = neural_ode_forward(params, y0, t_span, dt)
        
        # Backward pass (simplified - normally uses adjoint ODE)
        def loss_fn(p):
            traj, _ = neural_ode_forward(p, y0, t_span, dt)
            return jnp.sum(traj[-1] * loss_grad_output)  # Simplified loss
        
        grads = jax.grad(loss_fn)(params)
        
        return grads, trajectory
    
    def efficient_train_step(params, y0, target, t_span, lr=0.01):
        """Training step using adjoint method"""
        # Forward pass
        pred_traj, _ = neural_ode_forward(params, y0, t_span)
        final_pred = pred_traj[-1]
        
        # Loss and its gradient w.r.t. final state
        loss = jnp.sum((final_pred - target) ** 2)
        loss_grad = 2 * (final_pred - target)
        
        # Get parameter gradients via adjoint
        param_grads, _ = adjoint_solve(params, y0, loss_grad, t_span)
        
        # Update parameters
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, param_grads)
        
        return new_params, loss
    
    return efficient_train_step

# Test adjoint method
efficient_train_step = create_adjoint_neural_ode()

# Compare memory usage (conceptually)
adjoint_params = init_neural_ode(jax.random.PRNGKey(123), input_dim=2)
target_final = true_traj[-1]

print("Training with Adjoint Method:")
for epoch in range(10):
    adjoint_params, loss = efficient_train_step(
        adjoint_params, y0_train, target_final, (0.0, 4.0)
    )
    
    if epoch % 2 == 0:
        print(f"Epoch {epoch}: Loss = {loss:.6f}")

# Test final performance
adjoint_traj, _ = neural_ode_forward(adjoint_params, y0_train, (0.0, 4.0))
adjoint_error = jnp.linalg.norm(adjoint_traj[-1] - true_traj[-1])
print(f"Adjoint method final error: {adjoint_error:.6f}")
```

## Continuous Normalizing Flows

```python
def create_cnf_model():
    """Create a Continuous Normalizing Flow model"""
    
    def cnf_dynamics(params, y, t):
        """CNF dynamics with divergence computation"""
        # Augment state with log-determinant
        n_dims = 2
        z = y[:n_dims]
        
        # Neural network dynamics
        dzdt = neural_dynamics(params, z, t)
        
        # Compute trace of Jacobian (divergence) - simplified
        def trace_estimator(z):
            return jnp.sum(jax.vmap(jax.grad(lambda zi: neural_dynamics(params, zi, t)[0]))(z))
        
        # Hutchinson's trace estimator (simplified)
        div_estimate = trace_estimator(z)
        
        # Return [dzdt, -div]
        return jnp.concatenate([dzdt, jnp.array([-div_estimate])])
    
    def cnf_forward(params, z0, t_span=(0.0, 1.0)):
        """Forward pass through CNF"""
        # Augment initial state with log-det-jacobian = 0
        aug_z0 = jnp.concatenate([z0, jnp.array([0.0])])
        
        # Solve ODE
        dynamics_fn = lambda y, t: cnf_dynamics(params, y, t)
        aug_trajectory, t_points = ode_solve_rk4(dynamics_fn, aug_z0, t_span, dt=0.01)
        
        # Extract trajectory and log-det-jacobians
        trajectory = aug_trajectory[:, :2]
        log_det_jac = aug_trajectory[:, 2]
        
        return trajectory, log_det_jac, t_points
    
    def cnf_log_likelihood(params, x, base_log_prob_fn):
        """Compute log-likelihood under CNF model"""
        # Transform x back to base distribution
        traj, log_det_jac, _ = cnf_forward(params, x, t_span=(1.0, 0.0))  # Reverse
        z_base = traj[-1]
        
        # Log-likelihood = base_log_prob + log_det_jac
        base_log_prob = base_log_prob_fn(z_base)
        log_likelihood = base_log_prob + log_det_jac[-1]
        
        return log_likelihood, z_base
    
    return cnf_forward, cnf_log_likelihood

# Create CNF model
cnf_forward, cnf_log_likelihood = create_cnf_model()

# Test CNF
cnf_params = init_neural_ode(jax.random.PRNGKey(456), input_dim=2)

# Sample from base distribution (standard normal)
base_sample = jax.random.normal(jax.random.PRNGKey(789), (2,))
print(f"Base sample: {base_sample}")

# Transform through CNF
cnf_traj, cnf_log_det, cnf_t = cnf_forward(cnf_params, base_sample)
transformed_sample = cnf_traj[-1]

print(f"Transformed sample: {transformed_sample}")
print(f"Log-det-jacobian: {cnf_log_det[-1]}")

# Test log-likelihood computation
def standard_normal_log_prob(z):
    """Log probability of standard normal"""
    return -0.5 * jnp.sum(z ** 2) - jnp.log(2 * jnp.pi)

test_point = jnp.array([0.5, -0.3])
log_lik, base_z = cnf_log_likelihood(cnf_params, test_point, standard_normal_log_prob)
print(f"Log-likelihood at test point: {log_lik}")
print(f"Corresponding base point: {base_z}")
```

## Applications: Modeling Physical Systems

```python
def create_physics_informed_ode():
    """Create physics-informed neural ODE"""
    
    def hamiltonian_dynamics(params, y, t):
        """Hamiltonian dynamics with neural network potential"""
        n_dims = len(y) // 2
        q = y[:n_dims]  # Positions
        p = y[n_dims:]  # Momenta
        
        # dq/dt = ∂H/∂p = p (assuming unit mass)
        dqdt = p
        
        # dp/dt = -∂H/∂q = -∂V/∂q (V is potential energy)
        # Learn potential gradient with neural network
        def potential_energy(pos):
            # Simple neural network for potential
            h = jax.nn.tanh(pos @ params['w1'] + params['b1'])
            return jnp.sum(h @ params['w2'] + params['b2'])
        
        dpdt = -jax.grad(potential_energy)(q)
        
        return jnp.concatenate([dqdt, dpdt])
    
    def init_hamiltonian_params(key):
        """Initialize parameters for Hamiltonian system"""
        k1, k2 = jax.random.split(key)
        return {
            'w1': jax.random.normal(k1, (2, 16)) * 0.1,
            'b1': jnp.zeros(16),
            'w2': jax.random.normal(k2, (16,)) * 0.1,
            'b2': 0.0
        }
    
    def energy_conservation_loss(params, y0, t_span):
        """Loss that encourages energy conservation"""
        dynamics_fn = lambda y, t: hamiltonian_dynamics(params, y, t)
        trajectory, _ = ode_solve_rk4(dynamics_fn, y0, t_span, dt=0.01)
        
        # Compute energy at each time step
        def total_energy(state):
            n_dims = len(state) // 2
            q, p = state[:n_dims], state[n_dims:]
            
            # Kinetic energy: 0.5 * |p|^2
            kinetic = 0.5 * jnp.sum(p ** 2)
            
            # Potential energy from neural network
            h = jax.nn.tanh(q @ params['w1'] + params['b1'])
            potential = jnp.sum(h @ params['w2'] + params['b2'])
            
            return kinetic + potential
        
        energies = jax.vmap(total_energy)(trajectory)
        
        # Energy should be conserved
        energy_variance = jnp.var(energies)
        return energy_variance
    
    return init_hamiltonian_params, hamiltonian_dynamics, energy_conservation_loss

# Create physics-informed model
init_ham_params, ham_dynamics, energy_loss_fn = create_physics_informed_ode()
ham_params = init_ham_params(jax.random.PRNGKey(999))

# Test energy conservation
y0_ham = jnp.array([1.0, 0.0, 0.0, 1.0])  # [q1, q2, p1, p2]
energy_loss = energy_loss_fn(ham_params, y0_ham, (0.0, 2.0))
print(f"Initial energy conservation loss: {energy_loss:.6f}")

# Train to minimize energy variance
def train_physics_step(params, y0, t_span, lr=0.01):
    loss, grads = jax.value_and_grad(energy_loss_fn)(params, y0, t_span)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

print("Training Physics-Informed ODE:")
for epoch in range(20):
    ham_params, loss = train_physics_step(ham_params, y0_ham, (0.0, 2.0))
    
    if epoch % 5 == 0:
        print(f"Epoch {epoch}: Energy variance = {loss:.6f}")

print(f"Final energy conservation loss: {loss:.6f}")

# Test trained Hamiltonian system
ham_dynamics_trained = lambda y, t: hamiltonian_dynamics(ham_params, y, t)
ham_trajectory, ham_t = ode_solve_rk4(ham_dynamics_trained, y0_ham, (0.0, 5.0), dt=0.01)

print(f"Hamiltonian trajectory shape: {ham_trajectory.shape}")
print(f"Initial state: {ham_trajectory[0]}")
print(f"Final state: {ham_trajectory[-1]}")
```

## Summary

In this notebook, we explored differentiable ODEs in JAX:

**Core Concepts:**
- **ODE Solvers**: Euler, Runge-Kutta, and adaptive methods
- **Neural ODEs**: Parameterized dynamics with neural networks  
- **Adjoint Method**: Memory-efficient gradient computation
- **Continuous Flows**: Normalizing flows with ODE dynamics

**Key Applications:**
- **Dynamical Systems**: Modeling physical and biological systems
- **Generative Models**: Continuous normalizing flows for density estimation
- **Time Series**: Neural ODEs for irregular time series modeling
- **Physics-Informed**: Incorporating physical constraints and conservation laws

**Training Strategies:**
- **Trajectory Matching**: Fit entire solution paths
- **Endpoint Fitting**: Match only final states
- **Conservation Laws**: Physics-informed losses
- **Likelihood Training**: Maximum likelihood for generative models

**Advanced Techniques:**
- **Augmented Dynamics**: Including log-determinant computation
- **Stochastic ODEs**: Adding noise for uncertainty quantification
- **Hamiltonian Systems**: Energy-conserving dynamics
- **Multi-Scale**: Handling systems with different time scales

**Computational Benefits:**
- **Automatic Differentiation**: End-to-end gradient computation
- **Adaptive Integration**: Error control and efficiency
- **Memory Efficiency**: Adjoint method for long trajectories
- **Parallelization**: Vectorized operations across initial conditions

Neural ODEs provide a powerful framework for learning continuous-time dynamics while leveraging the full power of automatic differentiation and modern optimization techniques in JAX.