# Surrogate-Accelerated Bayesian Inference with Catalax

This template demonstrates how to perform surrogate-accelerated Bayesian parameter estimation for enzyme kinetic models using the Catalax JAX library. This approach combines the power of Bayesian inference with Neural ODE surrogate models to dramatically speed up parameter estimation while maintaining uncertainty quantification.

## What is Surrogate-Accelerated Bayesian Inference?

Surrogate-accelerated Bayesian inference uses fast neural network approximations (surrogates) of expensive differential equation models to speed up Bayesian parameter estimation. Instead of solving the full ODE system at each MCMC step, we use a pre-trained Neural ODE surrogate that can evaluate the model orders of magnitude faster.

The key advantages of surrogate-accelerated Bayesian inference include:

- **Dramatic Speed Improvements**: Neural ODE surrogates can be 100-1000x faster than traditional ODE solvers
- **Maintained Accuracy**: Well-trained surrogates preserve the accuracy of the original model
- **Full Uncertainty Quantification**: Get confidence intervals and credible regions for your parameters
- **Prior Knowledge Integration**: Incorporate existing knowledge about parameter ranges
- **Robust Predictions**: Make predictions that account for parameter uncertainty

## Neural ODE Surrogates

Neural ODEs combine neural networks with differential equations, learning to approximate the dynamics of complex systems. In this context, they serve as fast approximations of enzyme kinetic models, enabling rapid evaluation during MCMC sampling while preserving the underlying physics.

## Surrogate Hamiltonian Monte Carlo (Surrogate HMC)

This template uses Surrogate Hamiltonian Monte Carlo, which combines HMC sampling with Neural ODE surrogates. The surrogate model is first trained on the original ODE system, then used during MCMC to dramatically accelerate the sampling process while maintaining statistical rigor.

## Getting Started

This template provides the basic code to get you started with Bayesian inference using Catalax. In order to ensure that the template works, you need to assign prior distributions to the parameters. In the follwing is an example of how to assign prior distributions to the parameters.

```python
import catalax.mcmc as cmc

model.parameters["v_max"].prior = cmc.Normal(mu=0.0, sigma=1.0)
```

Surrogate-accelerated Bayesian inference **requires** a trained Neural ODE surrogate model. We recommend using the **Neural ODE template** first to train the surrogate model. Learn more about surrogate-accelerated Bayesian inference with Catalax in the [Catalax documentation](https://catalax.mintlify.app/hmc/surrogate-hmc).

In [None]:
# Install all required packages
%pip install pyenzyme catalax

In [None]:
import pyenzyme as pe
import catalax as ctx
import catalax.mcmc as cmc
import catalax.neural as ctn

In the following cell, we will load the EnzymML document from the EnzymeML Suite. The resulting object is an instance of the `EnzymeMLDocument` class, which you can inspect and re-use for your analysis. The following functions are available and compatible with the EnzymeMLDocument class:

- `pe.summary(enzmldoc)`: Print a summary of the EnzymeML document.
- `pe.plot(enzmldoc)`: Plot the EnzymeML document.
- `pe.plot_interactive(enzmldoc)`: Interactive plot of the EnzymeML document.
- `pe.to_pandas(enzmldoc)`: Convert the EnzymeML document to a pandas DataFrame.
- `pe.to_sbml(enzmldoc)`: Convert the EnzymeML document to an SBML document.
- `pe.to_petab(enzmldoc)`: Convert the EnzymeML document to a PEtab format.
- `pe.get_current()`: Get the current EnzymeML document from the EnzymeML Suite.

In [None]:
# Connect to the EnzymeML Suite
suite = pe.EnzymeMLSuite()

# Get the current EnzymeML document
enzmldoc = suite.get_current()

# Print a summary of the EnzymeML document
pe.summary(enzmldoc)

## Converting EnzymeML to Catalax

The `ctx.from_enzymeml` function converts an EnzymeML document to a Catalax dataset and model objects. The dataset contains the experimental data, and the model is a Catalax model object that you can use for parameter estimation.

In [None]:
dataset, model = ctx.from_enzymeml(enzmldoc)

# Load the trained Neural ODE surrogate model (adjust the path to the correct file)
trained = ctn.NeuralODE.from_eqx("trained_neural_ode.eqx")

## Bayesian Inference

The `ctx.optimize` function performs parameter estimation using the specified optimization algorithm. The function returns the optimized parameters, the optimized model, and the optimization history.

### Step 1: Set the prior distributions for the parameters

Before running the MCMC simulation, we need to set the prior distributions for the parameters. A prior distribution is a probability distribution that represents our beliefs about the parameters before we have any data. For instance, if you believe that a parameter is between 1 and 10, you can set a uniform prior distribution for that parameter.

In [None]:
for param in model.parameters.values():
    # We are setting a very wide uniform prior to the parameters
    # This is a placeholder and should be replaced with a more informative prior
    param.prior = cmc.priors.Uniform(low=1e-6, high=1e5)

### Step 2: Perform MCMC simulation

Now we can perform the MCMC simulation, by using the `cmc.HMC` class. You can customize the sampling process by changing the parameters of the `cmc.HMC` class:

- `num_warmup`: The number of warmup samples.
- `num_samples`: The number of samples to draw from the posterior distribution.
- `dt0`: The initial step size.
- `max_steps`: The maximum number of steps.
- `verbose`: The verbosity level.

In contrast to the usual HMC workflow, we will pass the trained Neural ODE surrogate model to the `cmc.HMC.run` function via the `surrogate` argument.

**Tips**

- HMC is very efficient, you dont need millions of samples to get a good posterior distribution.
- If the resulting sampling fails (indicated by `Rhat > 1.01`), you can try to increase the number of warmup samples or the number of samples.
- The choice of prior distributions is crucial for the quality of the posterior distribution. Use literature values if possible.


In [None]:
# Perform MCMC simulation
hmc = cmc.HMC(
    num_warmup=1000,
    num_samples=1000,
    dt0=0.1,
    max_steps=64**4,
    verbose=1,
)

results = hmc.run(
    model=model,
    dataset=dataset,
    yerrs=2.0,
    surrogate=trained,
)

# Return the fitted model
fitted_model = results.get_fitted_model()

# Print the summary
results.summary()

### Step 3: Diagnostics

Now that our Bayesian inference is done, we can visualize the results using common diagnistical plots. Using all is out of the scope of this template, but you can find more information in the [Catalax documentation](https://catalax.mintlify.app/hmc/mcmc-basic). We will use a corner plot to visualize the posterior distribution and cross-correlation of the parameters.


In [None]:
results.plot_corner(show=True, path="bayesian_inference_corner.png")

We can now also visualize the model fit to the data. Note, that this plot will display the uncertanties of the parameters as bands. Medium opaque bands correspond to 50% credibility intervals, and light opaque bands correspond to 95% credibility intervals. Wide bands indicate high uncertainty, and point to insufficient data or an overparameterized model.


In [None]:
# Visualize the optimized model fit to the data
dataset.plot(
    predictor=fitted_model,
    show=True,
    path="bayesian_inference.png",
)

## Update the EnzymeML Suite document

Once done, we can update the EnzymeML Suite document with the optimized parameters.

In [None]:
# We will update the model with the optimized parameters
updated_enzmldoc = fitted_model.to_enzymeml(enzmldoc)

# And update the EnzymeML Suite document
suite.update_current(updated_enzmldoc)