In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
from scipy.linalg import expm as scipy_expm
import time
import matplotlib.pyplot as plt
DTYPE_FP64 = jnp.float64

# Generate stiff dense A matrices
def generate_stiff_A(key, hidden_dim, input_dim, stiffness_factor=50):
    keys = jax.random.split(key, num=input_dim*2)
    A_list = []
    for i in range(input_dim):
        Q_i, _ = jnp.linalg.qr(jax.random.normal(keys[2*i], (hidden_dim, hidden_dim), dtype=DTYPE_FP64))
        eigvals = jnp.linspace(-stiffness_factor, 0.0, hidden_dim)
        D_i = jnp.diag(eigvals)
        A_i = Q_i @ D_i @ jnp.linalg.inv(Q_i)
        A_list.append(A_i)
    return jnp.stack(A_list, axis=0)

# Generate data paths with irregular sampling
def generate_paths(key, num_obs, d_omega):
    key1, key2 = jax.random.split(key)
    regular_times = jnp.linspace(0, 10, num_obs + 1)
    perturbations = jax.random.uniform(key1, (num_obs - 1,), minval=-0.03, maxval=0.03)
    irregular_times = jnp.sort(regular_times.at[1:-1].add(perturbations))
    dt_irregular = jnp.diff(irregular_times)
    dW = jax.random.normal(key2, (num_obs, d_omega), dtype=DTYPE_FP64) * jnp.sqrt(dt_irregular[:, None])
    omega_observed = jnp.vstack([jnp.zeros((1, d_omega), dtype=DTYPE_FP64), jnp.cumsum(dW, axis=0)])
    return irregular_times, omega_observed, dW

# Scaling and squaring w Padé
def expm_pade_with_scaling(A):
    I = jnp.eye(A.shape[0], dtype=A.dtype)
    A_norm = jnp.linalg.norm(A, ord=1)
    theta_13 = 1
    A_norm_safe = jnp.maximum(A_norm, 1e-16)
    s = jnp.maximum(0, jnp.ceil(jnp.log2(A_norm_safe / theta_13))).astype(jnp.int32)
    A_scaled = A / (2.0 ** s)
    b = jnp.array([
        64764752532480000, 32382376266240000, 7771770303897600,
        1187353796428800, 129060195264000, 10559470521600,
        670442572800, 33522128640, 1323241920,
        40840800, 960960, 16380,
        182, 1
    ], dtype=A.dtype)
    A2 = A_scaled @ A_scaled
    A4 = A2 @ A2
    A6 = A2 @ A4
    U = A_scaled @ (A6 @ (b[13] * A6 + b[11] * A4 + b[9] * A2) + b[7]*A6 + b[5]*A4 + b[3]*A2 + b[1]*I)
    V = A6 @ (b[12]*A6 + b[10]*A4 + b[8]*A2) + b[6]*A6 + b[4]*A4 + b[2]*A2 + b[0]*I
    R = jnp.linalg.solve(-U + V, U + V)
    def body_fun(_, R_inner):
        return R_inner @ R_inner
    R = jax.lax.fori_loop(0, s, body_fun, R)
    return R, s

