In [1]:
import equinox as eqx
import jax

class MyModule(eqx.Module):
    layers: list
    extra_bias: jax.Array

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 8, key=key2),
                       eqx.nn.Linear(8, 2, key=key3)]
        # This is a trainable parameter.
        self.extra_bias = jax.numpy.ones(2)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x) + self.extra_bias

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)
learning_rate = 0.1
model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [6]:
loss(model,x,y)

MyModule(
  layers=[
    Linear(
      weight=f32[8,2],
      bias=f32[8],
      in_features=2,
      out_features=8,
      use_bias=True
    ),
    Linear(
      weight=f32[8,8],
      bias=f32[8],
      in_features=8,
      out_features=8,
      use_bias=True
    ),
    Linear(
      weight=f32[2,8],
      bias=f32[2],
      in_features=8,
      out_features=2,
      use_bias=True
    )
  ],
  extra_bias=f32[2]
)

In [9]:
model(jax.numpy.array([1.0,2.0]))

Array([0.88071775, 1.0653281 ], dtype=float32)