# Neural ODEs with Catalax

This template demonstrates how to use Neural Ordinary Differential Equations (Neural ODEs) with the Catalax JAX library to model enzyme kinetic systems. Neural ODEs combine the power of neural networks with differential equations, enabling you to learn complex dynamics directly from data while maintaining the interpretability and physical constraints of ODE-based models.

## What are Neural ODEs?

Neural ODEs are a class of deep learning models that use neural networks to parameterize the derivative function in an ordinary differential equation. Instead of discretizing time and using fixed layers like traditional neural networks, Neural ODEs treat the hidden state as a continuous function of time, solving an ODE to compute the output.

The key advantages of Neural ODEs include:

- **Continuous Dynamics**: Model systems with continuous-time dynamics rather than discrete steps
- **Memory Efficiency**: Constant memory cost regardless of evaluation time points
- **Adaptive Computation**: Automatically adjust computational effort based on dynamics complexity
- **Physical Interpretability**: Maintain ODE structure while learning complex behaviors
- **Irregular Time Series**: Handle data with non-uniform time sampling naturally

## How Neural ODEs Work

In a Neural ODE, the dynamics are defined as:

$$\frac{dx}{dt} = f_θ(x(t), t)$$

Where $f_\theta$ is a neural network with parameters $\theta$ that learns the derivative function. The solution $x(t)$ is obtained by solving this ODE using numerical integration methods, with the neural network gradients computed using the adjoint sensitivity method for efficient backpropagation.

## Applications in Enzyme Kinetics

Neural ODEs are particularly powerful for enzyme kinetics where:

- Traditional kinetic models are too simple to capture observed behavior
- Complex regulatory mechanisms are present but not well understood
- Multiple interacting pathways create non-linear dynamics
- Data-driven discovery of new kinetic mechanisms is desired
- High-dimensional systems with many species and reactions

## Surrogate Modeling

Neural ODEs can serve as fast surrogate models for complex mechanistic models, enabling:

- **Accelerated Simulation**: Orders of magnitude faster evaluation than traditional ODE solvers
- **Bayesian Inference**: Fast surrogates enable practical MCMC sampling for parameter estimation
- **Real-time Applications**: Enable real-time model evaluation and control

## Getting Started

This template provides the basic framework for building Neural ODE models with Catalax. The neural network will learn to capture the underlying dynamics of your enzyme kinetic system directly from time-series data, providing both accurate predictions and insights into the system behavior.