# Linear CDE
class LinearCDE:
    def __init__(self, input_dim, hidden_dim, A_matrices, g_scale=2.0):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.A_matrices = A_matrices
        self.g_scale = g_scale
        self.times = None
        self.omega_vals = None

    def set_linear_path(self, times, omega_vals):
        self.times = jnp.array(times, dtype=DTYPE_FP64)
        self.omega_vals = jnp.array(omega_vals, dtype=DTYPE_FP64)

    def evaluate_omega_at_time(self, t_eval):
        # Find in which interval t_eval is
        interval_idx = jnp.searchsorted(self.times, t_eval) - 1
        interval_idx = jnp.clip(interval_idx, 0, len(self.times) - 2)
        t_start = self.times[interval_idx]
        t_end = self.times[interval_idx + 1]
        # Get omega values at interval endpoints
        omega_start = self.omega_vals[interval_idx]
        omega_end = self.omega_vals[interval_idx + 1]
        # Linear interpolation
        alpha = (t_eval - t_start) / (t_end - t_start)
        omega_at_t = omega_start + alpha * (omega_end - omega_start)
        return omega_at_t

    def fixed_step_generator(self, t_start, t_end):
        omega_start = self.evaluate_omega_at_time(t_start)
        omega_end = self.evaluate_omega_at_time(t_end)
        delta_omega = omega_end - omega_start
        G = self.g_scale * jnp.einsum('i,ijk->jk', delta_omega, self.A_matrices)
        return G

    def evaluate_generator_at_time(self, t_eval):

        interval_idx = jnp.searchsorted(self.times, t_eval) - 1
        interval_idx = jnp.clip(interval_idx, 0, len(self.times) - 2)
        t_start = self.times[interval_idx]
        t_end = self.times[interval_idx + 1]
        omega_start = self.omega_vals[interval_idx]
        omega_end = self.omega_vals[interval_idx + 1]
        interval_length = t_end - t_start
        omega_rate = (omega_end - omega_start) / interval_length  # dω/dt in that interval
        G_inst = self.g_scale * jnp.einsum('i,ijk->jk', omega_rate, self.A_matrices)
        return G_inst

    def estimate_starting_step_size(self, h0, t0, p, Atol=1e-6, Rtol=1e-3):
        # Compute d0 and d1
        d0 = jnp.linalg.norm(h0)
        A_t0 = self.evaluate_generator_at_time(t0)
        f0 = A_t0 @ h0
        d1 = jnp.linalg.norm(f0)

        # Handle cases where d0 or d1 are very small
        d0_safe = jnp.maximum(d0, 1e-5)
        d1_safe = jnp.maximum(d1, 1e-5)
        h0_step = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0_safe / d1_safe)

        # Perform one Euler step
        h1 = h0 + h0_step * f0
        f1 = self.evaluate_generator_at_time(t0 + h0_step) @ h1

        # Estimate second derivative
        d2 = jnp.linalg.norm(f1 - f0) / h0_step

        # Second guess for step size
        max_d1_d2 = jnp.maximum(d1, d2)

        # Handle case where max(d1, d2) is small
        h1_step = jnp.where(
            max_d1_d2 <= 1e-15,
            jnp.maximum(1e-6, h0_step * 1e-3),
            (0.01 / max_d1_d2) ** (1.0 / (p + 1))
        )

        # Final step size
        dt_start = jnp.minimum(100 * h0_step, h1_step)
        return float(dt_start)

    # Fixed step method
    def forward_fixed(self, times, omega_vals, h0, method="pade_fixed", dt_fixed=0.01):
        self.set_linear_path(times, omega_vals)
        t_min, t_max = float(times[0]), float(times[-1])
        fixed_times = np.arange(t_min, t_max + dt_fixed/2, dt_fixed)

        # Compute generator matrices
        G_seq = []
        for i in range(len(fixed_times) - 1):
            t_start, t_end = fixed_times[i], fixed_times[i+1]
            G = self.fixed_step_generator(t_start, t_end)
            G_seq.append(G)
        G_seq = jnp.array(G_seq)
        matrix_norms = jnp.array([jnp.linalg.norm(G, ord=1) for G in G_seq])

        def compute_expm(G):
            R, s = expm_pade_with_scaling(G)
            return R, s
        exp_seq, scales = jax.vmap(compute_expm)(G_seq)

        # Associative scan
        def matmul(x, y):
            return x @ y
        cumulative_products = jax.lax.associative_scan(matmul, exp_seq[::-1])
        h_final = cumulative_products[-1] @ h0

        return h_final, fixed_times, matrix_norms, scales

    # Midpoint rule
    def magnus2(self, h, t, dt):

        t_mid = t + dt/2
        G_mid = self.evaluate_generator_at_time(t_mid)
        Omega = dt * G_mid
        exp_Omega, _ = expm_pade_with_scaling(Omega)
        return exp_Omega @ h

    # Step doubling
    def step_doubling(self, h, t, dt):
        h_coarse = self.magnus2(h, t, dt)
        h_half = self.magnus2(h, t, dt/2)
        h_fine = self.magnus2(h_half, t + dt/2, dt/2)
        return h_fine, h_coarse

    # Magnus embedded method
    def magnus42_step_correct(self, h, t, dt):

        # Gauss-Legendre nodes for interval
        sqrt3_inv = 1.0 / jnp.sqrt(3.0)
        xi_minus = 0.5 * (1.0 - sqrt3_inv)
        xi_plus = 0.5 * (1.0 + sqrt3_inv)
        t_minus = t + dt * xi_minus
        t_plus = t + dt * xi_plus

        # Evaluate generator at nodes
        A_minus = self.evaluate_generator_at_time(t_minus)
        A_plus = self.evaluate_generator_at_time(t_plus)

        # Second-order Magnus term
        Omega2 = dt * 0.5 * (A_minus + A_plus)

        # Fourth-order Magnus term
        commutator = A_minus @ A_plus - A_plus @ A_minus
        Omega4 = Omega2 - (jnp.sqrt(3.0) * dt**2 / 12.0) * commutator

        # Compute matrix exponentials
        exp_Omega2, _ = expm_pade_with_scaling(Omega2)
        exp_Omega4, _ = expm_pade_with_scaling(Omega4)

        # Low-order solution
        h_low = exp_Omega2 @ h
        # High-order solution
        h_high = exp_Omega4 @ h
        return h_high, h_low

    # Adaptive methods
    def forward_adaptive(self, times, omega_vals, h0, method="adaptive_stepdoubling",
                              Atol=1e-6, Rtol=1e-3, fac_safety=0.9,
                              fac_min=0.2, fac_max=5.0, max_iters=1000):
        self.set_linear_path(times, omega_vals)
        h = h0
        t = float(times[0])
        t_end = float(times[-1])
        p = 2 if method == "adaptive_stepdoubling" else 4

        # Estimate starting step size
        dt = self.estimate_starting_step_size(h0, t, p, Atol, Rtol)
        print(f"Estimated starting step size: {dt:.6e}")

        hs, ts = [h0], [t]
        iters = 0
        rejections = 0
        while t < t_end and iters < max_iters:
            iters += 1
            if t + dt > t_end:
                dt = t_end - t
            if method == "adaptive_stepdoubling":
                h_fine, h_coarse = self.step_doubling(h, t, dt)
            elif method == "adaptive_embedded":
                h_fine, h_coarse = self.magnus42_step_correct(h, t, dt)
            else:
                raise NotImplementedError("Unknown method")

            # Error control
            sci = Atol + jnp.maximum(jnp.abs(h), jnp.abs(h_fine)) * Rtol
            local_errors = jnp.abs((h_fine - h_coarse) / sci)
            err = jnp.max(local_errors)
            accept = err <= 1.0
            if accept:
                t += dt
                h = h_fine
                hs.append(h)
                ts.append(t)
            else:
                rejections += 1

            # Step size update
            dt_new = dt * fac_safety * (1.0 / jnp.maximum(err, 1e-6))**(1.0/(p+1))
            dt = float(jnp.clip(dt_new, fac_min*dt, fac_max*dt))
        print(f"Method: {method}, Total iterations: {iters}, Rejections: {rejections}")
        return jnp.array(ts), jnp.array(hs)

