In [16]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax

In [17]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [18]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(2, 38, rngs=rngs)

nnx.display(model)

def isKernel(x):
  return jax.tree.map(lambda y: y.ndim > 1, x)
tx = optax.adamw(1e-3, weight_decay=1, mask=isKernel)
optimizerState = nnx.Optimizer(model, tx)

In [19]:
def lossFunction(model, x, index, target):
  values = model(x)
  return jnp.mean((values[index] - target) ** 2)

In [20]:
input = jnp.array([0.1, -0.1])
originalModelOutput = model(input)
print(f'Original model output: {originalModelOutput}')

@nnx.jit
def trainStep(model, optimizerState, input, index):
  target = 1.0
  gradients = nnx.grad(lossFunction)(model, input, index, target)
  optimizerState.update(gradients)

for i in range(10000):
  trainStep(model, optimizerState, input, i%10)

updatedModelOutput = model(input)
print(f'Updated model output: {updatedModelOutput}')

nnx.display(model)


Original model output: [-0.03245787 -0.00417501 -0.05088333  0.06116954  0.08806434 -0.05966324
 -0.10697126  0.00481742  0.0847846  -0.15256187 -0.07400742  0.03143728
  0.03714191  0.01746072 -0.08262251  0.04910205 -0.04121367  0.06232316
  0.08703245 -0.07444818  0.01227359  0.07393724 -0.14987452 -0.06048569
 -0.09122576 -0.04660274 -0.07468466 -0.03908684  0.00765305 -0.04118071
  0.0709928  -0.06787866 -0.04932142 -0.07135013  0.02210368 -0.02094354
 -0.00290814  0.06492181]
Updated model output: [ 9.8095423e-01  9.8884177e-01  9.9346423e-01  9.9687809e-01
  9.9960864e-01  1.0013217e+00  1.0008665e+00  9.9614727e-01
  9.8705089e-01  9.8062277e-01  2.0272502e-05 -1.2461627e-05
 -1.3578127e-05  1.2785282e-05 -7.8782678e-06 -2.0795701e-06
  3.3008018e-07  1.3384075e-05 -2.2760054e-05  1.1880682e-05
 -3.2228916e-05  2.2382064e-06 -2.7051070e-05  6.6846628e-06
 -3.3874865e-05 -8.4975836e-06 -1.1391238e-05 -1.1243656e-05
 -1.9438661e-05 -5.0132792e-05  4.2523420e-06 -1.9590050e-05
 -8