# Tutorial 6: Custom Solver Integration

**Learning Objectives:**
- Use VectorBuilder with custom ODE solvers
- Integrate JAX/Diffrax for GPU-accelerated solving
- Compare different solver backends (Diffsol, Diffrax, DifferentialEquations.jl)
- Understand performance trade-offs between solvers

**Prerequisites:** Tutorials 1-2, familiarity with JAX (optional)

**Runtime:** ~15 minutes

## Introduction

While `DiffsolBuilder` provides a convenient interface for ODE fitting using the DiffSL language, `VectorBuilder` offers maximum flexibility by allowing you to use **any** ODE solver. This enables:

1. **GPU acceleration** via JAX/Diffrax
2. **Specialized solvers** from Julia's DifferentialEquations.jl
3. **Custom dynamics** that don't fit the DiffSL syntax
4. **Pre-existing code** integration

In this tutorial, we'll solve the predator-prey model using different backends and compare their performance.

In [None]:
# Import plotting utilities
import time

import diffid
import matplotlib.pyplot as plt
import numpy as np

## The Lotka-Volterra Model

The predator-prey dynamics are described by:

$$\frac{dx}{dt} = \alpha x - \beta xy \quad \text{(prey growth and predation)}$$

$$\frac{dy}{dt} = \delta xy - \gamma y \quad \text{(predator growth and death)}$$

where:
- $x$ = prey population
- $y$ = predator population  
- $\alpha$ = prey birth rate
- $\beta$ = predation rate
- $\delta$ = predator reproduction efficiency
- $\gamma$ = predator death rate

## Method 1: DiffsolBuilder (Baseline)

First, let's solve it using the built-in DiffsolBuilder as a baseline:

In [None]:
# Generate synthetic data using scipy
from scipy.integrate import solve_ivp

np.random.seed(42)

# True parameters
true_params = {
    "alpha": 1.1,  # prey birth rate
    "beta": 0.4,  # predation rate
    "delta": 0.1,  # predator efficiency
    "gamma": 0.4,  # predator death rate
}

# Time points
t_data = np.linspace(0, 15, 50)

# DiffSL model for fitting
model_str = """
in_i { alpha = 2.0/3.0, beta = 4.0/3.0, delta = 1.0, gamma = 1.0 }
x0 { 10.0 } y0 { 5.0 }
u_i {
    y1 = x0,
    y2 = y0,
}
F_i {
    alpha * y1 - beta * y1 * y2,
    delta * y1 * y2 - gamma * y2,
}
"""


# Define ODE for scipy
def lotka_volterra(t, state, alpha, beta, delta, gamma):
    x, y = state
    return [alpha * x - beta * x * y, delta * x * y - gamma * y]


# Generate "true" solution with noise using scipy
y0 = [10.0, 5.0]  # initial conditions
sol = solve_ivp(
    lotka_volterra,
    [t_data[0], t_data[-1]],
    y0,
    args=(
        true_params["alpha"],
        true_params["beta"],
        true_params["delta"],
        true_params["gamma"],
    ),
    t_eval=t_data,
    method="RK45",
)

y_true = sol.y.T  # Shape: (n_times, 2)
y_observed = y_true + np.random.normal(0, 0.5, y_true.shape)

print(f"Data shape: {y_observed.shape}")
print(f"Time span: [{t_data[0]:.1f}, {t_data[-1]:.1f}]")
print(f"Number of observations: {len(t_data)}")

In [None]:
# Visualize data
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(t_data, y_observed[:, 0], "o", label="Prey (observed)", alpha=0.6, markersize=6)
ax.plot(
    t_data, y_observed[:, 1], "s", label="Predator (observed)", alpha=0.6, markersize=6
)
ax.plot(t_data, y_true[:, 0], "--", label="Prey (true)", linewidth=2, alpha=0.8)
ax.plot(t_data, y_true[:, 1], "--", label="Predator (true)", linewidth=2, alpha=0.8)