# Euler ground truth
def euler_ground_truth(A_matrices, delta_omega, h0, num_substeps=100000, g_scale=1):
    I = jnp.eye(h0.shape[0], dtype=h0.dtype)
    def step(h, dW):
        dW_sub = dW / num_substeps
        G_sub = g_scale * jnp.einsum('i,ijk->jk', dW_sub, A_matrices)
        def substep(h_inner, _):
            return (I + G_sub) @ h_inner, None
        h_final, _ = jax.lax.scan(substep, h, None, length=num_substeps)
        return h_final, None
    h_final, _ = jax.lax.scan(step, h0, delta_omega)
    return h_final

# Scipy expm
def scipy_method(A_matrices, delta_omega, h0, g_scale):
    h = np.array(h0, dtype=float)
    for t in range(delta_omega.shape[0]):
        G_seq = g_scale * np.einsum('i,ijk->jk', np.array(delta_omega[t]), np.array(A_matrices))
        h = scipy_expm(G_seq) @ h
    return jnp.array(h)

# Pade per observation
def pade_obs_method(A_matrices, delta_omega, h0, g_scale):
    h = jnp.array(h0, dtype=DTYPE_FP64)
    s_list = []
    norms = []
    for t in range(delta_omega.shape[0]):
        G = g_scale * jnp.einsum('i,ijk->jk', delta_omega[t], A_matrices)
        G_norm = jnp.linalg.norm(G, ord=1)
        E, s = expm_pade_with_scaling(G)
        h = E @ h
        s_list.append(int(s))
        norms.append(float(G_norm))
    return h, np.array(s_list, dtype=int), np.array(norms, dtype=float)


