In [10]:
import jax
import equinox as eqx
import jax.numpy as jnp

class NeuralNetwork(eqx.Module):
    layers: list
    extra_bias: jax.Array

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        # These contain trainable parameters.
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 8, key=key2),
                       eqx.nn.Linear(8, 2, key=key3)]
        # This is also a trainable parameter.
        self.extra_bias = jax.numpy.ones(2)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x) + self.extra_bias

@jax.jit  # compile this function to make it run fast.
@jax.grad  # differentiate all floating-point arrays in `model`.
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)  # vectorise the model over a batch of data
    return jax.numpy.mean((y - pred_y) ** 2)  # L2 loss

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
# Example data
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
model = NeuralNetwork(model_key)


In [11]:
# Compute gradients
grads = loss(model, x, y)
# Perform gradient descent
learning_rate = 0.1
new_model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)

IndentationError: unexpected indent (3285081436.py, line 5)

In [9]:
grads

NeuralNetwork(
  layers=[
    Linear(
      weight=f32[8,2],
      bias=f32[8],
      in_features=2,
      out_features=8,
      use_bias=True
    ),
    Linear(
      weight=f32[8,8],
      bias=f32[8],
      in_features=8,
      out_features=8,
      use_bias=True
    ),
    Linear(
      weight=f32[2,8],
      bias=f32[2],
      in_features=8,
      out_features=2,
      use_bias=True
    )
  ],
  extra_bias=f32[2]
)

In [6]:
new_model(x[0])

Array([0.94002867, 1.0102818 ], dtype=float32)

Array([[ 0.45058283, -0.5241236 ],
       [ 0.47175983, -0.9505155 ],
       [-0.274354  , -0.3354187 ],
       [-1.3399609 ,  0.51744205],
       [ 0.04267281, -2.5185785 ],
       [-1.8371971 , -1.3193125 ],
       [-1.3888389 ,  0.01685155],
       [-0.84184444,  0.4296381 ],
       [ 0.11726868,  1.4158473 ],
       [ 0.41861644,  0.5193756 ],
       [-0.66662365, -1.6508355 ],
       [-0.90173376, -1.5103353 ],
       [-0.51253945,  1.4452225 ],
       [ 1.2667034 ,  0.17484404],
       [ 1.1796353 , -1.2071054 ],
       [-0.9602053 ,  0.55503404],
       [ 0.791516  , -1.345151  ],
       [-0.20621188, -1.420752  ],
       [-0.37009287,  2.0824673 ],
       [ 0.435246  ,  0.42489448],
       [-0.4851248 ,  0.32981548],
       [-0.46110278, -0.31829634],
       [ 0.6464708 ,  0.644891  ],
       [-0.56586397, -2.2312171 ],
       [-0.36772147,  0.34090295],
       [-0.8522557 , -2.5428352 ],
       [-0.27524272, -0.831592  ],
       [ 0.6775455 , -0.01105266],
       [ 0.17004964,