# Universal ODEs with Catalax

This template demonstrates how to use Universal Ordinary Differential Equations (Universal ODEs) with the Catalax JAX library to model systems where part of the dynamics are known and part are unknown. Universal ODEs combine mechanistic knowledge with neural networks to capture both understood physics and unknown behaviors in a single, unified model.

## What are Universal ODEs?

Universal ODEs are a hybrid modeling approach that combines traditional differential equations (representing known physics) with neural networks (representing unknown dynamics). This allows you to leverage existing mechanistic knowledge while letting the neural network learn the parts of the system that are not well understood or too complex to model explicitly.

The key advantages of Universal ODEs include:

- **Physics-Informed Learning**: Incorporate existing mechanistic knowledge while learning unknown dynamics
- **Interpretable Models**: Separate known physics from learned components for better understanding
- **Data Efficiency**: Leverage prior knowledge to reduce data requirements
- **Robust Extrapolation**: Physics constraints help the model generalize beyond training data
- **Residual Modeling**: Learn corrections to existing models or capture missing physics

## How Universal ODEs Work

In a Universal ODE, the system dynamics are split into two components:

1. **Known Physics**: Traditional differential equations representing well-understood mechanisms
2. **Unknown Dynamics**: Neural networks that learn residual or missing behaviors from data

The total dynamics become: 

$$\frac{dx}{dt} = f_{physics}(x, \theta) + f_{neural}(x, \phi)$$

Where $f_{physics}$ represents the known mechanistic model and $f_{neural}$ is a neural network that captures unknown or residual dynamics.

## Applications in Enzyme Kinetics

Universal ODEs are particularly powerful for enzyme kinetics where:

- Basic kinetic mechanisms are known but regulatory effects are complex
- Allosteric or cooperative effects are present but difficult to model explicitly
- Environmental factors influence kinetics in unknown ways
- Multiple enzymes interact in ways that are hard to capture mechanistically

## Getting Started

This template provides the basic framework for building Universal ODE models with Catalax. The neural network component will learn to capture any dynamics not explained by your mechanistic model, making it perfect for discovering new biological phenomena or improving model accuracy.

