In [None]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax

In [None]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [None]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(2, 38, rngs=rngs)
tx = optax.adam(1e-3)
optimizerState = nnx.Optimizer(model, tx)
modelGraph, modelWeights = nnx.split(model)
del model

In [None]:
def lossFunction(model, x, index, target):
  values = model(x)
  return jnp.mean((values[index] - target) ** 2)

In [None]:
tmpModel = nnx.merge(modelGraph, modelWeights)
input = jnp.array([0.1, -0.1])
originalModelOutput = tmpModel(input)
print(f'Original model output: {originalModelOutput}')
index = 0
target = 1.0
loss = lossFunction(tmpModel, input, index, target)
print(f'Loss: {loss}')
del tmpModel

In [None]:
for i in range(100):
  tmpModel = nnx.merge(modelGraph, modelWeights)
  gradients = nnx.grad(lossFunction)(tmpModel, input, index, target)
  optimizerState.update(gradients)
  updatedModelOutput = tmpModel(input)
  # print(f'Updated model output: {updatedModelOutput}')
  newLoss = lossFunction(tmpModel, input, index, target)
  print(f'New loss: {newLoss}')
  del tmpModel

print(f'New model output: {tmpModel(input)}')