ax.set_xlabel("Time", fontsize=12)
ax.set_ylabel("Population", fontsize=12)
ax.set_title("Predator-Prey Dynamics: Synthetic Data", fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Fit with DiffsolBuilder
start_time = time.time()

# Combine times and observations for DiffsolBuilder
# Data format: first column is time, remaining columns are observations
data = np.column_stack((t_data, y_observed))

result_diffsol = (
    diffid.DiffsolBuilder()
    .with_diffsl(model_str)
    .with_data(data)
    .with_parameter("alpha", 1.0)  # initial guess
    .with_parameter("beta", 0.3)
    .with_parameter("delta", 0.05)
    .with_parameter("gamma", 0.5)
    .with_cost(diffid.SSE())
    .with_optimiser(diffid.NelderMead().with_max_iter(500))
    .build()
    .optimise()
)

diffsol_time = time.time() - start_time

print("\n" + "=" * 60)
print("DIFFSOL RESULTS")
print("=" * 60)
print(f"True parameters:      {list(true_params.values())}")
print(f"Estimated parameters: {result_diffsol.x}")
print(f"Final SSE:            {result_diffsol.value:.6f}")
print(f"Iterations:           {result_diffsol.iterations}")
print(f"Function evals:       {result_diffsol.evaluations}")
print(f"Time:                 {diffsol_time:.3f}s")
print(f"Success:              {result_diffsol.success}")

## Method 2: VectorBuilder with JAX/Diffrax

`VectorBuilder` accepts any callable that maps parameters to predicted outputs. This enables using **JAX/Diffrax** for GPU-accelerated ODE solving with automatic differentiation.

### Advantages:
- ‚ö° GPU acceleration
- üî• JIT compilation
- üìê Automatic differentiation (gradients for free)
- üöÄ Fast iteration for gradient-based optimisers

In [None]:
# Optional: install JAX/Diffrax if not already installed
# !pip install jax jaxlib diffrax

try:
    import diffrax as dfx
    import jax.numpy as jnp
    from jax import config, jit

    # Enable float64 precision
    config.update("jax_enable_x64", True)

    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    print("‚ö†Ô∏è  JAX/Diffrax not installed. Skipping GPU example.")
    print("    Install with: pip install jax jaxlib diffrax")

In [None]:
if JAX_AVAILABLE:
    # Define dynamics in JAX
    def lotka_volterra_jax(t, state, params):
        """Lotka-Volterra dynamics in JAX."""
        x, y = state
        alpha, beta, delta, gamma = params
        return jnp.array(
            [
                alpha * x - beta * x * y,  # prey
                delta * x * y - gamma * y,  # predator
            ]
        )

    # Configure ODE solver
    solver = dfx.Tsit5()  # Tsitouras 5(4) method
    saveat = dfx.SaveAt(ts=t_data)
    term = dfx.ODETerm(lotka_volterra_jax)

    # Extract bounds for JIT
    t0, t1 = float(t_data[0]), float(t_data[-1])
    dt0 = float(t_data[1] - t_data[0])
    y0 = jnp.array([10.0, 5.0])  # initial state

    @jit
    def simulate_jax(params):
        """JAX-native ODE integration (JIT-compiled)."""
        sol = dfx.diffeqsolve(
            term,
            solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=y0,
            args=params,
            saveat=saveat,
        )
        return sol.ys  # Shape: (n_times, 2)

    def simulate_numpy(params):
        """NumPy wrapper for Diffid compatibility."""
        return np.asarray(simulate_jax(jnp.asarray(params)))

    # Warm up JIT compiler
    _ = simulate_numpy([1.0, 0.4, 0.1, 0.4])
    print("JAX/Diffrax solver ready (JIT compiled)")

In [None]:
if JAX_AVAILABLE:
    # Fit with VectorBuilder + JAX/Diffrax
    start_time = time.time()

    result_diffrax = (
        diffid.VectorBuilder()
        .with_objective(simulate_numpy)
        .with_data(y_observed)
        .with_parameter("alpha", 1.0)
        .with_parameter("beta", 0.3)
        .with_parameter("delta", 0.05)
        .with_parameter("gamma", 0.5)
        .with_cost(diffid.SSE())
        .with_optimiser(diffid.NelderMead().with_max_iter(500))
        .build()
        .optimise()
    )

    diffrax_time = time.time() - start_time

    print("\n" + "=" * 60)
    print("DIFFRAX (JAX) RESULTS")
    print("=" * 60)
    print(f"True parameters:      {list(true_params.values())}")
    print(f"Estimated parameters: {result_diffrax.x}")
    print(f"Final SSE:            {result_diffrax.value:.6f}")
    print(f"Iterations:           {result_diffrax.iterations}")
    print(f"Function evals:       {result_diffrax.evaluations}")
    print(f"Time:                 {diffrax_time:.3f}s")
    print(f"Success:              {result_diffrax.success}")
    print(f"\nSpeedup vs Diffsol:   {diffsol_time / diffrax_time:.2f}x")

## Method 3: VectorBuilder with Julia DifferentialEquations.jl

Julia's **DifferentialEquations.jl** ecosystem offers:
- üî¨ Largest collection of ODE solvers
- üéØ Specialised methods (stiff, stochastic, DAE, etc.)
- üìä Advanced features (sensitivity analysis, callbacks)

Integration via `diffeqpy` provides access to Julia's ecosystem from Python.

In [None]:
# Optional: install Julia backend
# !pip install diffeqpy
# Then in Python: from diffeqpy import install; install()

try:
    from diffeqpy import ode

    JULIA_AVAILABLE = True
except ImportError:
    JULIA_AVAILABLE = False
    print("‚ö†Ô∏è  DifferentialEquations.jl not installed. Skipping Julia example.")
    print("    Install with: pip install diffeqpy")
    print("    Then run: from diffeqpy import install; install()")

In [None]:
if JULIA_AVAILABLE:
    # Define dynamics for Julia
    def lotka_volterra_julia(u, p, t):
        """Lotka-Volterra for Julia (u, p, t) signature."""
        x, y = u
        alpha, beta, delta, gamma = p
        return [alpha * x - beta * x * y, delta * x * y - gamma * y]

    def simulate_julia(params):
        """Solve ODE using Julia's Tsit5 solver."""
        u0 = [10.0, 5.0]
        tspan = (t_data[0], t_data[-1])

        prob = ode.ODEProblem(lotka_volterra_julia, u0, tspan, params)
        sol = ode.solve(prob, ode.Tsit5(), saveat=t_data)

        # Convert to numpy array (shape: n_times x 2)
        return np.array(sol.u).T

    # Test
    test_output = simulate_julia([1.0, 0.4, 0.1, 0.4])
    print("Julia/DifferentialEquations.jl ready")
    print(f"   Output shape: {test_output.shape}")

In [None]:
if JULIA_AVAILABLE:
    # Fit with VectorBuilder + Julia
    start_time = time.time()

    result_julia = (
        diffid.VectorBuilder()
        .with_objective(simulate_julia)
        .with_data(y_observed)
        .with_parameter("alpha", 1.0)
        .with_parameter("beta", 0.3)
        .with_parameter("delta", 0.05)
        .with_parameter("gamma", 0.5)
        .with_cost(diffid.SSE())
        .with_optimiser(diffid.NelderMead().with_max_iter(500))
        .build()
        .optimise()
    )

    julia_time = time.time() - start_time

    print("\n" + "=" * 60)
    print("JULIA DIFFERENTIALEQUATIONS.JL RESULTS")
    print("=" * 60)
    print(f"True parameters:      {list(true_params.values())}")
    print(f"Estimated parameters: {result_julia.x}")
    print(f"Final SSE:            {result_julia.value:.6f}")
    print(f"Iterations:           {result_julia.iterations}")
    print(f"Function evals:       {result_julia.evaluations}")
    print(f"Time:                 {julia_time:.3f}s")
    print(f"Success:              {result_julia.success}")
    print(f"\nSpeedup vs Diffsol:   {diffsol_time / julia_time:.2f}x")

## Performance Comparison

Let's compare all three backends:

In [None]:
# Collect results
results_summary = []

results_summary.append(
    {
        "Backend": "Diffsol",
        "Time (s)": diffsol_time,
        "SSE": result_diffsol.value,
        "Iterations": result_diffsol.iterations,
        "Evaluations": result_diffsol.evaluations,
    }
)

if JAX_AVAILABLE:
    results_summary.append(
        {
            "Backend": "JAX/Diffrax",
            "Time (s)": diffrax_time,
            "SSE": result_diffrax.value,
            "Iterations": result_diffrax.iterations,
            "Evaluations": result_diffrax.evaluations,
        }
    )

if JULIA_AVAILABLE:
    results_summary.append(
        {
            "Backend": "Julia/DiffEq",
            "Time (s)": julia_time,
            "SSE": result_julia.value,
            "Iterations": result_julia.iterations,
            "Evaluations": result_julia.evaluations,
        }
    )

# Display table
print("\n" + "=" * 80)
print("BACKEND COMPARISON")
print("=" * 80)
print(f"{'Backend':<20} {'Time (s)':<12} {'SSE':<15} {'Iters':<8} {'Evals'}")
print("-" * 80)

for r in results_summary:
    print(
        f"{r['Backend']:<20} {r['Time (s)']:<12.3f} {r['SSE']:<15.3e} "
        f"{r['Iterations']:<8} {r['Evaluations']}"
    )

In [None]:
# Visualize timing comparison
if len(results_summary) > 1:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    backends = [r["Backend"] for r in results_summary]
    times = [r["Time (s)"] for r in results_summary]
    sse_values = [r["SSE"] for r in results_summary]

    # Time comparison
    colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
    bars1 = ax1.bar(
        backends, times, color=colors[: len(backends)], alpha=0.7, edgecolor="black"
    )
    ax1.set_ylabel("Time (s)", fontsize=12)
    ax1.set_title("Optimisation Time Comparison", fontsize=14, fontweight="bold")
    ax1.grid(True, axis="y", alpha=0.3)

    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        ax1.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.2f}s",
            ha="center",
            va="bottom",
            fontsize=10,
        )

    # SSE comparison
    bars2 = ax2.bar(
        backends,
        sse_values,
        color=colors[: len(backends)],
        alpha=0.7,
        edgecolor="black",
    )
    ax2.set_ylabel("SSE", fontsize=12)
    ax2.set_title("Final Objective Value", fontsize=14, fontweight="bold")
    ax2.grid(True, axis="y", alpha=0.3)

    # Add value labels
    for bar in bars2:
        height = bar.get_height()
        ax2.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.1f}",
            ha="center",
            va="bottom",
            fontsize=10,
        )

    plt.tight_layout()
    plt.show()

