# Neural Ordinary Differential Equation on MNIST

In [1]:
# import time
# import numpy as np
# import math

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

import equinox as eqx
import diffrax

# from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping

from utils import get_MNIST_dloaders

In [2]:
# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 1e-3 # learning rate for AdamW
EPOCHS = 20
HIDDEN_SIZE = 42     # Linear
PRINT_EVERY = 1
SEED = 5678

## Dataset Import

In [3]:
trainloader, testloader = get_MNIST_dloaders(batch_size=BATCH_SIZE, path='~/Data',download = False)

In [4]:
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 1, 28, 28)
(64,)
[3 9 8 9 1 1 5 9 5 0 4 8 4 1 0 6 3 7 5 0 7 7 8 8 6 2 2 3 6 1 5 8 9 0 5 0 2
 7 6 2 3 8 9 1 4 4 8 3 8 3 7 9 2 6 0 3 6 6 6 0 9 2 6 0]


## Neural ODE Model

In [5]:
## This could work!!!
class Func(eqx.Module):
    Conv1: eqx.nn.Conv2d
    Conv2: eqx.nn.Conv2d
    Conv3: eqx.nn.Conv2d
    GroupNorm: eqx.nn.GroupNorm

    def __init__(self, dim, *, key, **kwargs):
        super().__init__(**kwargs)
        # input data = (64,1,28,28)
        keys = jrandom.split(key, 3)
        self.Conv1 = eqx.nn.Conv2d(1, dim, 3, padding=1, use_bias=False, key=keys[0])
        self.GroupNorm = eqx.nn.GroupNorm(dim, dim)
        self.Conv2 = eqx.nn.Conv2d(dim, dim, 3, padding=1, use_bias=False, key=keys[1])
        self.Conv3 = eqx.nn.Conv2d(dim, 1, 1, key=keys[2])
        
    def __call__(self, t, y, args):
        y = self.Conv1(y)
        y = jnn.softplus(y)
        y = self.Conv2(y)
        y = jnn.softplus(y)
        y = self.GroupNorm(y)
        y = self.Conv3(y)
        return y

class Fc(eqx.Module):
    Conv2d: eqx.nn.Conv2d
    AdAvgPool: eqx.nn.AdaptiveAvgPool2d
    Linear: eqx.nn.Linear
    def __init__(self, *, key, **args):
        keys = jrandom.split(key,2)
        self.Conv2d = eqx.nn.Conv2d(1,1,1, key=keys[0])
        self.AdAvgPool = eqx.nn.AdaptiveAvgPool2d(4)
        self.Linear = eqx.nn.Linear(1*4*4, 10, key=keys[1])
    def __call__(self, ys):
        y = self.Conv2d(ys)
        y = self.AdAvgPool(y)
        y = jnp.ravel(y)
        y = self.Linear(y)
        return y

class NeuralODE(eqx.Module):
    func: Func
    fc: Fc

    def __init__(self, dim, *, key, **kwargs):
        super().__init__(**kwargs)
        func_key, fc_key = jrandom.split(key, 2)
        self.func = Func(dim, key = func_key)
        self.fc = Fc(key= fc_key)


    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Dopri5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        # select the last point of solution trajectory
        return jax.vmap(self.fc, in_axes=0)(solution.ys)[-1]
        # return solution.ys

## Main

In [6]:
@eqx.filter_value_and_grad
def CrossEntropyLoss(model, ti, Xi, yi):
    pred_y = jax.vmap(model, in_axes=(None, 0))(ti, Xi) 
    labels = jnn.one_hot(yi, 10)
    _loss = optax.softmax_cross_entropy(pred_y, labels)
    return jnp.mean(_loss)
CrossEntropyLoss = eqx.filter_jit(CrossEntropyLoss)

@eqx.filter_jit
def compute_accuracy(model, ti, Xi, yi):
    pred_y = jax.vmap(model, in_axes=(None, 0))(ti, Xi)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(pred_y == yi)

