In [None]:
import jax
import jax.numpy as jnp
import optax

def init_params(key, layer_dims):
    params = []
    for i in range(len(layer_dims) - 1):
        key, subkey = jax.random.split(key)
        W = jax.random.normal(subkey, (layer_dims[i], layer_dims[i+1])) * jnp.sqrt(2/layer_dims[i])
        b = jnp.zeros(layer_dims[i+1])
        params.append((W, b))
    return params

def forward(params, x):
    a = x
    for W, b in params[:-1]:
        z = jnp.dot(a, W) + b
        a = jax.nn.relu(z)
    W, b = params[-1]
    z = jnp.dot(a, W) + b
    return z

def loss_fn(params, x, y):
    preds = forward(params, x)
    return jnp.mean((preds - y)**2)

@jax.jit
def update(params, x, y, optimizer_state):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, optimizer_state

if __name__ == '__main__':
    key = jax.random.PRNGKey(0)
    layer_dims = [784, 512, 256, 10]
    params = init_params(key, layer_dims)
    
    X_train = jax.random.normal(key, (1000, 784))
    y_train = jax.random.normal(key, (1000, 10))
    
    optimizer = optax.adam(1e-3)
    optimizer_state = optimizer.init(params)

    epochs = 10
    for epoch in range(epochs):
        params, optimizer_state = update(params, X_train, y_train, optimizer_state)
        loss = loss_fn(params, X_train, y_train)
        print(f"Epoch {epoch+1}, Loss: {loss}")