In [21]:
# GPU warmup
print("GPU warmup")
key_warmup = jax.random.PRNGKey(0)
input_dim_warmup = 16
hidden_dim_warmup = 16
num_obs_warmup = 50
A_warmup = generate_stiff_A(key_warmup, hidden_dim_warmup, input_dim_warmup)
t_warmup, omega_warmup, delta_warmup = generate_paths(key_warmup, num_obs_warmup, input_dim_warmup)
h0_warmup = jnp.ones(hidden_dim_warmup, dtype=DTYPE_FP64)
model_warmup = LinearCDE(input_dim_warmup, hidden_dim_warmup, A_warmup, g_scale=0.01)

# Warmup each method
_ = euler_ground_truth(A_warmup, delta_warmup, h0_warmup, num_substeps=1000, g_scale=0.01)
_ = scipy_method(A_warmup, delta_warmup, h0_warmup, 0.01)
_ = pade_obs_method(A_warmup, delta_warmup, h0_warmup, 0.01)
_ = model_warmup.forward_fixed(t_warmup, omega_warmup, h0_warmup, method="pade_fixed", dt_fixed=1.0)
_ = model_warmup.forward_adaptive(t_warmup, omega_warmup, h0_warmup, method="adaptive_stepdoubling")
_ = model_warmup.forward_adaptive(t_warmup, omega_warmup, h0_warmup, method="adaptive_embedded")
print("GPU warmup completed \n")

# Run all experiments
num_runs = 10
dt_fixed_values = [1, 2, 5, 10]
stiffness_factors = [1, 5, 10, 20]
all_results = {}

# Results storage
for dt_fixed in dt_fixed_values:
    all_results[dt_fixed] = {}
    for stiff_fac in stiffness_factors:
        all_results[dt_fixed][stiff_fac] = []

methods = [
    "Euler",
    "SciPy expm",
    "Pade per-observation",
    "Fixed step",
    "Adaptive step doubling",
    "Adaptive Magnus42"
]

print(f"Running {num_runs} experiments for each dt_fixed and stiffness_factor combination...")

