In [6]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax

In [7]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [8]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(2, 38, rngs=rngs)

nnx.display(model)

def isKernel(x):
  return jax.tree.map(lambda y: y.ndim > 1, x)
tx = optax.adamw(1e-3, weight_decay=1, mask=isKernel)
optimizerState = nnx.Optimizer(model, tx)

In [9]:
def lossFunction(model, x, index, target):
  values = model(x)
  return jnp.mean((values[index] - target) ** 2)

In [10]:
input = jnp.array([0.1, -0.1])
originalModelOutput = model(input)
print(f'Original model output: {originalModelOutput}')

@nnx.jit
def trainStep(model, optimizerState, input, index):
  target = 1.0
  gradients = nnx.grad(lossFunction)(model, input, index, target)
  globalNorm = optax.global_norm(gradients)
  optimizerState.update(gradients)
  return globalNorm

for i in range(200):
  globalNorm = trainStep(model, optimizerState, input, i%10)
  print(f'Global norm: {globalNorm}')

updatedModelOutput = model(input)
print(f'Updated model output: {updatedModelOutput}')

nnx.display(model)


Original model output: [-0.03245787 -0.00417501 -0.05088333  0.06116954  0.08806434 -0.05966324
 -0.10697126  0.00481742  0.0847846  -0.15256187 -0.07400742  0.03143728
  0.03714191  0.01746072 -0.08262251  0.04910205 -0.04121367  0.06232316
  0.08703245 -0.07444818  0.01227359  0.07393724 -0.14987452 -0.06048569
 -0.09122576 -0.04660274 -0.07468466 -0.03908684  0.00765305 -0.04118071
  0.0709928  -0.06787866 -0.04932142 -0.07135013  0.02210368 -0.02094354
 -0.00290814  0.06492181]
Global norm: 2.632449150085449
Global norm: 2.5957422256469727
Global norm: 2.784738063812256
Global norm: 2.3131484985351562
Global norm: 2.436494827270508
Global norm: 2.798032522201538
Global norm: 2.7770140171051025
Global norm: 2.5736753940582275
Global norm: 2.2914843559265137
Global norm: 3.151503086090088
Global norm: 2.5795822143554688
Global norm: 2.5416016578674316
Global norm: 2.718944787979126
Global norm: 2.2796943187713623
Global norm: 2.3739497661590576
Global norm: 2.7085020542144775
Global 