Learn more about Universal ODEs with Catalax in the [Catalax documentation](https://catalax.mintlify.app/neural/universal-ode). Also checkout [this example](https://github.com/JR-1991/Catalax/blob/master/examples/UniversalODE.ipynb) for an advanced dowstream application to recover mathematical expressions from the learned dynamics.


In [None]:
# Install all required packages
%pip install pyenzyme catalax

In [None]:
import jax.nn as jnn
import pyenzyme as pe
import catalax as ctx
import catalax.neural as ctn

In the following cell, we will load the EnzymML document from the EnzymeML Suite. The resulting object is an instance of the `EnzymeMLDocument` class, which you can inspect and re-use for your analysis. The following functions are available and compatible with the EnzymeMLDocument class:

- `pe.summary(enzmldoc)`: Print a summary of the EnzymeML document.
- `pe.plot(enzmldoc)`: Plot the EnzymeML document.
- `pe.plot_interactive(enzmldoc)`: Interactive plot of the EnzymeML document.
- `pe.to_pandas(enzmldoc)`: Convert the EnzymeML document to a pandas DataFrame.
- `pe.to_sbml(enzmldoc)`: Convert the EnzymeML document to an SBML document.
- `pe.to_petab(enzmldoc)`: Convert the EnzymeML document to a PEtab format.
- `pe.get_current()`: Get the current EnzymeML document from the EnzymeML Suite.

In [None]:
# Connect to the EnzymeML Suite
suite = pe.EnzymeMLSuite()

# Get the current EnzymeML document
enzmldoc = suite.get_current()

# Print a summary of the EnzymeML document
pe.summary(enzmldoc)

## Converting EnzymeML to Catalax

The `ctx.from_enzymeml` function converts an EnzymeML document to a Catalax dataset and model objects. The dataset contains the experimental data, and the model is a Catalax model object that you can use for parameter estimation.

In [None]:
dataset, model = ctx.from_enzymeml(enzmldoc)

# We will augment the dataset to generate more data for training
# and improve the generalization of the model
train_dataset = dataset.augment(n_augmentations=10)

## Step 1: Define the Neural ODE

As a first step, we will define the Universal ODE model. This is done by calling the `ctn.UniversalODE.from_model` function, which will determine the states that are modeled within the Universal ODE. This will include your mechanistic model into the neural model.

Ther neural part of the Universal ODE is a neural network, and thus requires certain hyperparameters defining its architecture:

- `width_size`: The number of neurons in the hidden layer. Keep it small for universal ODEs, it should not dominate the dynamics of the system.
- `depth`: The number of hidden layers. Keep it small for universal ODEs, it should not dominate the dynamics of the system.
- `activation`: The activation function of the neural network.

In the following cell, we will define a Neural ODE model with 4 neurons in the hidden layer and 2 hidden layers, and a `celu` activation function. CELU in particular is a good choice for the activation function, as it is a smooth and non-linear function, which is well-suited for the integration of the ODE. The choice of the activation function is critical and determines the behavior of the Neural ODE. Try to **not** use typical activation functions like `relu`, as they are not well-suited for the integration of the ODE.


In [None]:
# Create a universal ODE model
universal_ode = ctn.UniversalODE.from_model(
    model,
    width_size=4,
    depth=2,
    activation=jnn.celu,
)

## Step 2: Set up a training strategy

Now that we have defined the Neural ODE model, we can set up a training strategy. This is done by calling the `ctn.Strategy` class, which will determine the training strategy for the Neural ODE. This strategy will determine the optimizer, the loss function, the batch size, and the learning rate of the training process.

Neural networks are typically trained using gradient descent, which explores the parameter space gradually. Hence, our strategy determines how we move through this parameter space. The learning rate `lr` determines how big of a step we take in the parameter space. Too big and we might overshoot the minimum, too small and we will take forever to converge. Hence, the strategy helps us to determine when to take bigger or smaller steps.

**Tips**

- Initial training steps should only be trained on the first 15% of the data. Think of it as a warmup phase, because otherwise convergence is unstable.
- The learning rate should be set to a high value, as we want to explore the parameter space as fast as possible.
- After the warmup phase, the learning rate should be decreased to a lower value, as we want to fine-tune the parameters.
- The batch size should be set to a value that is a good compromise between memory efficiency and convergence speed.

_Note: The defaults have been set to a good compromise for this example. You can try to change the parameters and see how the training behaves._


In [None]:
# We will use L2 regularization to prevent overfitting
penalties = ctn.Penalties.for_universal_ode()

# We will use a batch size of 20 and a learning rate of 1e-3
strategy = ctn.Strategy(penalties=penalties, batch_size=20)

# We will first train on the first 15% of the data with a learning rate of 1e-3
strategy.add_step(lr=1e-3, length=0.15, steps=1000)

# Then we will train on the rest of the data with a learning rate of 1e-3
strategy.add_step(lr=1e-3, steps=1000)

# Finally, we will fine-tune the parameters with a learning rate of 1e-4
strategy.add_step(lr=1e-4, steps=1000)

# Train universal ODE
trained = universal_ode.train(
    dataset=train_dataset,
    strategy=strategy,
    print_every=10,
    weight_scale=1e-6,
    save_milestones=False,  # Set to True to save model checkpoints
    # log="progress.log", # Uncomment this line to log progress
)

# Save the trained model to a file
trained.save_to_eqx(".", "trained_universal_ode.eqx")

_Important: We have not used a validation strategy, as we are only interested in the training process. However, in practice, you should always use a validation strategy to monitor the performance of the model on unseen data. Check out the [Catalax documentation](https://catalax.mintlify.app/basic/data-management#data-splitting-and-cross-validation) for more information._

In [None]:
# Visualize the trained model fit to the data
dataset.plot(
    predictor=trained,
    show=True,
    path="trained_universal_ode.png",
)