Learn more about Neural ODEs with Catalax in the [Catalax documentation](https://catalax.mintlify.app/neural/neural-ode).


In [None]:
# Install all required packages
%pip install pyenzyme catalax

In [None]:
import jax.nn as jnn
import pyenzyme as pe
import catalax as ctx
import catalax.neural as ctn

In the following cell, we will load the EnzymML document from the EnzymeML Suite. The resulting object is an instance of the `EnzymeMLDocument` class, which you can inspect and re-use for your analysis. The following functions are available and compatible with the EnzymeMLDocument class:

- `pe.summary(enzmldoc)`: Print a summary of the EnzymeML document.
- `pe.plot(enzmldoc)`: Plot the EnzymeML document.
- `pe.plot_interactive(enzmldoc)`: Interactive plot of the EnzymeML document.
- `pe.to_pandas(enzmldoc)`: Convert the EnzymeML document to a pandas DataFrame.
- `pe.to_sbml(enzmldoc)`: Convert the EnzymeML document to an SBML document.
- `pe.to_petab(enzmldoc)`: Convert the EnzymeML document to a PEtab format.
- `pe.get_current()`: Get the current EnzymeML document from the EnzymeML Suite.

In [None]:
# Connect to the EnzymeML Suite
suite = pe.EnzymeMLSuite()

# Get the current EnzymeML document
enzmldoc = suite.get_current()

# Print a summary of the EnzymeML document
pe.summary(enzmldoc)

## Converting EnzymeML to Catalax

The `ctx.from_enzymeml` function converts an EnzymeML document to a Catalax dataset and model objects. The dataset contains the experimental data, and the model is a Catalax model object that you can use for parameter estimation.

In [None]:
dataset, model = ctx.from_enzymeml(enzmldoc)

# We will augment the dataset to generate more data for training
# and improve the generalization of the model
train_dataset = dataset.augment(n_augmentations=10)

## Step 1: Define the Neural ODE

As a first step, we will define the Neural ODE model. This is done by calling the `ctn.NeuralODE.from_model` function, which will determine the states that are modeled within the Neural ODE. Please note, you need to make sure that your measurement data aligns with the defined states. Otherwise, the model will not be able to learn the dynamics.

Ther inner core of a Neural ODE is a neural network, and thus requires certain hyperparameters defining its architecture:

- `width_size`: The number of neurons in the hidden layer.
- `depth`: The number of hidden layers.
- `activation`: The activation function of the neural network.

In the following cell, we will define a Neural ODE model with 4 neurons in the hidden layer and 2 hidden layers, and a `celu` activation function. CELU in particular is a good choice for the activation function, as it is a smooth and non-linear function, which is well-suited for the integration of the ODE. The choice of the activation function is critical and determines the behavior of the Neural ODE. Try to **not** use typical activation functions like `relu`, as they are not well-suited for the integration of the ODE.


In [None]:
# Create a neural ODE model
neural_ode = ctn.NeuralODE.from_model(
    model,
    width_size=4,
    depth=2,
    activation=jnn.celu,
)

## Step 2: Set up a training strategy

Now that we have defined the Neural ODE model, we can set up a training strategy. This is done by calling the `ctn.Strategy` class, which will determine the training strategy for the Neural ODE. This strategy will determine the optimizer, the loss function, the batch size, and the learning rate of the training process.

Neural networks are typically trained using gradient descent, which explores the parameter space gradually. Hence, our strategy determines how we move through this parameter space. The learning rate `lr` determines how big of a step we take in the parameter space. Too big and we might overshoot the minimum, too small and we will take forever to converge. Hence, the strategy helps us to determine when to take bigger or smaller steps.

**Tips**

- Initial training steps should only be trained on the first 15% of the data. Think of it as a warmup phase, because otherwise convergence is unstable.
- The learning rate should be set to a high value, as we want to explore the parameter space as fast as possible.
- After the warmup phase, the learning rate should be decreased to a lower value, as we want to fine-tune the parameters.
- The batch size should be set to a value that is a good compromise between memory efficiency and convergence speed.

_Note: The defaults have been set to a good compromise for this example. You can try to change the parameters and see how the training behaves._


In [None]:
# We will use L2 regularization to prevent overfitting
penalties = ctn.Penalties.for_neural_ode(l2_alpha=1e-5)

# We will use a batch size of 20 and a learning rate of 1e-3
strategy = ctn.Strategy(penalties=penalties, batch_size=20)

# We will first train on the first 15% of the data with a learning rate of 1e-3
strategy.add_step(lr=1e-3, length=0.15, steps=1000)

# Then we will train on the rest of the data with a learning rate of 1e-3
strategy.add_step(lr=1e-3, steps=1000)

# Finally, we will fine-tune the parameters with a learning rate of 1e-4
strategy.add_step(lr=1e-4, steps=1000)

# Train neural ODE
trained = neural_ode.train(
    dataset=train_dataset,
    strategy=strategy,
    print_every=10,
    weight_scale=1e-6,
    save_milestones=False,  # Set to True to save model checkpoints
    # log="progress.log", # Uncomment this line to log progress
)

# Save the trained model to a file
trained.save_to_eqx(".", "trained_neural_ode.eqx")

_Important: We have not used a validation strategy, as we are only interested in the training process. However, in practice, you should always use a validation strategy to monitor the performance of the model on unseen data. Check out the [Catalax documentation](https://catalax.mintlify.app/basic/data-management#data-splitting-and-cross-validation) for more information._

In [None]:
# Visualize the trained model fit to the data
dataset.plot(
    predictor=trained,
    show=True,
    path="trained_neural_ode.png",
)