### Plan

* ✅ Convert datasets into arrays of shape 
  * [Replicate, Time, Species]
* ✅ VMap the simulate function to iterate over each SaveAt and Y0 in parallel
  * Keep Parameters constant!
  * Each measurement has a y0 and timesteps plus data
  * In case of replicates, duplicate y0
* Upon conversion to a tensor, zero-pad data to the maximum array size
  * This addresses varying sample times
  * For beginning, lets stick with constant SaveAt array sizes for now
* Effectively, simulate and calculate residuals simultaneously
  * Jit this for maximum performance!
* Run LMFit as usual, as most of the overhead comes from function evaluation

In [None]:
try:
    from sysbiojax import Model
except ImportError:
    import sys
    !{sys.executable} -m pip install git+https://github.com/JR-1991/sysbiojax.git

    from sysbiojax import Model

In [1]:
import jax
import jax.numpy as jnp
from jax import config

config.update("jax_enable_x64", True)

In [2]:
# Initialize the model
model = Model(name="Simple menten model")

# Add species
model.add_species("s1, s2")

# Add ODEs
model.add_ode("s1", "- (v_max * s1) / ( K_m + s1)")
model.add_ode("s2", "(v_max * s1) / ( K_m + s1)")

# Add parameter values
model.parameters.v_max.value = 5.0
model.parameters.K_m.value = 100.0

model

Eq(x, Matrix([[s1, s2]]))

Eq(theta, Matrix([[K_m, v_max]]))

Eq(Derivative(s1, t), -s1*v_max/(K_m + s1))

Eq(Derivative(s2, t), s1*v_max/(K_m + s1))



In [13]:
# Create a mock dataset to fit
import numpy as np

DATASET_SIZE = 10
TIME_STEPS = 100
MAX_TIME = 100

dataset = {
    "initial_conditions": [
        {"s1": np.random.uniform(50, 200), "s2": 0.0} 
        for _ in range(DATASET_SIZE) 
    ],
    "time": jnp.array([
        [step for step in jnp.linspace(0, MAX_TIME, TIME_STEPS)]
        for _ in range(DATASET_SIZE)
    ])
}

len(dataset["initial_conditions"]), dataset["time"].shape

(10, (10, 100))

In [22]:
# Create synthetic data from the given dataset
#
# SysBioJax provides a vmap for varying timesteps
# and thus addresses irregular sampling times.
# In the future, a more general solution will be provided
# that also supports varying amounts of data points.

t0 = jnp.min(dataset["time"])
t1 = jnp.ceil(jnp.max(dataset["time"])) + 1
dt0 = 0.01

times, data = model.simulate(
    initial_conditions=dataset["initial_conditions"],
    t0=t0, t1=t1, dt0=dt0, saveat=dataset["time"], in_axes=(0, None, 0)
)

dataset["data"] = data

data.shape # (DATASET_SIZE, TIME_STEPS, SPECIES)

(10, 100, 2)

In [17]:
# TODO - write 'residuals' function for parameter estimation

In [18]:
# TODO - write LMFit wrapper for parameter estimation

In [20]:
# TODO - Write a function that takes a model and a dataset and returns a fitted model

In [21]:
# TODO - Write a measurement class that wraps data and provides it to the fitting function