for dt_fixed in dt_fixed_values:
    print(f"\nTesting dt_fixed = {dt_fixed}")
    for stiff_fac in stiffness_factors:
        print(f"  Stiffness factor = {stiff_fac}")
        for run_idx in range(num_runs):
            print(f"    Run {run_idx + 1}/{num_runs}")

            # Generate new data for this run
            key = jax.random.PRNGKey(42 + run_idx)
            input_dim = 32
            hidden_dim = 32
            num_obs = 200
            g_scale = 0.01
            h0 = jnp.ones(hidden_dim, dtype=DTYPE_FP64)
            t_vals, omega_vals, _ = generate_paths(key, num_obs, input_dim)
            delta_omega = jnp.diff(omega_vals, axis=0)
            A_mats = generate_stiff_A(key, hidden_dim, input_dim, stiffness_factor=stiff_fac)
            model = LinearCDE(input_dim, hidden_dim, A_mats, g_scale=g_scale)

            # Run all methods for this experiment
            results = {}

            # Euler ground truth
            t0 = time.time()
            h_euler = euler_ground_truth(A_mats, delta_omega, h0, num_substeps=1000, g_scale=g_scale)
            euler_time = time.time() - t0
            euler_norm = jnp.linalg.norm(h_euler)
            results["Euler"] = {'time': euler_time, 'rel_error': 0.0}

            # SciPy expm
            t0 = time.time()
            h_scipy = scipy_method(A_mats, delta_omega, h0, g_scale)
            scipy_time = time.time() - t0
            scipy_rel_error = float(jnp.linalg.norm(h_scipy - h_euler) / euler_norm)
            results["SciPy expm"] = {'time': scipy_time, 'rel_error': scipy_rel_error}

            # Pade per observation
            t0 = time.time()
            h_pade_obs, _, _ = pade_obs_method(A_mats, delta_omega, h0, g_scale)
            pade_obs_time = time.time() - t0
            pade_obs_rel_error = float(jnp.linalg.norm(h_pade_obs - h_euler) / euler_norm)
            results["Pade per-observation"] = {'time': pade_obs_time, 'rel_error': pade_obs_rel_error}

            # Fixed step
            t0 = time.time()
            h_pade_parallel, _, _, _ = model.forward_fixed(
                t_vals, omega_vals, h0, method="pade_fixed", dt_fixed=dt_fixed
            )
            pade_parallel_time = time.time() - t0
            pade_parallel_rel_error = float(jnp.linalg.norm(h_pade_parallel - h_euler) / euler_norm)
            results["Fixed step"] = {'time': pade_parallel_time, 'rel_error': pade_parallel_rel_error}

            # Adaptive Step doubling
            t0 = time.time()
            ts_sd, hs_sd = model.forward_adaptive(t_vals, omega_vals, h0, method="adaptive_stepdoubling")
            adaptive_sd_time = time.time() - t0
            adaptive_sd_rel_error = float(jnp.linalg.norm(hs_sd[-1] - h_euler) / euler_norm)
            results["Adaptive step doubling"] = {'time': adaptive_sd_time, 'rel_error': adaptive_sd_rel_error}

            # Embedded Magnus
            t0 = time.time()
            ts_emb, hs_emb = model.forward_adaptive(t_vals, omega_vals, h0, method="adaptive_embedded")
            adaptive_emb_time = time.time() - t0
            adaptive_emb_rel_error = float(jnp.linalg.norm(hs_emb[-1] - h_euler) / euler_norm)
            results["Adaptive Magnus42"] = {'time': adaptive_emb_time, 'rel_error': adaptive_emb_rel_error}

            # Store results
            all_results[dt_fixed][stiff_fac].append(results)

# Aggregate results for all combinations
print(f"\n{'='*80}")
print("Results")
print(f"{'='*80}")

for dt_fixed in dt_fixed_values:
    for stiff_fac in stiffness_factors:
        print(f"\ndt_fixed = {dt_fixed}, stiffness_factor = {stiff_fac}")
        print(f"{'Method':<35} {'Rel Error':<25} {'Time (s)':<20}")
        print("-" * 80)
        aggregated_results = {}
        for method in methods:
            times = [run[method]['time'] for run in all_results[dt_fixed][stiff_fac]]
            rel_errors = [run[method]['rel_error'] for run in all_results[dt_fixed][stiff_fac]]
            aggregated_results[method] = {
                'time_mean': jnp.mean(jnp.array(times)),
                'time_std': jnp.std(jnp.array(times)),
                'rel_error_mean': jnp.mean(jnp.array(rel_errors)),
                'rel_error_std': jnp.std(jnp.array(rel_errors))}

        # Results
        for method in methods:
            stats = aggregated_results[method]
            time_str = f"{stats['time_mean']:.6f} ± {stats['time_std']:.6f}"
            if stats['rel_error_mean'] == 0.0:
                error_str = "0.00e+00 ± 0.00e+00"
            else:
                error_str = f"{stats['rel_error_mean']:.2e} ± {stats['rel_error_std']:.2e}"

            print(f"{method:<35} {error_str:<25} {time_str:<20}")

