In [33]:
from flax import nnx
import jax
import jax.numpy as jnp

In [34]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, dropoutRate: float, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.dropout1 = nnx.Dropout(rate=dropoutRate, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x, deterministic, rngs):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.dropout1(x, deterministic=deterministic, rngs=rngs)
    x = self.linear2(x)
    return x

In [35]:
dropoutRate = 0.1
# key = jax.random.key(0)
# key1, key2 = jax.random.split(key)
# rngs = nnx.Rngs(params=key1, dropout=key2)
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(10, 5, dropoutRate, rngs)

In [36]:
rngs2 = nnx.Rngs(jax.random.key(1))
input = jnp.ones((10,))
output = model(input, deterministic=False, rngs=rngs2)
print(output)
output = model(input, deterministic=True, rngs=rngs2)
print(output)

[ 0.40495783  0.28237268 -0.63730925  0.52325785 -0.5349195 ]
[ 0.3173416   0.26317954 -0.58215386  0.45295987 -0.45487443]


In [37]:
@nnx.jit
def jittedModel(x, rngs):
  return model(x, deterministic=False, rngs=rngs), model(x, deterministic=True, rngs=rngs)

x, y = jittedModel(input, rngs2)
print(x)
print(y)


[ 0.26417822  0.4194631  -0.06500262  0.41393083 -0.65323234]
[ 0.3173416   0.26317954 -0.5821538   0.4529599  -0.4548744 ]
