<a href="https://colab.research.google.com/github/FannYYW/jax_tutorial/blob/main/Flax_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Basic Usage

In [None]:
from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

INFO:2024-11-10 21:56:12,482:jax._src.xla_bridge:906: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program.
INFO:2024-11-10 21:56:14,350:jax._src.xla_bridge:906: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2024-11-10 21:56:14,351:jax._src.xla_bridge:906: Unable to initialize backend 'mock_tpu': Must pass --mock_tpu_platform flag to initialize the mock_tpu backend


# Basics

## The Flax NNX Module system

In [None]:
import jax
import jax.numpy as jnp

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)

[[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]


In [None]:
# @title Stateful computation

class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')

counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)


In [None]:
# @title Nested Modules

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

In [None]:
# @title Model surgery

class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

## Flax transformations

In [None]:
# The updates to each of the nnx.BatchNorm and nnx.Dropout layer’s state is
# automatically propagated from within loss_fn to train_step all the way to the
# model reference outside.

# The optimizer holds a mutable reference to the model - this relationship is
# preserved inside the train_step function making it possible to update the
# model’s parameters using the optimizer alone.

import optax

# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')

loss = Array(1., dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)


In [None]:
# @title Scan over layers

@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
# This calls the forward function with the stack of MLP models and the input.
# Due to scan, each model in the stack processes the input sequentially.
y = forward(model, x)

print(f'{y.shape = }')
print(y)
# TODO: why model does not show 5 layers?
nnx.display(model)


y.shape = (3, 10)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [None]:
graphdef, state = nnx.split(model)

nnx.display(graphdef, state)

In [None]:
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

new_key, old_key = jax.random.split(jax.random.key(0), 2)
model = create_model(old_key)

def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
print(y)
nnx.display(model)

y.shape = (3, 10)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


## The Flax Functional API

In [None]:
# @ title nnx.Param nnx.Variables example

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)




In [None]:
# @title State and GraphDef

# nnx.State is a Mapping from strings to nnx.Variables or nested States.
# nnx.GraphDef contains all the static information needed to reconstruct a
# nnx.Module graph, it is analogous to JAX’s PyTreeDef.

graphdef, state = nnx.split(model)

nnx.display(graphdef, state)

In [None]:
# @title Split, merge, and update

# TODO: why splitting is necessary when using JAX transforms with Flax NNX modules.
# https://screenshot.googleplex.com/v4GB7hLwgsbsLQB

print(f'{model.count.value = }')

# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
  # 2. Use `nnx.merge` to create a new model inside the JAX transformation.
  model = nnx.merge(graphdef, state)
  # 3. Call the `nnx.Module`
  y = model(x)
  # 4. Use `nnx.split` to propagate `nnx.State` updates.
  _, state = nnx.split(model)
  return y, state

y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)

print(f'{model.count.value = }')

model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)


In [None]:
# @title Fine-grained State control

# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)

nnx.display(graphdef, params, counts)

In [None]:
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)

# Mnist Example
https://flax.readthedocs.io/en/latest/mnist_tutorial.html

# Transformations

In [None]:
# @title Basic Examples
import jax
from jax import numpy as jnp, random
from flax import nnx


class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(weights)

y.shape = (10, 3)


In [None]:
# Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers.
def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )


seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)

In [None]:
# @title Transforming methods

class WeightStack(nnx.Module):
  @nnx.vmap
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=0, out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)

print(f'{y.shape = }')
nnx.display(weights)

y.shape = (3, 10)


In [None]:
# @title State propagation
# propagate state changes to preserve reference semantics


class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

# updates were propagated to the original Weights object outside the transformation
weights.count


Count(
  value=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
)