GPU warmup
Estimated starting step size: 7.923867e-02
Method: adaptive_stepdoubling, Total iterations: 267, Rejections: 132
Estimated starting step size: 2.184552e-01
Method: adaptive_embedded, Total iterations: 179, Rejections: 99
GPU warmup completed 

Running 10 experiments for each dt_fixed and stiffness_factor combination...

Testing dt_fixed = 1
  Stiffness factor = 1
    Run 1/10
Estimated starting step size: 8.309550e-02
Method: adaptive_stepdoubling, Total iterations: 624, Rejections: 366
Estimated starting step size: 2.247742e-01
Method: adaptive_embedded, Total iterations: 25, Rejections: 2
    Run 2/10
Estimated starting step size: 1.076926e-01
Method: adaptive_stepdoubling, Total iterations: 614, Rejections: 364
Estimated starting step size: 2.626101e-01
Method: adaptive_embedded, Total iterations: 26, Rejections: 4
    Run 3/10
Estimated starting step size: 1.848991e-01
Method: adaptive_stepdoubling, Total iterations: 660, Rejections: 390
Estimated starting step size: 3.6

In [5]:
# GPU warmup
print("GPU warmup")
key_warmup = jax.random.PRNGKey(0)
input_dim_warmup = 16
hidden_dim_warmup = 16
num_obs_warmup = 50
A_warmup = generate_stiff_A(key_warmup, hidden_dim_warmup, input_dim_warmup)
t_warmup, omega_warmup, delta_warmup = generate_paths(key_warmup, num_obs_warmup, input_dim_warmup)
h0_warmup = jnp.ones(hidden_dim_warmup, dtype=DTYPE_FP64)
model_warmup = LinearCDE(input_dim_warmup, hidden_dim_warmup, A_warmup, g_scale=0.01)

# Warmup each method
_ = euler_ground_truth(A_warmup, delta_warmup, h0_warmup, num_substeps=1000, g_scale=0.01)
_ = scipy_method(A_warmup, delta_warmup, h0_warmup, 0.01)
_ = pade_obs_method(A_warmup, delta_warmup, h0_warmup, 0.01)
_ = model_warmup.forward_fixed(t_warmup, omega_warmup, h0_warmup, method="pade_fixed", dt_fixed=1.0)
_ = model_warmup.forward_adaptive(t_warmup, omega_warmup, h0_warmup, method="adaptive_stepdoubling")
_ = model_warmup.forward_adaptive(t_warmup, omega_warmup, h0_warmup, method="adaptive_embedded")
print("GPU warmup completed \n")

# Run all experiments
num_runs = 10
dt_fixed_values = [1, 2, 5, 10]
stiffness_factors = [50]
all_results = {}

# Results storage
for stiff_fac in stiffness_factors:
    all_results[stiff_fac] = {}
    for dt_fixed in dt_fixed_values:
        all_results[stiff_fac][dt_fixed] = []
    all_results[stiff_fac]['adaptive'] = []

fixed_methods = [
    "Euler",
    "SciPy expm",
    "Pade per-observation",
    "Fixed step"
]

adaptive_methods = [
    "Adaptive step doubling",
    "Adaptive Magnus42"
]

print(f"Running {num_runs} experiments for each dt_fixed and stiffness_factor combination...")

