Let’s build the smallest possible irregular-time dataset where a continuous-time model (Neural ODE–style) should beat a plain 3-layer MLP that relies on uniform resampling.

Quick intuition (why this works): with irregular timestamps, a fixed-grid MLP must first squash each example onto a uniform grid (zero-fill / interpolate). That throws away timing detail. A Neural-ODE pipeline can consume the actual event times and integrate to any query time, which is exactly where continuous-time models have been shown to help on uneven time series (e.g., ODE-RNN / Latent ODE, Neural CDE).

Below is self-contained JAX code that:
- Synthesizes a 1D latent ODE process with per-sequence decay rate and irregular, noisy observation times then visualize it.
- Trains two models to predict the value at a future query time $t$ given the irregular observations:
    1. Baseline MLP (3 layers) fed a uniformly resampled vector (plus mask & Δt channels).
    2. Neural-ODE model that extracts a rate parameter from the irregular events with a tiny encoder, then integrates $\frac{dy}{dt} = \alpha_\theta y$ using `diffrax.diffeqsolve` to $t$.
    3. Compares their test performance. 

> Expectation: on this irregular dataset, the Neural-ODE variant typically achieves lower MSE because it respects true event times instead of relying on coarse resampling. This matches the literature’s advantage of continuous-time models on uneven data.

In [None]:
# Synthesizes a 1D latent ODE process with per-sequence decay rate and irregular, noisy observation times.
# Trains two models to predict the value at a future query time t given the irregular observations.
import math, functools, numpy as np
import jax, jax.numpy as jnp, jax.random as jr
import equinox as eqx
import optax
import diffrax as dx

# ----------------------------
# 1) Synthetic irregular dataset
# ----------------------------
# Latent true dynamics per sequence: dy/dt = -w * y, with y(0)=y0
# We observe noisy y at irregular times in [0, T_obs], and want y(t_star).
key = jr.PRNGKey(0)
T_obs = 1.0
t_star = 1.3                      # query time is beyond last obs (mild extrapolation)
N_train, N_val = 4000, 1000
min_events, max_events = 4, 16    # irregular # of events per sequence
sigma_noise = 0.03

def sample_sequence(key):
    k1, k2, k3 = jr.split(key, 3)
    # per-example decay rate and initial value
    w = jr.lognormal(k1, sigma=0.4) + 0.2     # positive, varied
    y0 = jr.normal(k2) * 0.5 + 1.0
    # irregular times: sorted uniform then randomly drop some
    n = jr.randint(k3, (), min_events, max_events+1)
    ts = jnp.sort(jr.uniform(k3, shape=(n,), minval=0.0, maxval=T_obs))
    y_clean = y0 * jnp.exp(-w * ts)
    y_obs = y_clean + jr.normal(k3, shape=ts.shape) * sigma_noise
    y_target = y0 * jnp.exp(-w * t_star)     # ground truth at query time
    return ts, y_obs, y_target

def batch_dataset(key, N):
    keys = jr.split(key, N)
    data = [sample_sequence(k) for k in keys]
    # jagged to python lists
    ts_list = [d[0] for d in data]
    ys_list = [d[1] for d in data]
    target = jnp.stack([d[2] for d in data])
    return ts_list, ys_list, target

key_train, key_val = jr.split(key)
train_ts, train_ys, train_ystar = batch_dataset(key_train, N_train)
val_ts,   val_ys,   val_ystar   = batch_dataset(key_val,   N_val)

# ----------------------------
# 2) Utilities: make uniform-grid tensors for the baseline
# ----------------------------
GRID = 16
grid = jnp.linspace(0.0, T_obs, GRID)