## Visualize Fitted Models

Compare the fitted trajectories from each backend:

In [None]:
# Generate predictions from each fitted model
t_fine = np.linspace(t_data[0], t_data[-1], 200)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot prey dynamics
ax1.plot(t_data, y_observed[:, 0], "o", label="Observed", alpha=0.5, markersize=6)
ax1.plot(t_data, y_true[:, 0], "k--", label="True", linewidth=2, alpha=0.7)

# Diffsol fit - use scipy to generate predictions with fitted parameters
sol_fitted = solve_ivp(
    lotka_volterra,
    [t_fine[0], t_fine[-1]],
    [10.0, 5.0],  # initial conditions
    args=tuple(result_diffsol.x),
    t_eval=t_fine,
    method="RK45",
)
y_pred_diffsol = sol_fitted.y.T
ax1.plot(t_fine, y_pred_diffsol[:, 0], label="Diffsol", linewidth=2)

if JAX_AVAILABLE:
    y_pred_diffrax = simulate_numpy(result_diffrax.x)
    ax1.plot(
        t_data, y_pred_diffrax[:, 0], label="JAX/Diffrax", linewidth=2, linestyle=":"
    )

if JULIA_AVAILABLE:
    y_pred_julia = simulate_julia(result_julia.x)
    ax1.plot(
        t_data, y_pred_julia[:, 0], label="Julia/DiffEq", linewidth=2, linestyle="-."
    )