In [None]:
# @title Graph updates propagation
# Flax NNX’s state propagation machinery can track arbitrary updates to the
# objects as long as they’re local to the inputs (updates to globals inside
# transforms are not supported).
# While this feature is very powerful, it must be used with care because it can
# clash with JAX’s underlying assumptions for certain transforms. For example,
# jit expects the structure of the inputs to be stable in order to cache the
# compiled function, so changing the graph structure inside an nnx.jit-ed
# function causes continuous recompilations and performance degradation. On the
# other hand, scan only allows a fixed carry structure, so adding/removing
# sub-states declared as carry will cause an error.

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  weights.some_property = ['a', 2, False] # add attribute
  del weights.bias # delete attribute
  weights.new_param = weights.kernel # share reference
  return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(weights)

In [None]:
# @title Transforming sub-states (lift types)

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(1),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

# note that nnx.StateAxes can only be used directly on Flax NNX objects, and
# it cannot be used as a prefix for a pytree of objects.
state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

print(weights.count)
nnx.display(weights)
nnx.display(y)

Count(
  value=Array(2, dtype=int32, weak_type=True)
)


In [None]:
# if we don't do any broadcast for Count, how to make this work?

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(10),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

# note that nnx.StateAxes can only be used directly on Flax NNX objects, and
# it cannot be used as a prefix for a pytree of objects.
state_axes = nnx.StateAxes({nnx.Param: 0, Count: 0}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

print(weights.count)
nnx.display(weights)
nnx.display(y)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [None]:
# @title Random State

class Weights(nnx.Module):
  def __init__(self, kernel, bias, count, seed):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)
    # a random state is just a regular state
    self.rngs = nnx.Rngs(noise=seed)

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))

def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  # Because Rngs’s state is updated in place and automatically propagated by
  # nnx.vmap, we will get a different result every time that noisy_vector_dot is called.
  return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
nnx.display(y1)
nnx.display(y2)

False


## Rules and limitations

In [None]:
# @title Mutable Module cannot be passed by closure

# To solve this issue pass all Module as arguments to the functions being transformed. In this case f should accept counter as an argument.

class Counter(nnx.Module):
  def __init__(self):
    self.count = nnx.Param(jnp.array(0))

  def increment(self):
    self.count += jnp.array(1)

counter = Counter()

@nnx.jit
def f(x):
  counter.increment()
  return 2 * x

try:
  y = f(3)
except Exception as e:
  print(e)


In [None]:
# @title Consistent aliasing

class Weights(nnx.Module):
  def __init__(self, array: jax.Array):
    self.param = nnx.Param(array)

m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]

# this would be problematic in Flax NNX because you are trying to vectorize m` in two different ways.
@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
  ...

try:
  f(arg1, arg2)
except ValueError as e:
  print(e)


# arg1 is vectorized on axis 0 on the input, and axis 1 on the output. As expected, this is problematic and Flax NNX will raise an error.
@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
  return arg1

try:
  f(arg1)
except ValueError as e:
  print(e)




In [None]:
# @title Axis metadata

class Weights(nnx.Module):
  def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
    self.param = nnx.Param(array, sharding=sharding)

m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))

# vmapping should happen along the 'b' axis of the sharding.
@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
  print(f'Inner {m.param.shape = }')
  print(f'Inner {m.param.sharding = }')

f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')

@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
  return Weights(jnp.ones((3, 5)), sharding=('a', None))

m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')

Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)


# Scale up on multiple devices

In [None]:
# @title Overview

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from typing import *

import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from flax import nnx

import optax # Optax for common losses and optimizers.

print(f'You have 8 “fake” JAX devices now: {jax.devices()}')

# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
            axis_names=('data', 'model'))
print(mesh)

INFO:2024-11-16 00:11:00,843:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program.
INFO:2024-11-16 00:11:01,873:jax._src.xla_bridge:927: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2024-11-16 00:11:01,874:jax._src.xla_bridge:927: Unable to initialize backend 'mock_tpu': Must pass --mock_tpu_platform flag to initialize the mock_tpu backend


You have 8 “fake” JAX devices now: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Mesh('data': 2, 'model': 4)


In [None]:
# @title Define a model with specified sharding

class DotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1` and annotate its kernel with.
    # `sharding (None, 'model')`.
    self.dot1 = nnx.Linear(
      depth, depth,
      # The first dimension will be replicated across all devices.
      # The second dimension will be sharded over the 'model' axis of the device
      # mesh. This means W1 will be sharded 4-way on devices (0, 4), (1, 5),
      # (2, 6) and (3, 7), in this dimension.
      kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
      use_bias=False,  # or use `bias_init` to give it annotation too
      rngs=rngs)

    # Initialize a weight param `w2` and annotate with sharding ('model', None).
    # Note that this is simply adding `.sharding` to the variable as metadata!
    self.w2 = nnx.Param(
      init_fn(rngs.params(), (depth, depth)),  # RNG key and shape for W2 creation
      sharding=('model', None),
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # In data parallelism, input / intermediate value's first dimension (batch)
    # will be sharded on `data` axis
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
    z = jnp.dot(y, self.w2.value)
    return z

In [None]:
# @title Initialize a sharded model

unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))

# You have annotations stuck there, yay!
print(unsharded_model.dot1.kernel.sharding)     # (None, 'model')
print(unsharded_model.w2.sharding)              # ('model', None)

# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding)  # SingleDeviceSharding
print(unsharded_model.w2.value.sharding)           # SingleDeviceSharding

(None, 'model')
('model', None)
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), memory_kind=device)


In [None]:
@nnx.jit
def create_sharded_model():
  model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)

# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)

NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=device)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=device)
sharded_model.dot1.kernel (None, 'model') :
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,4│TPU 1,5│TPU 2,6│TPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
sharded_model.w2 ('model', None) :
┌───────────────────────┐
│        TPU 0,4        │
├───────────────────────┤
│        TPU 1,5        │
├───────────────────────┤
│        TPU 2,6        │
├───────────────────────┤
│        TPU 3,7        │
└───────────────────────┘


In [None]:
# @title Load a sharded model from a checkpoint

import orbax.checkpoint as ocp

# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)

# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
# use the nnx.eval_shape transform to generate a model of abstract JAX arrays,
# and only use its .sharding annotations to obtain the sharding tree.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
  lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
  # generate such a sharding pytree with Flax’s nnx.get_named_sharding
  abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
                                      target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)



┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,4│TPU 1,5│TPU 2,6│TPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
┌───────────────────────┐
│        TPU 0,4        │
├───────────────────────┤
│        TPU 1,5        │
├───────────────────────┤
│        TPU 2,6        │
├───────────────────────┤
│        TPU 3,7        │
└───────────────────────┘


In [None]:
# @title Compile the training loop

# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)

with mesh:
  output = sharded_model(input)
# Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without jit compilation.
print(output.shape)
jax.debug.visualize_array_sharding(output)  # Also sharded as `('data', None)`.

(8, 1024)
┌──────────────────────────────────────────────────────────────────────────────┐
│                                                                              │
│                                 TPU 0,1,2,3                                  │
│                                                                              │
│                                                                              │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│                                 TPU 4,5,6,7                                  │
│                                                                              │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘


In [None]:
# except that the inputs and labels are also explicitly sharded.
# nnx.jit will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs.

nnx.jit will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs.
optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))  # reference sharing

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model: DotReluDot):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

with mesh:
  for i in range(5):
    loss = train_step(sharded_model, optimizer, input, label)
    print(loss)    # Model (over-)fitting to the labels quickly.

1.4579723
0.7753284
0.5239993
0.38882557
0.2845982


In [None]:
# @title Profiling

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(sharded_model, optimizer, input, label))

1.14 ms ± 52.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
# @title Logical axis annotation

# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))

class LogicalDotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_metadata(
        # Provide the sharding rules here.
        init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
      use_bias=False,
      rngs=rngs)

    # Initialize a weight param `w2`.
    self.w2 = nnx.Param(
      # Didn't provide the sharding rules here to show you how to overwrite it later.
      nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
        rngs.params(), (depth, depth))
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
    z = jnp.dot(y, self.w2.value)
    return z

In [None]:
# If you didn’t provide all sharding_rule annotations in the model definition,
# you can write a few lines to add it to Flax’s nnx.State of the model, before
# the call of nnx.get_partition_spec or nnx.get_named_sharding.


def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:
  vs.sharding_rules = sharding_rules
  return vs

@nnx.jit
def create_sharded_logical_model():
  model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
  state = nnx.state(model)
  state = jax.tree.map(add_sharding_rule, state,
                       is_leaf=lambda x: isinstance(x, nnx.VariableState))
  pspecs = nnx.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

with mesh:
  sharded_logical_model = create_sharded_logical_model()

jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)

# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

with mesh:
  logical_output = sharded_logical_model(input)
  assert logical_output.sharding.is_equivalent_to(
    NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
  )


┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,4│TPU 1,5│TPU 2,6│TPU 3,7│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
┌───────────────────────┐
│        TPU 0,4        │
├───────────────────────┤
│        TPU 1,5        │
├───────────────────────┤
│        TPU 2,6        │
├───────────────────────┤
│        TPU 3,7        │
└───────────────────────┘


# Using Filters

In [None]:
from flax import nnx

class Foo(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = nnx.BatchStat(True)

foo = Foo()

# nnx.Param and nnx.BatchStat are used as Filters to split the model into two groups
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')

params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})


In [None]:
# @title The Filter Protocol

# (path: tuple[Key, ...], value: Any) -> bool

def is_param(path, value) -> bool:
  return isinstance(value, nnx.Param) or (
    hasattr(value, 'type') and issubclass(value.type, nnx.Param)
  )

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')

is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')

is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True


In [None]:
# @title The Filter DSL

# nnx.filterlib.Filter

# example
# we can use the following filters to define a nnx.StateAxes object that we can
# pass to nnx.vmap’s in_axes to specify how model’s various substates should be vectorized:
# selects all nnx.Param objects with the tag 'dropout', e.g., self.dropout_rate = nnx.Param(0.5, tags=('dropout',))
# the selected parameters will be vmapped over their first axis (axis 0).
# ... selects all other parts of the model's state not matched by the previous filter
# None: all other parts of the state will be broadcasted across the vmapped
# dimension, effectively remaining the same for each vmapped iteration.
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})

@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
  ...

In [None]:
# convert literal into a predicate

is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))

print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')

In [None]:
# @title Grouping States

# how nnx.split is roughly implemented
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]

def split(node, *filters):
  graphdef, state = nnx.graph.flatten(node)
  predicates = [nnx.filterlib.to_predicate(f) for f in filters]
  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]

  for path, value in state.flat_state().items():
    for i, predicate in enumerate(predicates):
      if predicate(path, value):
        flat_states[i][path] = value
        break
    else:
      raise ValueError(f'No filter matched {path = } {value = }')

  states: tuple[nnx.GraphState, ...] = tuple(
    nnx.State.from_flat_path(flat_state) for flat_state in flat_states
  )
  return graphdef, *states

# lets test it...
foo = Foo()

graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')

params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})


In [None]:
# filtering is order-dependent, you should place more specific filters before more general filters

class SpecialParam(nnx.Param):
  pass

class Bar(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = SpecialParam(0)

bar = Bar()

graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')

graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')

params = State({
  'a': VariableState(
    type=Param,
    value=0
  ),
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})
special_params = State({})
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
special_params = State({
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})


# Randomness

In [None]:
# random state is just another type of state

from flax import nnx
import jax
from jax import random, numpy as jnp

## Rngs, RngStream, and RngState

- Rngs: The main user interface. It defines a set of named RngStream objects.

- nnx.RngStream: A object that can generate a stream of RNG keys. It holds a root key and a count inside a RngKey and RngCount Variables respectively. When a new key is generated, the count is incremented.

- nnx.RngState: The base type for all RNG-related state.

 - nnx.RngKey: Variable type for holding RNG keys, it includes a tag attribute containing the name of the stream.
 - nnx.RngCount: Variable type for holding RNG counts, it includes a tag attribute containing the name of the stream.

In [None]:
# created two streams
# This creates an RNG stream named "params" and initializes it with a seed value of 0
# This creates another RNG stream named "dropout" and initializes it with a JAX random key generated from the seed 1.
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)

INFO:2024-11-16 18:34:20,874:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program.
INFO:2024-11-16 18:34:21,935:jax._src.xla_bridge:927: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2024-11-16 18:34:21,938:jax._src.xla_bridge:927: Unable to initialize backend 'mock_tpu': Must pass --mock_tpu_platform flag to initialize the mock_tpu backend


In [None]:
# This will return a new key
params_key = rngs.params()
dropout_key = rngs.dropout()

nnx.display(rngs)

In [None]:
# @title Standare stream names
# There are only two standard stream names used by Flax NNX’s built-in layers: params, dropout
class Model(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    # TODO: two rngs streams, how the function knows which one to use?
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.drop(self.linear(x)))

model = Model(nnx.Rngs(params=0, dropout=1))

y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
nnx.display(model)

y.shape = (1, 10)


In [None]:
# @title Default stream

# Flax NNX provides a default stream that can be be used as a fallback when a stream is not found.
rngs = nnx.Rngs(0, params=1)

key1 = rngs.params() # call params
key2 = rngs.dropout() # fallback to default
key3 = rngs() # call default directly

# test with Model that uses params and dropout
model = Model(rngs)
y = model(jnp.ones((1, 20)))

nnx.display(rngs)
nnx.display(model)

## Filtering random state


In [None]:
model = Model(nnx.Rngs(params=0, dropout=1))

rng_state = nnx.state(model, nnx.RngState) # all random state
key_state = nnx.state(model, nnx.RngKey) # only keys
count_state = nnx.state(model, nnx.RngCount) # only counts
rng_params_state = nnx.state(model, 'params') # only params
rng_dropout_state = nnx.state(model, 'dropout') # only dropout
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys

nnx.display(rng_state)
nnx.display(params_key_state)

## Reseeding

In [None]:
model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))

y1 = model(x)
y2 = model(x)

nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)
y4 = model(x)

assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3)     # same
assert not jnp.allclose(y3, y4) # different

## Splitting Rngs

In [None]:
# When interacting with transforms like vmap or pmap it is often necessary to
# split the random state such that each replica has its own unique state.

rngs = nnx.Rngs(params=0, dropout=1)

@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
  print('Inside:')
  # rngs.dropout() # ValueError: fold_in accepts a single key...
  nnx.display(rngs)

f(rngs)

print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)

Inside:


Outside:


## Transforms

In [None]:
# @title Data parallel dropout

model = Model(nnx.Rngs(params=0, dropout=1))

num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})

# need to split the random state of the dropout to ensure that each replica gets different dropout masks
@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
  return model(x)

y = forward(model, x)
print(y.shape)

(8, 16, 10)


In [None]:
# @title Recurrent dropout

class Count(nnx.Variable): pass

class RNNCell(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
    self.dout = dout
    self.count = Count(jnp.array(0, jnp.uint32))

  def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
    h = self.drop(h) # recurrent dropout
    y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
    self.count += 1
    return y, y

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.dout))

cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))

@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
  h = cell.initial_state(batch_size=x.shape[0])

  # broadcast 'recurrent_dropout' RNG state to have the same mask on every step
  state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
  @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
  def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
    h, y = cell(h, x)
    return h, y

  h, y = unroll(cell, h, x)
  return y

x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)

print(f'{y.shape = }')
print(f'{cell.count.value = }')

y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)


# Model surgery

-  parameters, such as layer replacement, parameter or state manipulation, or even “monkey patching”

In [None]:
from typing import *
from pprint import pprint
import functools

import jax
from jax import lax, numpy as jnp, tree_util as jtu

from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax

key = jax.random.key(0)

In [None]:
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

In [None]:
# @title Pythonic nnx.Module manipulation

# Pythonic operations on its sub-Modules, such as sub-Module swapping, Module sharing, variable sharing, and monkey-patching:

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))

# Sub-`Module` swapping.
original1, original2 = model.linear1, model.linear2
model.linear1, model.linear2 = model.linear2, model.linear1
np.testing.assert_allclose(model(x), original1(original2(x)))

# `Module` sharing (tying all weights together).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear2 = model.linear1
assert not hasattr(nnx.state(model), 'linear2')
np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))

# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel  # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert not hasattr(nnx.state(model)['linear2'], 'kernel')

# Monkey-patching.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
def awesome_layer(x): return x
model.linear2 = awesome_layer
np.testing.assert_allclose(model(x), model.linear1(x))

In [None]:
# @title Creating an abstract model or state without memory allocation

# Create a function that returns a valid Flax NNX model; and
# Run nnx.eval_shape (not jax.eval_shape) upon it.
abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel
abs_state['linear1']['bias'].value = model.linear1.bias
abs_state['linear2']['kernel'].value = model.linear2.kernel
abs_state['linear2']['bias'].value = model.linear2.bias
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x))  # They are equivalent now!

State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  }
})


In [None]:
# @title Checkpoint surgery

# Save a version of model into a checkpoint
checkpointer = orbax.PyTreeCheckpointer()
old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)

class ModifiedTwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(dim, dim, rngs=rngs)  # no longer linear1!
    self.layer2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.layer1(x)
    return self.layer2(x)

abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
# In this new model, the sub-Modules are renamed from linear(1|2) to layer(1|2).
# Since the pytree structure has changed, it is impossible to directly load the
#old checkpoint with the new model state structure:
try:
  with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))
  print(with_item)
except Exception as e:
  print(f'This will throw error: {type(e)}: {e}')

This will throw error: <class 'ValueError'>: Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.


In [None]:
def process_raw_dict(raw_state_dict):
  flattened = nnx.traversals.flatten_mapping(raw_state_dict)
  # Cut the '.value' postfix on every leaf path.
  flattened = {(path[:-1] if path[-1] == 'value' else path): value
               for path, value in flattened.items()}
  return nnx.traversals.unflatten_mapping(flattened)

# Make your local change on the checkpoint dictionary.
raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw_dict)
raw_dict['layer1'] = raw_dict.pop('linear1')
raw_dict['layer2'] = raw_dict.pop('linear2')

# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))



{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[-0.8034528 , -0.3407213 , -0.94083846,  0.0100597 ],
       [ 0.26146525,  1.1247874 ,  0.54563653, -0.37416318],
       [ 1.0281817 , -0.67987853, -0.14883997,  0.0569495 ],
       [-0.4430822 , -0.60587126,  0.43408424, -0.40540555]],      dtype=float32)}},
 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'value': Array([[ 0.21010023,  0.82893616,  0.04589561,  0.54226494],
       [ 0.41913688,  0.8435967 , -0.4793838 , -0.49135533],
       [-0.46072757,  0.46301958,  0.3927706 , -0.9441392 ],
       [-0.66906977, -0.18474793, -0.5762287 ,  0.48211244]],      dtype=float32)}}}


## Partial initialization

- In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only part of your model parameters. This can be achieved through:

- Naive partial initialization; or Memory-efficient partial initialization.

In [None]:
# @title Naive partial initialization

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# In this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
      ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 discarded - only lora_a & lora_b are used in model)')

Number of jax arrays in memory at start: 80
Number of jax arrays in memory midway: 90 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)
Number of jax arrays in memory at end: 88 (2 discarded - only lora_a & lora_b are used in model)


In [None]:
# @title Memory-efficient partial initialization

# To do memory-efficient partial initialization, use nnx.jit’s efficiently
# compiled code to make sure only the state parameters you need are initialized:

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
# donate_argnums=0 tells nnx.jit that the first argument (old_state) can be
# "donated" to the function. This means that after the function call,
# old_state's memory can be reclaimed, preventing it from unnecessarily occupying memory.
@nnx.jit(donate_argnums=0)
def partial_init(old_state, rngs):
  model = TwoLayerMLP(4, rngs=rngs)
  # Create a new state.
  model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
  # Add the existing state.
  nnx.update(model, old_state)
  return model

print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')
# Note that `old_state` will be deleted after this `partial_init` call.
good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 new created - lora_a and lora_b)')

Number of JAX Arrays in memory at start: 92
Number of JAX Arrays in memory at end: 98 (2 new created - lora_a and lora_b)


# Save and load checkpoints

- how to save and load Flax NNX model checkpoints with Orbax

In [None]:
# @title Setup

from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')

In [None]:
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

# Instantiate the model and show we can run it.
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

In [None]:
# @title Save checkpoints

_, state = nnx.split(model)
nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / 'state', state)

In [None]:
# @title Restore checkpoints

# First, create an abstract Flax NNX model (without allocating any memory for
# arrays), and show its abstract variable state to the checkpointing library.
# Once you have the state, use nnx.merge to obtain your Flax NNX model, and use it as usual.

# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)

state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
jax.tree.map(np.testing.assert_array_equal, state, state_restored)
print('NNX State restored: ')
nnx.display(state_restored)

# The model is now good to use!
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)

The abstract NNX state (all leaves are abstract arrays):


NNX State restored: 




In [None]:
# @title Save and restore as pure dictionaries

# Save as pure dict
pure_dict_state = state.to_pure_dict()
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)

# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4)  # The model still works!


In [None]:
# @title Restore when checkpoint structures differ

class ModifiedTwoLayerMLP(nnx.Module):
  """A modified version of TwoLayerMLP, which requires bias arrays."""
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True)  # We need bias now!
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True)  # We need bias now!

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

# Accommodate your old checkpoint to the new code.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))
restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))

# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4)  # The new model works!

nnx.display(model.linear1)

# Flax NNX vs JAX Transformations

- Notice the function signature of Flax NNX-transformed functions can accept the nnx.Linear module directly and can make stateful updates to the module, whereas the function signature of JAX-transformed functions can only accept the pytree-registered State and GraphDef objects and must return an updated copy of them to maintain the purity of the transformed function.

In [None]:
# @title Flax

from flax import nnx
import jax

x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))

@nnx.jit
def train_step(model, x, y):
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)



In [None]:
# @title Jax
@jax.jit
def train_step(graphdef, state, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, argnums=1)(graphdef, state)

  model = nnx.merge(graphdef, state)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)

In [None]:
# @title Mixing Flax NNX and JAX transformations

@nnx.jit
def train_step(model, x, y):
  def loss_fn(graphdef, state):
    model = nnx.merge(graphdef, state)
    return ((model(x) - y) ** 2).mean()
  grads = jax.grad(loss_fn, 1)(*nnx.split(model))
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)

model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
train_step(model, x, y)

@jax.jit
def train_step(graphdef, state, x, y):
  model = nnx.merge(graphdef, state)
  def loss_fn(model):
    return ((model(x) - y) ** 2).mean()
  grads = nnx.grad(loss_fn)(model)
  params = nnx.state(model, nnx.Param)
  params = jax.tree_util.tree_map(
    lambda p, g: p - 0.1 * g, params, grads
  )
  nnx.update(model, params)
  return nnx.split(model)

graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
graphdef, state = train_step(graphdef, state, x, y)

# Example: Using Pretrained Gemma

In [None]:
# @title

In [None]:
# @title

In [None]:
# @title