def to_uniform_grid(ts, ys):
    # simple 0-order hold onto the nearest left grid point
    # Also produce a mask (which grid bins had an observation) and Δt since last obs per bin.
    # NOTE: this is deliberately simplistic; it's what hurts the baseline.
    idx = jnp.searchsorted(grid, ts, side="right") - 1
    idx = jnp.clip(idx, 0, GRID-1)
    y_grid = jnp.zeros((GRID,))
    m_grid = jnp.zeros((GRID,))
    y_grid = y_grid.at[idx].set(ys)
    m_grid = m_grid.at[idx].set(1.0)
    # time since last obs per bin
    last_t = -jnp.inf * jnp.ones((GRID,))
    last_t = last_t.at[idx].set(ts)
    # fill missing with last seen time going left->right
    last = -1e9
    dt_grid = []
    for g, t in enumerate(grid):
        last = jnp.where(m_grid[g]>0, last_t[g], last)
        dt_grid.append(jnp.maximum(0.0, t - jnp.maximum(last, 0.0)))
    dt_grid = jnp.array(dt_grid)
    # input features: [y_grid, mask, dt_grid, t_star - grid]  (the last lets MLP know the query offset)
    tdiff = jnp.full((GRID,), t_star) - grid
    x = jnp.stack([y_grid, m_grid, dt_grid, tdiff], axis=-1)   # (GRID, 4)
    return x.reshape(-1)                                       # (GRID*4,)

def pack_batch_uniform(batch_idx):
    return train_uniform[jnp.array(batch_idx)]

def pack_targets(y_star, batch_idx):
    return y_star[jnp.array(batch_idx)]

# Precompute uniform-grid tensors once so JIT stays happy
train_uniform = jnp.stack([to_uniform_grid(t, y) for t, y in zip(train_ts, train_ys)])
val_uniform = jnp.stack([to_uniform_grid(val_ts[i], val_ys[i]) for i in range(N_val)])

def pad_irregular(ts_list, ys_list, max_len):
    ts_padded, ys_padded, mask = [], [], []
    for ts, ys in zip(ts_list, ys_list):
        n = ts.shape[0]
        pad = max_len - int(n)
        ts_padded.append(jnp.pad(ts, (0, pad)))
        ys_padded.append(jnp.pad(ys, (0, pad)))
        mask.append(jnp.concatenate([jnp.ones(n, dtype=jnp.float32), jnp.zeros(pad, dtype=jnp.float32)]))
    return jnp.stack(ts_padded), jnp.stack(ys_padded), jnp.stack(mask)

train_ts_pad, train_ys_pad, train_mask = pad_irregular(train_ts, train_ys, max_events)
val_ts_pad, val_ys_pad, val_mask = pad_irregular(val_ts, val_ys, max_events)

# ----------------------------
# 3) Baseline: 3-layer MLP on uniform grid
# ----------------------------
class BaselineMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, key):
        self.mlp = eqx.nn.MLP(
            in_size=GRID*4, out_size=1, width_size=128, depth=3,
            activation=jax.nn.silu, key=key
        )
    def __call__(self, x):
        return self.mlp(x).squeeze(-1)

# ----------------------------
# 4) Neural-ODE model:
#    - tiny encoder predicts a scalar rate a_theta from (irregular) (t_i, y_i)
#    - ODE: dy/dt = a_theta * y
#    - integrate y(0) estimated from first obs (or small net) to t_star with diffrax
# ----------------------------
class IrregularEncoder(eqx.Module):
    # Very small set encoder: average of per-event embeddings of [t, y, (t - prev_t)]
    # (kept tiny on purpose; NODE's advantage stems from using true times & integration)
    event_mlp: eqx.nn.MLP
    head: eqx.nn.MLP
    def __init__(self, key):
        k1, k2 = jr.split(key)
        self.event_mlp = eqx.nn.MLP(in_size=3, out_size=32, width_size=64, depth=2,
                                    activation=jax.nn.tanh, key=k1)
        self.head = eqx.nn.MLP(in_size=32, out_size=2, width_size=64, depth=2,
                               activation=jax.nn.tanh, key=k2)  # outputs [a_theta, y0_hat]
    def __call__(self, ts, ys, mask):
        # ts, ys are padded to length max_events with mask indicating valid entries
        dts = jnp.diff(ts, prepend=ts[:1])
        feats = jnp.stack([ts, ys, dts], axis=-1)         # (n, 3)
        emb = jax.vmap(self.event_mlp)(feats)             # (n, 32)
        mask = mask[:, None]
        masked = emb * mask
        denom = jnp.maximum(mask.sum(), 1.0)
        pooled = masked.sum(axis=0) / denom               # (32,)
        out = self.head(pooled)                           # (2,)
        a_theta, y0_hat = out[0], out[1]
        return a_theta, y0_hat