def evaluate(model, ti, testloader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        loss, _ = CrossEntropyLoss(model, ti, x, y)
        avg_loss += loss
        avg_acc += compute_accuracy(model, ti, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [7]:
def train(
    lr=1e-3, # learning rate for AdamW
    epochs=20,
    hidden_size=42,     # Linear
    seed=5678,
    print_every=1,
    dry_run = False, 
    model_path = None,
    save=False
):
    model_key = jrandom.PRNGKey(seed)
    iters = len(trainloader)
    if dry_run:
        epochs = 1
        iters = min(int(print_every*10), iters) 
    
    ts = jnp.linspace(0,1,2)

    model = NeuralODE(dim = hidden_size, key=model_key)

    optim = optax.adamw(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    
    @eqx.filter_jit
    def make_step(model, Xi, yi, ti,  opt_state):
        # Xi, yi = data_i
        loss, grads = CrossEntropyLoss(model, ti, Xi, yi)
        updates, opt_state = optim.update(
            grads, opt_state, eqx.filter(model, eqx.is_array)
        )
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state
    
    # # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader
            
    # Training loop
    # start = time.time()
    # for epoch, (x, y) in zip(range(epochs), infinite_trainloader()):
        # start = time.time()    
    for epoch in range(epochs):
        for iter, (x, y) in zip(
            range(iters), infinite_trainloader()
        ):
            x = x.numpy()
            y = y.numpy()
            train_loss, model, opt_state = make_step(model, x, y, ts, opt_state)
            # end = time.time()
            if dry_run and ((iter % print_every) == 0 or (iter == iters - 1)):
                print(f"{iter=}, train_loss={train_loss.item():.4f}, data_shape = ({x.shape, y.shape})")
        if (epoch % print_every) == 0 or (epoch == epochs - 1):
            test_loss, test_accuracy = evaluate(model, ts, testloader)
            print(
                f"{epoch=}, train_loss={train_loss.item():.4f}, "
                f"test_loss={test_loss.item():.4f}, test_accuracy={test_accuracy.item():.4f}"
                )
            if save:
                eqx.tree_serialise_leaves(f"./model/NODE_epoch={epoch}.eqx", model)
            # print(f"Iter: {iter}, Loss: {loss}, Computation time: {end - start}")

    return ts, model

In [8]:
# ts, model = train(dry_run = True, print_every=10)

In [9]:
ts, model = train(
    lr = LEARNING_RATE, 
    epochs = EPOCHS,
    hidden_size = HIDDEN_SIZE,     # Linear
    seed = SEED,
    print_every = PRINT_EVERY,
    dry_run = False
)

  self.pid = os.fork()
2025-02-19 20:24:55.688922: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1035] Compiling 47 configs for 3 fusions on a single thread.
2025-02-19 20:32:29.927430: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1035] Compiling 6 configs for 3 fusions on a single thread.
  self.pid = os.fork()
2025-02-19 20:33:50.226656: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1035] Compiling 6 configs for 3 fusions on a single thread.


epoch=0, train_loss=0.0975, test_loss=0.2279, test_accuracy=0.9302
epoch=1, train_loss=0.3427, test_loss=0.1841, test_accuracy=0.9421
epoch=2, train_loss=0.2006, test_loss=0.1564, test_accuracy=0.9515
epoch=3, train_loss=0.0645, test_loss=0.1393, test_accuracy=0.9586
epoch=4, train_loss=0.0533, test_loss=0.1183, test_accuracy=0.9617
epoch=5, train_loss=0.0451, test_loss=0.1032, test_accuracy=0.9691
epoch=6, train_loss=0.1255, test_loss=0.0935, test_accuracy=0.9711
epoch=7, train_loss=0.0678, test_loss=0.0926, test_accuracy=0.9707
epoch=8, train_loss=0.0342, test_loss=0.1043, test_accuracy=0.9665
epoch=9, train_loss=0.0228, test_loss=0.0809, test_accuracy=0.9753
epoch=10, train_loss=0.0110, test_loss=0.0716, test_accuracy=0.9776
epoch=11, train_loss=0.0672, test_loss=0.0761, test_accuracy=0.9771
epoch=12, train_loss=0.1491, test_loss=0.0764, test_accuracy=0.9762
epoch=13, train_loss=0.0771, test_loss=0.0666, test_accuracy=0.9790
epoch=14, train_loss=0.0465, test_loss=0.0746, test_accura

In [None]:
# # save the model
# eqx.tree_serialise_leaves(f"./model/NODE_epoch=20_250219.eqx", model)

In [20]:
# Number of parameters
# Ref: https://github.com/jax-ml/jax/discussions/6153
param_count = sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array)))
print(f"{param_count=:,.0f}K")

param_count=16,553K
