In [22]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax

In [23]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [24]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(2, 38, rngs=rngs)
tx = optax.adam(1e-3)
optimizerState = nnx.Optimizer(model, tx)
modelGraph, modelWeights = nnx.split(model)
del model

In [25]:
def lossFunction(model, x, index, target):
  values = model(x)
  return jnp.mean((values[index] - target) ** 2)

In [26]:
tmpModel = nnx.merge(modelGraph, modelWeights)
input = jnp.array([0.1, -0.1])
originalModelOutput = tmpModel(input)
print(f'Original model output: {originalModelOutput}')
index = 0
target = 1.0
loss = lossFunction(tmpModel, input, index, target)
print(f'Loss: {loss}')
del tmpModel

Original model output: [-0.03245787 -0.00417501 -0.05088333  0.06116954  0.08806434 -0.05966324
 -0.10697126  0.00481742  0.0847846  -0.15256187 -0.07400742  0.03143728
  0.03714191  0.01746072 -0.08262251  0.04910205 -0.04121367  0.06232316
  0.08703245 -0.07444818  0.01227359  0.07393724 -0.14987452 -0.06048569
 -0.09122576 -0.04660274 -0.07468466 -0.03908684  0.00765305 -0.04118071
  0.0709928  -0.06787866 -0.04932142 -0.07135013  0.02210368 -0.02094354
 -0.00290814  0.06492181]
Loss: 1.0659691095352173


In [27]:
for i in range(100):
  tmpModel = nnx.merge(modelGraph, modelWeights)
  gradients = nnx.grad(lossFunction)(tmpModel, input, index, target)
  optimizerState.update(gradients)
  updatedModelOutput = tmpModel(input)
  # print(f'Updated model output: {updatedModelOutput}')
  newLoss = lossFunction(tmpModel, input, index, target)
  print(f'New loss: {newLoss}')
  del tmpModel

print(f'New model output: {tmpModel(input)}')

New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173
New loss: 1.0659691095352173


KeyboardInterrupt: 