In [2]:
from flax import nnx
import jax
import jax.numpy as jnp

In [3]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [23]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(2, 38, rngs=rngs)
nnx.display(model)
graph, params = nnx.split(model)
nnx.display(params)
finalBias = params['linear2']['bias']
nnx.display(finalBias)
print(f'{type(finalBias.value)} {finalBias.value}')
params['linear2']['bias'] = finalBias.value.at[0].set(10.0)
nnx.display(params)
model = nnx.merge(graph, params)
nnx.display(model)

<class 'jaxlib.xla_extension.ArrayImpl'> [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [8]:
@nnx.jit
def selectAction(model, x):
  jax.debug.print('{}',x)
  res = model(x)
  jax.debug.print('{}',res)
  argmaxRes = jnp.argmax(res)
  return argmaxRes

In [24]:
input = jnp.array([0.1, 0.01])
action = selectAction(model, input)
print(action)

[0.1  0.01]
[ 1.00114412e+01 -1.33433845e-02 -6.04167581e-02  2.69002598e-02
  5.23151606e-02  1.62119158e-02 -5.86803108e-02 -2.71647871e-02
  6.99529052e-02 -9.18169320e-02 -5.63428625e-02  2.74759773e-02
  3.16478983e-02  3.91693972e-02 -2.82971263e-02  1.72396116e-02
 -2.26175245e-02  3.24044302e-02  7.67709687e-02  4.00238670e-03
 -1.78625286e-02  1.34895369e-02 -8.92210156e-02 -3.08890436e-02
 -8.13865885e-02 -1.67200882e-02 -6.93658292e-02 -3.06927562e-02
  7.63989985e-03 -3.19317207e-02  1.12243704e-02 -4.60273400e-02
  1.07448697e-02 -3.15227173e-02 -1.89291649e-02  3.29239294e-04
 -3.37286592e-02  4.92924675e-02]
0