class NODEPredictor(eqx.Module):
    enc: IrregularEncoder
    def __init__(self, key):
        self.enc = IrregularEncoder(key)

    def __call__(self, ts, ys, mask):
        # Encode irregular events to get rate and initial
        a_theta, y0_hat = self.enc(ts, ys, mask)

        # Define ODE: dy/dt = a_theta * y
        def vf(t, y, args):
            a = args
            return a * y

        term = dx.ODETerm(vf)
        solver = dx.Tsit5()
        # Integrate from t=0 to t_star, starting at y0_hat
        sol = dx.diffeqsolve(
            term, solver, t0=0.0, t1=t_star, dt0=1e-2,
            y0=y0_hat, args=a_theta, saveat=dx.SaveAt(t1=True)
        )
        return sol.ys  # scalar

# ----------------------------
# 5) Training loops
# ----------------------------
@eqx.filter_jit
def mse(a, b): return jnp.mean((a - b) ** 2)

def train_baseline():
    key = jr.PRNGKey(1)
    model = BaselineMLP(key)
    opt = optax.adam(3e-3)
    opt_state = opt.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_value_and_grad
    def loss_fn(model, batch_idx):
        x = pack_batch_uniform(batch_idx)      # (B, GRID*4)
        y = pack_targets(train_ystar, batch_idx)                   # (B,)
        preds = jax.vmap(model)(x)
        return mse(preds, y)

    @eqx.filter_jit
    def step(model, opt_state, batch_idx):
        loss, grads = loss_fn(model, batch_idx)
        updates, opt_state = opt.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    BATCH = 128
    steps = 1500
    for i in range(steps):
        idx = np.random.choice(N_train, size=BATCH, replace=False)
        model, opt_state, loss = step(model, opt_state, idx)
        if (i+1) % 200 == 0:
            # quick val
            val_idx = np.arange(N_val)
            xval = val_uniform
            yval = val_ystar
            preds = jax.vmap(model)(xval)
            print(f"[Baseline] step {i+1:4d}  train_loss={loss.item():.4e}  val_mse={mse(preds, yval).item():.4e}")
    return model

def train_node():
    key = jr.PRNGKey(2)
    model = NODEPredictor(key)
    opt = optax.adam(3e-3)
    opt_state = opt.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_value_and_grad
    def loss_fn(model, ts_batch, ys_batch, mask_batch, targets):
        preds = jax.vmap(lambda t, y, m: model(t, y, m))(ts_batch, ys_batch, mask_batch)
        return mse(preds, targets)

    @eqx.filter_jit
    def step(model, opt_state, batch_idx):
        batch_idx = jnp.array(batch_idx)
        ts_batch = train_ts_pad[batch_idx]
        ys_batch = train_ys_pad[batch_idx]
        mask_batch = train_mask[batch_idx]
        target_batch = train_ystar[batch_idx]
        loss, grads = loss_fn(model, ts_batch, ys_batch, mask_batch, target_batch)
        updates, opt_state = opt.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    BATCH = 128
    steps = 1500
    for i in range(steps):
        idx = np.random.choice(N_train, size=BATCH, replace=False)
        model, opt_state, loss = step(model, opt_state, idx)
        if (i+1) % 200 == 0:
            preds = jax.vmap(lambda t, y, m: model(t, y, m))(val_ts_pad, val_ys_pad, val_mask)
            print(f"[NODE]     step {i+1:4d}  train_loss={loss.item():.4e}  val_mse={mse(preds, val_ystar).item():.4e}")
    return model

# ----------------------------
# 6) Run both
# ----------------------------
baseline = train_baseline()
node = train_node()

# Final validation comparison
xval = val_uniform
pred_base = jax.vmap(baseline)(xval)
pred_node = jax.vmap(lambda t, y, m: node(t, y, m))(val_ts_pad, val_ys_pad, val_mask)
print("Final Val MSE -- Baseline:", mse(pred_base, val_ystar).item())
print("Final Val MSE -- NODE    :", mse(pred_node, val_ystar).item())
