In [1]:
from flax import nnx
import jax
import jax.numpy as jnp

In [2]:
class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

In [4]:
jax.config.update("jax_platform_name", "cpu")

In [7]:
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(f"This is y {y}")
nnx.display(model)

This is y [[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]
Linear(
  w=Param(
    value=Array(shape=(2, 5), dtype=float32)
  ),
  b=Param(
    value=Array(shape=(5,), dtype=float32)
  ),
  din=2,
  dout=5
)


In [13]:
class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))
    print(f'This is the type of count inside __init__ function {type(self.count)}')

  def __call__(self):
    print(f'This is type of self.count before addition operation {type(self.count)}')
    self.count = self.count+1
    print(f'This is type of self.count before {type(self.count)}')
    self.count = Count(self.count)
    print(f'This is type of self.count after conversion {type(self.count)}')

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')

This is the type of count inside __init__ function <class '__main__.Count'>
counter.count.value = Array(0, dtype=int32, weak_type=True)
This is type of self.count before addition operation <class '__main__.Count'>
This is type of self.count before <class 'jaxlib.xla_extension.ArrayImpl'>
This is type of self.count after conversion <class '__main__.Count'>
counter.count.value = Array(1, dtype=int32, weak_type=True)


In [14]:
class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

In [15]:
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

In [16]:
y = model(x=jnp.ones((3, 2)))

In [17]:
nnx.display(model)

MLP(
  linear1=Linear(
    w=Param(
      value=Array(shape=(2, 16), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    din=2,
    dout=16
  ),
  dropout=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(5, dtype=uint32),
        tag='default'
      )
    )
  )),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    num_features=16,
    use_running_average=False,
    axis=-1,
    momentum=0.99,
    epsilon=1e-05,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    u

In [21]:
class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
nnx.display(model)

MLP(
  linear1=Linear(
    w=Param(
      value=Array(shape=(2, 32), dtype=float32)
    ),
    b=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    din=2,
    dout=32
  ),
  dropout=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(4, dtype=uint32),
        tag='default'
      )
    )
  )),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    num_features=32,
    use_running_average=False,
    axis=-1,
    momentum=0.99,
    epsilon=1e-05,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    u

In [22]:
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

In [23]:
y = model(x=jnp.ones((3, 2)))

nnx.display(model)

MLP(
  linear1=LoraLinear(
    linear=Linear(
      w=Param(
        value=Array(shape=(2, 32), dtype=float32)
      ),
      b=Param(
        value=Array(shape=(32,), dtype=float32)
      ),
      din=2,
      dout=32
    ),
    A=LoraParam(
      value=Array(shape=(2, 4), dtype=float32)
    ),
    B=LoraParam(
      value=Array(shape=(4, 32), dtype=float32)
    )
  ),
  dropout=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(9, dtype=uint32),
        tag='default'
      )
    )
  )),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(32,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(32,), dty

In [24]:
import optax

In [25]:
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

In [26]:
@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss

In [27]:
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

In [28]:
print(f'{loss = }')
print(f'{optimizer.step.value = }')

loss = Array(1.0000278, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)


: 