In [None]:
import jax
from jax._src.typing import StaticScalar
import jax.numpy as jnp
from jax import random
import flax.linen as nn
from numpy import dtype
import optax

class SimpleMLP(nn.Module):
    @nn.compact
    def __call__(self, x, train: bool=True):
        counter = self.variable('stats', 'counter', lambda: jnp.zeros((1,), dtype=jnp.int32))
        if train:
            counter.value +=1
        x = nn.Dense(features=64)(x)
        x = nn.Dropout(rate=0.1, deterministic=not train)(x)
        gaussian_rng = self.make_rng("gaussian")
        eps = random.normal(gaussian_rng, x.shape)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x, eps

def make_train_step(model, tx):
    @jax.jit
    def train_step(params, stats, opt_state, x, y, dropout_key, gaussian_key):
        def loss_fn(params):
            (logits, _), updated_vars = model.apply(
                {'params': params, 'stats': stats}, 
                x, 
                rngs={"dropout": dropout_key, "gaussian": gaussian_key}, 
                mutable=['stats']
            )
            loss = optax.l2_loss(logits, y).mean()
            return loss, updated_vars

        (loss, new_stats), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        updates, opt_state = tx.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return loss, params, opt_state, new_stats['stats']

    return train_step

key = random.PRNGKey(0)
x_data = random.normal(key, (100, 1))
y_data = 3 * x_data + 1 + (random.normal(key, (100, 1)) * 0.1)

model = SimpleMLP()
key, init_key = random.split(key)
variables = model.init(init_key, jnp.ones((1, 1)))
params = variables['params']
stats = variables['stats']

tx = optax.adam(learning_rate=0.01)
opt_state = tx.init(params)
train_step_fn = make_train_step(model, tx)

for epoch in range(100):
    key, dropout_key, gaussian_key = random.split(key, 3)
    loss, params, opt_state, stats = train_step_fn(
        params, stats, opt_state, x_data, y_data, dropout_key, gaussian_key
    )
    if epoch % 20 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Counts: {stats['counter']}")

Epoch 1, Loss: 2.8178, Counts: [2]
Epoch 21, Loss: 0.1884, Counts: [22]
Epoch 41, Loss: 0.0604, Counts: [42]
Epoch 61, Loss: 0.0468, Counts: [62]
Epoch 81, Loss: 0.0353, Counts: [82]