ax1.set_xlabel("Time", fontsize=12)
ax1.set_ylabel("Prey Population", fontsize=12)
ax1.set_title("Prey Dynamics: Model Comparison", fontsize=14, fontweight="bold")
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot predator dynamics
ax2.plot(t_data, y_observed[:, 1], "s", label="Observed", alpha=0.5, markersize=6)
ax2.plot(t_data, y_true[:, 1], "k--", label="True", linewidth=2, alpha=0.7)
ax2.plot(t_fine, y_pred_diffsol[:, 1], label="Diffsol", linewidth=2)

if JAX_AVAILABLE:
    ax2.plot(
        t_data, y_pred_diffrax[:, 1], label="JAX/Diffrax", linewidth=2, linestyle=":"
    )

if JULIA_AVAILABLE:
    ax2.plot(
        t_data, y_pred_julia[:, 1], label="Julia/DiffEq", linewidth=2, linestyle="-."
    )

ax2.set_xlabel("Time", fontsize=12)
ax2.set_ylabel("Predator Population", fontsize=12)
ax2.set_title("Predator Dynamics: Model Comparison", fontsize=14, fontweight="bold")
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Key Takeaways

### When to Use Each Backend

**DiffsolBuilder (Diffsol):**
- Quick prototyping with DiffSL syntax
- ‚úÖ Standard ODE problems
- ‚úÖ No extra dependencies
- ‚ùå Limited to DiffSL expressiveness