for stiff_fac in stiffness_factors:
    print(f"\nTesting stiffness factor = {stiff_fac}")
    print("  Running adaptive methods...")
    for run_idx in range(num_runs):
        print(f"    Adaptive run {run_idx + 1}/{num_runs}")

        # Generate new data for this run
        key = jax.random.PRNGKey(42 + run_idx)
        input_dim = 32
        hidden_dim = 32
        num_obs = 200
        g_scale = 0.01
        h0 = jnp.ones(hidden_dim, dtype=DTYPE_FP64)
        t_vals, omega_vals, _ = generate_paths(key, num_obs, input_dim)
        delta_omega = jnp.diff(omega_vals, axis=0)
        A_mats = generate_stiff_A(key, hidden_dim, input_dim, stiffness_factor=stiff_fac)
        model = LinearCDE(input_dim, hidden_dim, A_mats, g_scale=g_scale)

        # Euler ground truth
        h_euler = euler_ground_truth(A_mats, delta_omega, h0, num_substeps=1000, g_scale=g_scale)
        euler_norm = jnp.linalg.norm(h_euler)

        # Run adaptive methods
        adaptive_results = {}

        # Adaptive Step doubling
        t0 = time.time()
        ts_sd, hs_sd = model.forward_adaptive(t_vals, omega_vals, h0, method="adaptive_stepdoubling")
        adaptive_sd_time = time.time() - t0
        adaptive_sd_rel_error = float(jnp.linalg.norm(hs_sd[-1] - h_euler) / euler_norm)
        adaptive_results["Adaptive step doubling"] = {'time': adaptive_sd_time, 'rel_error': adaptive_sd_rel_error}

        # Embedded Magnus
        t0 = time.time()
        ts_emb, hs_emb = model.forward_adaptive(t_vals, omega_vals, h0, method="adaptive_embedded")
        adaptive_emb_time = time.time() - t0
        adaptive_emb_rel_error = float(jnp.linalg.norm(hs_emb[-1] - h_euler) / euler_norm)
        adaptive_results["Adaptive Magnus42"] = {'time': adaptive_emb_time, 'rel_error': adaptive_emb_rel_error}

        all_results[stiff_fac]['adaptive'].append(adaptive_results)

    # Fixed step methods
    for dt_fixed in dt_fixed_values:
        print(f"  Testing dt_fixed = {dt_fixed}")
        for run_idx in range(num_runs):
            print(f"    Fixed step run {run_idx + 1}/{num_runs}")

            # Generate same data as adaptive methods for this run
            key = jax.random.PRNGKey(42 + run_idx)
            input_dim = 32
            hidden_dim = 32
            num_obs = 200
            g_scale = 0.01
            h0 = jnp.ones(hidden_dim, dtype=DTYPE_FP64)
            t_vals, omega_vals, _ = generate_paths(key, num_obs, input_dim)
            delta_omega = jnp.diff(omega_vals, axis=0)
            A_mats = generate_stiff_A(key, hidden_dim, input_dim, stiffness_factor=stiff_fac)
            model = LinearCDE(input_dim, hidden_dim, A_mats, g_scale=g_scale)

            # Fixed methods
            fixed_results = {}

            # Euler ground truth
            t0 = time.time()
            h_euler = euler_ground_truth(A_mats, delta_omega, h0, num_substeps=1000, g_scale=g_scale)
            euler_time = time.time() - t0
            euler_norm = jnp.linalg.norm(h_euler)
            fixed_results["Euler"] = {'time': euler_time, 'rel_error': 0.0}

            # SciPy expm
            t0 = time.time()
            h_scipy = scipy_method(A_mats, delta_omega, h0, g_scale)
            scipy_time = time.time() - t0
            scipy_rel_error = float(jnp.linalg.norm(h_scipy - h_euler) / euler_norm)
            fixed_results["SciPy expm"] = {'time': scipy_time, 'rel_error': scipy_rel_error}

            # Pade per observation
            t0 = time.time()
            h_pade_obs, _, _ = pade_obs_method(A_mats, delta_omega, h0, g_scale)
            pade_obs_time = time.time() - t0
            pade_obs_rel_error = float(jnp.linalg.norm(h_pade_obs - h_euler) / euler_norm)
            fixed_results["Pade per-observation"] = {'time': pade_obs_time, 'rel_error': pade_obs_rel_error}

            # Fixed step
            t0 = time.time()
            h_pade_parallel, _, _, _ = model.forward_fixed(
                t_vals, omega_vals, h0, method="pade_fixed", dt_fixed=dt_fixed
            )
            pade_parallel_time = time.time() - t0
            pade_parallel_rel_error = float(jnp.linalg.norm(h_pade_parallel - h_euler) / euler_norm)
            fixed_results["Fixed step"] = {'time': pade_parallel_time, 'rel_error': pade_parallel_rel_error}

            all_results[stiff_fac][dt_fixed].append(fixed_results)

