In [1]:
import jax
import jax.numpy as jnp
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu
import diffrax
import tensorflow_datasets as tfds
from jax.example_libraries import optimizers

In [2]:
import tensorflow

In [3]:
# ---------------------
# Model Definition
# ---------------------

# Define network components with stax.
hidden_dim = 64
num_classes = 10

# fc_in: projects input (784-dim) to hidden_dim.
fc_in_init, fc_in_apply = stax.serial(
    Dense(hidden_dim),
    Relu
)

# ode_net: defines the ODE dynamics on the hidden state.
# It takes a hidden state and returns its time derivative.
ode_net_init, ode_net_apply = stax.serial(
    Dense(50),
    Relu,
    Dense(hidden_dim)
)

# fc_out: maps the final hidden state to logits.
fc_out_init, fc_out_apply = stax.serial(
    Dense(num_classes)
)

In [4]:
# The full model function. Given parameters and an input x,
# it applies fc_in, evolves the hidden state using diffrax, then applies fc_out.
def apply_model(params, x):
    params_fc_in, params_ode, params_fc_out = params
    # Project input using fc_in.
    h0 = fc_in_apply(params_fc_in, x)
    
    # Define the ODE function.
    def ode_fun(t, y, args):
        # Here we ignore t (you could add time-dependence if desired).
        return ode_net_apply(params_ode, y)
    
    term = diffrax.ODETerm(ode_fun)
    # Solve the ODE from t=0 to t=1.
    sol = diffrax.diffeqsolve(
        term,
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=1.0,
        dt0=0.1,
        y0=h0,
        saveat=diffrax.SaveAt(ts=jnp.array([1.0]))
    )
    h_T = sol.ys[0]  # final hidden state at t = 1.0
    logits = fc_out_apply(params_fc_out, h_T)
    return logits

# Cross-entropy loss.
def loss_fn(params, x, y):
    logits = apply_model(params, x)
    log_probs = jax.nn.log_softmax(logits)
    one_hot = jax.nn.one_hot(y, num_classes)
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

# Accuracy function.
def accuracy_fn(params, x, y):
    logits = apply_model(params, x)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == y)


In [5]:
# ---------------------
# Data Loading & Preprocessing
# ---------------------

# Load MNIST via tensorflow_datasets.
ds_builder = tfds.builder("mnist")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
test_ds  = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))

# Preprocess: flatten images and normalize.
x_train = train_ds["image"].reshape(-1, 28*28) / 255.0
y_train = train_ds["label"]
x_test  = test_ds["image"].reshape(-1, 28*28) / 255.0
y_test  = test_ds["label"]

# Utility: create minibatches.
def get_batches(x, y, batch_size, key):
    n = x.shape[0]
    perm = jax.random.permutation(key, n)
    for i in range(0, n, batch_size):
        batch_idx = perm[i:i+batch_size]
        yield x[batch_idx], y[batch_idx]


2025-02-18 12:05:47.503014: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /Users/goutham/tensorflow_datasets/mnist/3.0.1...[0m


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

[1mDataset mnist downloaded and prepared to /Users/goutham/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [6]:
# ---------------------
# Parameter Initialization & Optimizer Setup
# ---------------------

rng = jax.random.PRNGKey(42)
rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
# Initialize fc_in.
_, params_fc_in = fc_in_init(rng1, (-1, 28*28))
# Initialize ode_net.
_, params_ode = ode_net_init(rng2, (-1, hidden_dim))
# Initialize fc_out.
_, params_fc_out = fc_out_init(rng3, (-1, hidden_dim))
params = (params_fc_in, params_ode, params_fc_out)

# Set up Adam optimizer.
opt_init, opt_update, get_params = optimizers.adam(1e-3)
opt_state = opt_init(params)


In [7]:
# ---------------------
# Training Loop
# ---------------------

num_epochs = 5
batch_size = 128

@jax.jit
def update(step, opt_state, x_batch, y_batch):
    params = get_params(opt_state)
    loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch)
    opt_state = opt_update(step, grads, opt_state)
    return opt_state, loss

for epoch in range(num_epochs):
    # Use a new PRNG key per epoch for shuffling.
    key = jax.random.PRNGKey(epoch)
    batch_losses = []
    for i, (x_batch, y_batch) in enumerate(get_batches(x_train, y_train, batch_size, key)):
        opt_state, loss_val = update(i, opt_state, x_batch, y_batch)
        batch_losses.append(loss_val)
    params = get_params(opt_state)
    train_loss = jnp.mean(jnp.array(batch_losses))
    train_acc = accuracy_fn(params, x_train, y_train)
    test_acc  = accuracy_fn(params, x_test, y_test)
    print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%, Test Acc: {test_acc*100:.2f}%")


Epoch 1, Loss: 0.3400, Train Acc: 94.89%, Test Acc: 94.78%
Epoch 2, Loss: 0.1471, Train Acc: 96.04%, Test Acc: 95.39%
Epoch 3, Loss: 0.1111, Train Acc: 96.66%, Test Acc: 95.91%
Epoch 4, Loss: 0.0887, Train Acc: 97.63%, Test Acc: 96.60%
Epoch 5, Loss: 0.0764, Train Acc: 98.06%, Test Acc: 96.74%