**VectorBuilder + JAX/Diffrax:**
- ‚úÖ GPU acceleration for large problems
- ‚úÖ Automatic differentiation (enables gradient-based optimisers)
- ‚úÖ JIT compilation for speed
- ‚úÖ Excellent for high-dimensional problems
- ‚ùå Requires JAX ecosystem

**VectorBuilder + Julia/DifferentialEquations.jl:**
- ‚úÖ Largest solver collection (stiff, stochastic, DAE, DDE, etc.)
- ‚úÖ Advanced features (callbacks, sensitivity analysis)
- ‚úÖ Best for specialized problems
- ‚ùå Requires Julia installation

### Performance Insights

1. **JAX/Diffrax** typically fastest after JIT warmup
2. **Diffsol** excellent balance of speed and simplicity
3. **Julia/DiffEq** best for problems requiring specialized solvers
4. All backends produce equivalent parameter estimates

### VectorBuilder Flexibility

The key advantage of `VectorBuilder` is **total control**:
- Any Python callable works
- Integrate pre-existing simulation code
- Mix solver backends in the same workflow
- Enable advanced features like GPU acceleration

## Next Steps

- [Tutorial 7: Parallel Optimisation](07_parallel_optimization.ipynb) - Scale optimisation with parallelism
- [Guide: Custom Solvers](../../guides/custom-solvers.md) - Detailed integration patterns
- [API Reference: VectorBuilder](../../api-reference/python/builders.md#vectorbuilder) - Complete API documentation

## Exercises

1. **Add Gradient-Based Optimiser**: Modify the JAX/Diffrax example to use `Adam()` optimiser with gradients

2. **Custom Dynamics**: Implement a different ODE system (e.g., Lorenz attractor, SIR model) using VectorBuilder

3. **Benchmark Scaling**: Test how performance scales with:
   - Number of time points (50, 100, 500, 1000)
   - System size (2, 5, 10 state variables)
   - Problem stiffness

4. **Hybrid Approach**: Use VectorBuilder with a custom callable that switches solvers based on problem characteristics