# Aggregate results
print(f"\n{'='*80}")
print("Results")
print(f"{'='*80}")

for stiff_fac in stiffness_factors:
    for dt_fixed in dt_fixed_values:
        print(f"\ndt_fixed = {dt_fixed}, stiffness_factor = {stiff_fac}")
        print(f"{'Method':<35} {'Rel Error':<25} {'Time (s)':<20}")
        print("-" * 80)
        for method in fixed_methods:
            times = [run[method]['time'] for run in all_results[stiff_fac][dt_fixed]]
            rel_errors = [run[method]['rel_error'] for run in all_results[stiff_fac][dt_fixed]]
            time_mean = jnp.mean(jnp.array(times))
            time_std = jnp.std(jnp.array(times))
            rel_error_mean = jnp.mean(jnp.array(rel_errors))
            rel_error_std = jnp.std(jnp.array(rel_errors))
            time_str = f"{time_mean:.6f} ± {time_std:.6f}"
            if rel_error_mean == 0.0:
                error_str = "0.00e+00 ± 0.00e+00"
            else:
                error_str = f"{rel_error_mean:.2e} ± {rel_error_std:.2e}"
            print(f"{method:<35} {error_str:<25} {time_str:<20}")

        # Adaptive method results
        if dt_fixed == dt_fixed_values[0]:
            for method in adaptive_methods:
                times = [run[method]['time'] for run in all_results[stiff_fac]['adaptive']]
                rel_errors = [run[method]['rel_error'] for run in all_results[stiff_fac]['adaptive']]
                time_mean = jnp.mean(jnp.array(times))
                time_std = jnp.std(jnp.array(times))
                rel_error_mean = jnp.mean(jnp.array(rel_errors))
                rel_error_std = jnp.std(jnp.array(rel_errors))
                time_str = f"{time_mean:.6f} ± {time_std:.6f}"
                error_str = f"{rel_error_mean:.2e} ± {rel_error_std:.2e}"
                print(f"{method:<35} {error_str:<25} {time_str:<20}")
        else:
            print(f"{'Adaptive methods':<35} {'(same as dt_fixed=1)':<45}")

GPU warmup
Estimated starting step size: 7.923867e-02
Method: adaptive_stepdoubling, Total iterations: 267, Rejections: 132
Estimated starting step size: 2.184552e-01
Method: adaptive_embedded, Total iterations: 179, Rejections: 99
GPU warmup completed 

Running 10 experiments for each dt_fixed and stiffness_factor combination...

Testing stiffness factor = 50
  Running adaptive methods...
    Adaptive run 1/10
Estimated starting step size: 2.450431e-02
Method: adaptive_stepdoubling, Total iterations: 763, Rejections: 380
Estimated starting step size: 1.078657e-01
Method: adaptive_embedded, Total iterations: 713, Rejections: 418
    Adaptive run 2/10
Estimated starting step size: 3.303682e-02
Method: adaptive_stepdoubling, Total iterations: 846, Rejections: 422
Estimated starting step size: 1.292406e-01
Method: adaptive_embedded, Total iterations: 665, Rejections: 383
    Adaptive run 3/10
Estimated starting step size: 1.937422e-02
Method: adaptive_stepdoubling, Total iterations: 709, 