In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
from pathlib import Path
import time
import dataclasses
import functools
import contextlib
from io import StringIO
from pprint import pprint

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

import jax
from jax import numpy as jnp, random as jrandom
from flax import nnx, linen as nn
from jax import sharding
from jax.sharding import Mesh
from jax.experimental.mesh_utils import create_device_mesh
import optax
import equinox as eqx
#from flax.linen import partitioning as nn_partitioning
#from flax.core import meta
from flax.nnx import bridge

paths = [Path("MaxText").absolute()]
[sys.path.append(str(path)) for path in paths if str(path) not in sys.path]

from MaxText.layers.normalizations import RMSNorm
from MaxText.nnx_layers.normalizations import RMSNorm as NNXRMSNorm
from MaxText.nnx_layers import LinenToNNX
from MaxText import pyconfig, train, max_utils
from MaxText.nnx_layers.models import Transformer
from MaxText.layers.models import Transformer as Transformer2

2024-08-28 15:03:12.386928: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-28 15:03:12.398480: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-28 15:03:12.402136: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
buf = StringIO()
with contextlib.redirect_stdout(buf):
  pyconfig.initialize(["python3", "MaxText/configs/base.yml", "hardware=other", 
                      "enable_single_controller=True", "decoder_block=default", 
                      "scan_layers=True"])
  config = pyconfig.config
pyconfig_output = (buf.seek(0), buf.read())[1]

input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

Num_devices: 2, shape (1, 1, 2, 1, 1, 1, 1)


# Testing `nnx.Scan` and others

In [22]:
class MyMod(nnx.Module):
  def __init__(self, in_dim, out_dim, rngs):
    self.rngs = rngs
    self.linear1 = nnx.Linear(in_dim, 2 * out_dim, rngs=self.rngs)
    #self.linear2 = nnx.Linear(2 * out_dim, out_dim, rngs=self.rngs)
    self.linear2 = bridge.ToNNX(nn.Dense(out_dim), rngs=self.rngs)

  def __call__(self, x):
    return self.linear2(jax.nn.tanh(self.linear1(x))), None

In [8]:
def init_fn(x, mod: nnx.Scan):
  return bridge.lazy_init(mod.scan_module, x)
 
smod = nnx.Scan.constructor(MyMod, length=5)(in_dim=10, out_dim=10, rngs=nnx.Rngs(0))
#nnx.vmap(init_fn, in_axes=(None, 0))(jnp.ones((1, 10)), smod)

bridge.lazy_init(smod, jnp.ones((1, 10)))
smod(jnp.ones((1, 10)))

(Array([[-0.08562934,  0.16581826,  0.07718924,  0.07391047, -0.11791413,
         -0.39341435, -0.3302102 ,  0.02588149, -0.15963459,  0.05993025]],      dtype=float32),
 None)

In [30]:
model = Transformer(config, mesh, quant=None, 
                    rngs=nnx.Rngs(default=0, params=0))
#with jax.profiler.trace("nnx_init"):
#  model(input_tokens, input_positions)
#model = Transformer(config, mesh, quant=None)

In [31]:
bridge.lazy_init(model, input_tokens, input_positions)

Transformer(config=<MaxText.pyconfig.HyperParameters object at 0x7fe661e263d0>, mesh=Mesh(device_ids=array([[[[[[[0]]]],



         [[[[1]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive')), quant=None, rngs=Rngs(
  default=RngStream(
    count=RngCount(
      tag='default',
      value=Array(3, dtype=uint32)
    ),
    key=RngKey(
      tag='default',
      value=Array((), dtype=key<fry>) overlaying:
      [0 0]
    )
  ),
  params=RngStream(
    count=RngCount(
      tag='params',
      value=Array(1, dtype=uint32)
    ),
    key=RngKey(
      tag='params',
      value=Array((), dtype=key<fry>) overlaying:
      [0 0]
    )
  )
))

In [32]:
model(input_tokens, input_positions)

Array([[[ 0.41741037, -1.0495998 , -0.6531233 , ..., -0.22075425,
         -1.859736  ,  0.42536822]]], dtype=float32)

In [34]:
state = nnx.state(model)

In [None]:
from flax.core.meta import Partitioned
jax.tree.map(lambda x: None if not isinstance(x, Partitioned) else x, state, is_leaf=lambda x: isinstance(x, Partitioned))

In [83]:
class LazyMod(nnx.Module):
  def __init__(self):
    self.p1 = nnx.Param(nnx.with_partitioning(jnp.ones, ("embed", "fsdp"))(100))

  def __call__(self, x):
    if not hasattr(self, "p2"):
      self.p2 = nnx.Param(x, sharding=("embed", "embed"))
    return self.p1 + (self.p2 * x)

In [72]:
@dataclasses.dataclass(unsafe_hash=True)
class MeshRules:
  embed: str | None = None
  mlp: str | None = None
  kv: str | None = None
  vocab: str | None = None

  def __call__(self, *keys: str) -> tuple[str, ...]:
    return tuple(getattr(self, key) for key in keys)
    
mesh_rules = MeshRules(embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor')

In [75]:
mesh_rules("embed", "mlp")


('fsdp', 'tensor')

In [79]:
p = nnx.Param(nnx.with_partitioning(jnp.zeros, ("fsdp", "tensor"))((100, 2)))

In [116]:
NNXRMSNorm(100, rngs=nnx.Rngs(time.time_ns())).scale

Param(
  value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),
  sharding=(),
  mesh=None
)

# Test Flax DictModule

In [59]:
class DictMmodule(nn.Module):
  def __init__(self, d: dict | None = None):
    if d is None:
      return

    assert isinstance(d, dict)
    for key, item in d.items():
      self[key] = item

  def __getitem__(self, key):
    return getattr(self, key)

  def __setitem__(self, key, value):
    setattr(self, key, value)

  def __repr__(self):
    out = "DictModule(\n"
    for attr in dir(self):
      if not attr.startswith("_"):
        line = str(getattr(self, attr))
        line = "\n".join("  " + row for row in line.strip().split("\n"))
        out += line + ",\n"
    out += ")"
    return out


class DictModuleV2(nn.Module):
  #def __init__(self, d: dict | None = None):
  #  super().__init__()
  #  self._setup(d)
  #  self.d = d
    
  def __getitem__(self, key):
    #return getattr(self, key)
    return self.get(key)

  def __setitem__(self, key, value):
    self._user_set.add(key)
    #setattr(self, key, value)
    self.set(key, value)
    
  def get(self, key):
    return self.d[key]

  def set(self, key, value):
    self.d[key] = value
    
  def setup(self, d: dict | None = None):
    self._setup(d)

  def _setup(self, d: dict | None = None):
    self.d = d
    self._user_set = set()
    #if d is None:
    #  return
    #assert isinstance(d, dict)
    #for key, item in d.items():
    #  #self[key] = item
    #  #self.set(key, item)
    #  self._user_set.add(key)
    
  def add_modules(self, d):
    self.d = d

  def __call__(self, key, *args, **kw):
    #return self[key](*args, **kw)
    return self.get(key)(*args, **kw)

  def __repr__(self):
    out = "DictModuleV2(\n"
    for key in self._user_set:
      #line = str(getattr(self, key))
      line = str(self.get(key))
      line = "\n".join("  " + f"{key}={row}" for row in line.strip().split("\n"))
      out += line + ",\n"
    out += ")"
    return out


In [60]:
mod = DictModuleV2()
mod.add_modules({"a": "hello"})
print(mod)

AttributeError: "DictModuleV2" object has no attribute "_user_set". If "_user_set" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [61]:
class FlaxModel(nn.Module):
  def setup(self): 
    self.mod_dict = DictModuleV2()
    self.mod_dict.add_modules({"a": nn.Dense(10), "b": nn.Dense(20)})

  def __call__(self, x):
    #return mod_dict["b"](jax.nn.tanh(mod_dict["a"](x)))
    return self.mod_dict("b", jax.nn.tanh(self.mod_dict("a", x)))

In [76]:
d = DictMmodule({"hi": nn.Dense(2)})
d["hello"] = nn.Dense(10)

# Continue with Transformer testing

In [4]:
model = Transformer(config, mesh, quant=None, 
                    rngs=nnx.Rngs(time.time_ns() % 2 ** 31))

In [5]:
bridge.lazy_init(model, input_tokens, input_positions)
out = model(input_tokens, input_positions)
print(out.shape)

(1, 1, 32000)


In [6]:
def get_model(input_tokens, input_positions):
  model = Transformer(config, mesh, quant=None, 
                      rngs=nnx.Rngs(time.time_ns() % 2 ** 31))
  bridge.lazy_init(model, input_tokens, input_positions)
  return nnx.split(model)

In [7]:
mod, params = jax.eval_shape(get_model, input_tokens, input_positions)

In [37]:
all_leaves = jax.tree.leaves(params, is_leaf=lambda x: isinstance(x, nnx.VariableState))
all_leaves = [x for x in all_leaves if not issubclass(x.type, nnx.RngState)]

In [41]:
all(hasattr(x, "sharding") for x in all_leaves)

True

In [70]:
model2 = Transformer2(config, mesh, quant=None)
params2 = model2.init(nnx.Rngs(default=0, params=0)(), input_tokens, input_positions)
out = model2.apply(params2, input_tokens, input_positions)
print(out.shape)

(1, 1, 32000)


In [30]:
@jax.jit
def get_model_def():
  model = Transformer(config, mesh, quant=None, 
                      rngs=nnx.Rngs(default=0, params=0))
  bridge.lazy_init(model, input_tokens, input_positions)
  #model(input_tokens, input_positions)
  return nnx.split(model)

In [31]:
gdef, _ = jax.eval_shape(get_model_def)
_, params = get_model_def()

In [13]:
#@functools.partial(jax.jit, static_argnums=(0,))
@jax.jit
def fwd_fn(gdef, state, *input):
  return nnx.merge(gdef, state)(*input)

In [17]:
with jax.log_compiles():
  fwd_fn(gdef, params, input_tokens, input_positions)

In [25]:
model(input_tokens, input_positions)

Array([[[ 0.90953326,  1.16376   ,  1.1164488 , ...,  0.5105222 ,
         -0.5272695 ,  0.1440109 ]]], dtype=float32)

In [4]:
@nnx.jit
def eval_param_shape():
  input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])
  with nn_partitioning.axis_rules(config.logical_axis_rules):
    model = Transformer(config, mesh, quant=None)
    model(input_tokens, input_positions)
  #return nnx.state(model)
  return nnx.split(model)

def get_model_def():
  input_tokens, input_positions = jnp.array([[0]]), jnp.array([[0]])
  with nn_partitioning.axis_rules(config.logical_axis_rules):
    model = Transformer(config, mesh, quant=None)
    model(input_tokens, input_positions)
  return nnx.graphdef(model)
  #return nnx.split(model)[1]

In [5]:
gdef, params = jax.eval_shape(eval_param_shape)
state_logical_annotations = nn.get_partition_spec(params)
state_mesh_shardings = nn.logical_to_mesh_sharding(
  state_logical_annotations, mesh, config.logical_axis_rules)



In [6]:
gdef

GraphDef(
  nodedef=NodeDef(
    type=Transformer,
    index=0,
    attributes=('config', 'decoder', 'mesh', 'quant', 'shared_embedding'),
    subgraphs={
      'decoder': NodeDef(
        type=LinenToNNX,
        index=1,
        attributes=('deterministic', 'initialized', 'linen_module', 'linen_state', 'rngs', 'use_running_average'),
        subgraphs={
          'linen_state': NodeDef(
            type=PytreeType,
            index=-1,
            attributes=('params',),
            subgraphs={
              'params': NodeDef(
                type=PytreeType,
                index=-1,
                attributes=('decoder_norm', 'layers', 'logits_dense'),
                subgraphs={
                  'decoder_norm': NodeDef(
                    type=PytreeType,
                    index=-1,
                    attributes=('scale',),
                    subgraphs={},
                    static_fields={},
                    leaves={
                      'scale': 2
                   

In [7]:
params = jax.jit(lambda: eval_param_shape()[1], out_shardings=state_mesh_shardings)()

In [8]:
params["shared_embedding"]

State({
  'embedding': VariableState(
    type=Param,
    value=LogicallyPartitioned(value=Array([[ 0.0575558 , -0.31785473,  0.11529455, ...,  1.0821939 ,
             1.4235774 , -1.2933688 ],
           [-2.0068665 , -0.06486757,  0.1310754 , ..., -1.5467689 ,
             0.37397835,  0.41232687],
           [-0.57422966,  0.1731033 ,  0.9584525 , ...,  0.07480869,
             0.15087242,  0.41225332],
           ...,
           [ 0.7054784 , -0.4994459 ,  0.07542419, ..., -1.2780907 ,
            -0.12462003,  0.4509493 ],
           [ 1.3809816 , -1.2765152 ,  0.77147233, ...,  1.7020334 ,
             0.6716798 , -0.24864346],
           [ 1.4495107 ,  0.41864708,  1.412156  , ..., -1.0488809 ,
             0.12066022,  1.5232936 ]], dtype=float32), names=('fsdp', 'embed'), mesh=None, rules=None)
  )
})

In [9]:
params2 = jax.tree.map(lambda x: x.unbox() if isinstance(x, meta.Partitioned) else x, params, is_leaf=lambda x: isinstance(x, meta.Partitioned))

In [10]:
fwd_fn = jax.jit(lambda p, *inputs: nnx.merge(gdef, p)(*inputs))

In [15]:
gdef

GraphDef(
  nodedef=NodeDef(
    type=Transformer,
    index=0,
    attributes=('config', 'decoder', 'mesh', 'quant', 'shared_embedding'),
    subgraphs={
      'decoder': NodeDef(
        type=LinenToNNX,
        index=1,
        attributes=('deterministic', 'initialized', 'linen_module', 'linen_state', 'rngs', 'use_running_average'),
        subgraphs={
          'linen_state': NodeDef(
            type=PytreeType,
            index=-1,
            attributes=('params',),
            subgraphs={
              'params': NodeDef(
                type=PytreeType,
                index=-1,
                attributes=('decoder_norm', 'layers', 'logits_dense'),
                subgraphs={
                  'decoder_norm': NodeDef(
                    type=PytreeType,
                    index=-1,
                    attributes=('scale',),
                    subgraphs={},
                    static_fields={},
                    leaves={
                      'scale': 2
                   

In [14]:
fwd_fn(params2, input_tokens, input_positions)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[32000,2048] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was jit_fn at /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/transforms/transforms.py:139 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning.<locals>.wrapper). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:88:2 (_graph_node_meta_call)
/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/nnx/nnx/object.py:82:4 (ObjectMeta._object_meta_construct)
<string>:11:2 (__create_fn__.<locals>.__init__)
/home/rdyro/maxtext/MaxText/nnx_layers/embeddings.py:65:12 (Embed.__post_init__)
/home/rdyro/.pyenv/versions/3.11.9/envs/devel/lib/python3.11/site-packages/flax/linen/spmd.py:361:6 (with_logical_partitioning.<locals>.wrapper)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [None]:
#nnx.merge(get_model_def(), params2)(input_tokens, input_positions)
with jax.checking_leaks():
  fwd_fn(params2, input_tokens, input_positions)

In [19]:
jax.debug.visualize_array_sharding(params2["shared_embedding"]["embedding"].value)

In [16]:
jax.debug.visualize_array_sharding(params2["decoder"]["linen_state"]["params"]["layers"]["mlp"]["wi_0"]["kernel"].value[:, 0, ...])

In [14]:
params_shape = jax.eval_shape(eval_param_shape)



In [15]:
params_shape

State({
  'decoder': {
    'linen_state': {
      'params': {
        'decoder_norm': {
          'scale': VariableState(
            type=Param,
            value=ShapeDtypeStruct(shape=(2048,), dtype=float32)
          )
        },
        'layers': {
          'mlp': {
            'wi_0': {
              'kernel': VariableState(
                type=Param,
                value=ShapeDtypeStruct(shape=(2048, 16, 7168), dtype=float32)
              )
            },
            'wi_1': {
              'kernel': VariableState(
                type=Param,
                value=ShapeDtypeStruct(shape=(2048, 16, 7168), dtype=float32)
              )
            },
            'wo': {
              'kernel': VariableState(
                type=Param,
                value=ShapeDtypeStruct(shape=(7168, 16, 2048), dtype=float32)
              )
            }
          },
          'post_self_attention_layer_norm': {
            'scale': VariableState(
              type=Param,
              v

In [5]:
model.decoder.linen_state

{'params': {'decoder_norm': {'scale': LogicallyPartitioned(value=Param(
     value=Array([1., 1., 1., ..., 1., 1., 1.], dtype=float32)
   ), names=('norm',), mesh=None, rules=None)},
  'layers': {'mlp': {'wi_0': {'kernel': LogicallyPartitioned(value=Param(
       value=Array([[[-1.46824829e-02,  1.97991189e-02,  4.92318161e-02, ...,
                -2.27814689e-02, -1.77739235e-03,  7.30113918e-03],
               [ 6.71353471e-03,  2.53547207e-02,  2.49213744e-02, ...,
                 1.92197636e-02,  4.14784951e-03, -3.55075486e-02],
               [ 2.83581223e-02,  3.04420362e-03, -8.07493739e-03, ...,
                 6.35656202e-03, -3.63132101e-04, -8.82705022e-03],
               ...,
               [ 6.87895669e-03,  9.28692240e-03,  1.86208095e-02, ...,
                 1.10543484e-03,  1.05740549e-02,  2.69666575e-02],
               [ 3.79684679e-02, -4.19698209e-02, -4.50059175e-02, ...,
                -2.73766350e-02,  4.00221422e-02, -1.40723391e-02],
               [ 

In [8]:
nnx.state(model.shared_embedding)

State({
  'embedding': VariableState(
    type=Param,
    value=Array([[ 0.0575558 , -0.31785473,  0.11529455, ...,  1.0821939 ,
             1.4235774 , -1.2933688 ],
           [-2.0068665 , -0.06486757,  0.1310754 , ..., -1.5467689 ,
             0.37397835,  0.41232687],
           [-0.57422966,  0.1731033 ,  0.9584525 , ...,  0.07480869,
             0.15087242,  0.41225332],
           ...,
           [ 0.7054784 , -0.4994459 ,  0.07542419, ..., -1.2780907 ,
            -0.12462003,  0.4509493 ],
           [ 1.3809816 , -1.2765152 ,  0.77147233, ...,  1.7020334 ,
             0.6716798 , -0.24864346],
           [ 1.4495107 ,  0.41864708,  1.412156  , ..., -1.0488809 ,
             0.12066022,  1.5232936 ]], dtype=float32)
  ),
  'rngs': {
    'default': {
      'count': VariableState(
        type=RngCount,
        value=Array(2, dtype=uint32),
        tag='default'
      ),
      'key': VariableState(
        type=RngKey,
        value=Array((), dtype=key<fry>) overlaying:
   

In [21]:
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = train.setup_mesh_and_model(pyconfig.config)
with jax.profiler.trace("linen_init"):
  params = model.init(init_rng, input_tokens, input_positions)

Num_devices: 2, shape (1, 1, 2, 1, 1, 1, 1)
Setting up checkpoint logger...
Creating checkpoint manager...
Checkpoint manager created!


In [27]:
params["params"]["token_embedder"]

{'embedding': LogicallyPartitioned(value=Array([[ 1.4774086 ,  0.29882422, -0.34196898, ...,  0.50766706,
         -1.3415289 ,  1.7924749 ],
        [ 0.47449157, -1.9720819 ,  0.38987246, ...,  1.4013296 ,
         -0.44650635,  0.5926745 ],
        [-0.74953204, -1.6305867 , -0.9119016 , ...,  0.9319291 ,
         -0.24040265,  0.36705223],
        ...,
        [ 1.0246398 ,  0.14891908,  1.2458084 , ..., -0.3720324 ,
          1.6434435 , -0.6255694 ],
        [-1.3611012 , -1.3714991 ,  0.5478506 , ...,  3.409998  ,
         -0.0136618 ,  0.4574892 ],
        [-0.4890293 , -1.4788411 ,  0.851169  , ...,  0.9657814 ,
         -0.2644767 ,  1.1933228 ]], dtype=float32), names=('vocab', 'embed'), mesh=None, rules=None)}

In [5]:
model.train()

In [6]:
state = nnx.state(model)

In [8]:
model(input_tokens, input_positions)

Update keys: []
Update keys: []


Array([[[ 0.3628558 , -1.0749931 ,  0.5897113 , ...,  0.47235385,
         -1.8449324 ,  1.3828944 ]]], dtype=float32)

In [9]:
@functools.partial(jax.jit, static_argnums=(0,))
def fwd_fn(gdef, state, *inputs):
  return nnx.merge(gdef, state)(*inputs)

In [12]:
fwd_fn(*nnx.split(model), input_tokens, input_positions)

Array([[[ 0.36080366, -1.0749974 ,  0.59387594, ...,  0.46911016,
         -1.8451215 ,  1.3860557 ]]], dtype=float32)

In [30]:
rng = nnx.Rngs(time.time_ns() % 2 ** 31)()
params = model.init(rng, input_tokens, input_positions)

In [8]:
params_flat = jax.tree.flatten(params)[0]
pprint({i: x.shape for i, x in enumerate(params_flat)})

{0: (2048,),
 1: (2048, 16, 7168),
 2: (2048, 16, 7168),
 3: (7168, 16, 2048),
 4: (2048, 16),
 5: (2048, 16),
 6: (2048, 16, 16, 128),
 7: (16, 16, 128, 2048),
 8: (2048, 16, 16, 128),
 9: (2048, 16, 16, 128),
 10: (2048, 32000),
 11: (32000, 2048)}


# Sharding

In [14]:
mesh = sharding.Mesh(jax.devices("cpu"), ("x",))
shard = sharding.NamedSharding(mesh, sharding.PartitionSpec("x", None))

In [29]:
arr = nnx.with_partitioning(nnx.initializers.ones, shard)(nnx.Rngs(0), (128, 16))
arr = nnx.Variable(jnp.ones((128, 20)))

In [49]:
o = nn.with_logical_partitioning(jnp.zeros, ("x", None), mesh)

# LinenToNNX

In [12]:
import jaxfi as jaxm
mod = LinenToNNX(nn.BatchNorm(use_running_average=False), rngs=nnx.Rngs(0))
y = mod(jaxm.randn((10, 100)))

In [11]:
r = jaxm.randn((1, 100))
%timeit y = mod(r)

296 μs ± 8.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
old_batch_stats = mod.state["batch_stats"]
print(old_batch_stats)
y = mod(jaxm.randn((10, 100)))
print(mod.state["batch_stats"])

{'mean': Variable(
  value=Array([-2.0821048e-03, -6.8043353e-04,  2.7525767e-03,  9.4063886e-05,
         -4.5324411e-04, -4.6478841e-03,  1.6677739e-03,  5.1256125e-03,
         -1.7300450e-03, -1.4472724e-03,  1.0186685e-03,  4.4281147e-03,
         -5.8264798e-04,  4.9753045e-03, -4.9023994e-04,  3.1962610e-04,
          3.4446979e-04, -1.7976237e-03,  3.2958160e-03, -2.7488766e-03,
         -7.9713884e-04, -4.9346022e-04,  3.1748249e-03,  2.1695246e-03,
         -1.1033370e-03,  6.7661819e-04, -9.1823353e-04,  4.7784806e-03,
         -2.0047871e-03,  5.3438372e-03,  9.5874531e-04,  2.3210477e-03,
          2.5819831e-03,  3.6093504e-03,  9.5296197e-04,  2.7141403e-03,
          1.1053076e-03, -1.9327267e-03, -2.2280891e-03,  2.1749218e-03,
          3.4572810e-03, -2.1481735e-03,  5.5041173e-03,  1.5461477e-03,
         -4.6485366e-04, -1.8874716e-03,  1.0145564e-03,  7.3133651e-03,
          1.3833139e-03, -4.3422944e-04,  3.4693078e-04,  1.7135066e-03,
          2.3201664e-03, -

In [9]:
mod.eval()

In [8]:
model = nnx.Linear(100, 100, rngs=nnx.Rngs(0))
opt = nnx.Optimizer(model, optax.adam(1e-5))
grads = nnx.split(model)[1]
grads = nnx.grad(lambda model: jnp.mean(model(jaxm.randn((4, 100)))))(model)
opt.update(grads)

In [3]:
#os.environ.setdefault("PROCESS_ID", "0")
#os.environ.setdefault("JAX_PROCESS_COUNT", "1")
#os.environ.setdefault("PROCESS_IN_JOB", "0")
#os.environ.setdefault("JAX_COORDINATOR_ADDRESS", "127.0.0.1")
#os.environ.setdefault("JAX_COORDINATOR_IP", "127.0.0.1")
#os.environ.setdefault("JAX_COORDINATOR_PORT", str(65323))
#os.environ.setdefault("NNODES", str(1))
#os.environ.setdefault("NODE_RANK", str(0))
#
##os.environ["COORDINATOR_ADDRESS"] = "127.0.0.1:1234"
#os.environ.setdefault("JOB_INDEX", "0")
#os.environ.setdefault("JOB_COMPLETION_INDEX", "0")
#os.environ["PROCESSES_IN_JOB"] = "1"
##os.environ["JAX_PROCESS_COUNT"] = "1"
pyconfig.initialize(["python3", "MaxText/configs/base.yml", "hardware=other", "enable_single_controller=True"])
config = dict(pyconfig.config.get_keys())

init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = train.setup_mesh_and_model(pyconfig.config)
input_tokens = jnp.array([[0]])
input_positions = jnp.array([[0]])
params = model.init(init_rng, input_tokens, input_positions)

Updating keys from env and command line: ['hardware', 'enable_single_controller']
Running Model: default
Updating keys from model: []
Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period
dataset_type set to tfds, will use keys['dataset_path']='' and keys['dataset_name']='c4/en:3.0.1'
Config param adam_b1: 0.9
Config param adam_b2: 0.95
Config param adam_eps: 1e-08
Config param adam_eps_root: 0.0
Config param adam_weight_decay: 0.1
Config param add_bos: True
Config param add_eos: True
Config param allow_split_physical_axes: False
Config param ar_cache_axis_order: 1,2,0,3
Config param async_checkpointing: True
Config param attention: autoselected
Config param attention_type: global
Config param attn_logits_soft_cap: None
Config param autoregressive_decode_assert: 
Config param base_emb_dim: 2048
Config param base_mlp_dim: 7168
Config param base_num_decoder_layers: 16
Config param base_num_kv_heads: 16
Config param base_num_query_heads: 16
Config

In [21]:
mod.state

{'batch_stats': {'mean': Variable(
    value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
  ),
  'var': Variable(
    value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 

In [26]:
mod.train()

In [None]:
nnx.Dropout

In [None]:
nnx.BatchNorm

In [25]:
nnx.display(mod)

LazyNNX(
  init_args=(),
  initialized=True,
  linen_mod=BatchNorm(
      # attributes
      use_running_average = True
      axis = -1
      momentum = 0.99
      epsilon = 1e-05
      dtype = None
      param_dtype = float32
      use_bias = True
      use_scale = True
      bias_init = zeros
      scale_init = ones
      axis_name = None
      axis_index_groups = None
      use_fast_variance = True
      force_float32_reductions = True
  ),
  state={'batch_stats': {'mean': Variable(
    value=Array(shape=(100,), dtype=float32)
  ), 'var': Variable(
    value=Array(shape=(100,), dtype=float32)
  )}, 'params': {'bias': Param(
    value=Array(shape=(100,), dtype=float32)
  ), 'scale': Param(
    value=Array(shape=(100,), dtype=float32)
  )}}
)


In [8]:
@functools.partial(jax.jit, static_argnames=["graphdef"])
def loss_fn(graphdef, params, *args):
  return jnp.sum(nnx.merge(graphdef, params)(*args))

In [9]:
graphdef, params = nnx.split(mod)
loss_fn(graphdef, params, jnp.ones(100))
grad_fn = jax.jit(jax.grad(loss_fn, argnums=1))
grad_fn(graphdef, params, jnp.ones(100))

State({
  'state': {
    'params': {
      'bias': VariableState(
        type=Param,
        value=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)
      ),
      'kernel': VariableState(
        type=Param,
        value=Array([[1., 1., 1., ..., 1., 1., 1.],
               [1., 1., 1., ..., 1., 1., 1.],
               [1., 1., 1., ..., 1., 1., 1.],
               ...,
               [1., 1., 1., ..., 1., 1., 1.],
               [1., 1., 1., ..., 1., 1., 1.],
               [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)
      )
    }
  }
}

In [11]:
@functools.partial(nnx.vmap, axis_size=5, in_axes=(None, None))
def make_model(rngs: nnx.Rngs, x: jax.Array):
  mod = LazyNNX(nn.Dense, 100, rngs=rngs)
  y = mod(x) # run the model in accordance with linen's convention (nn.compact)
  return mod

In [19]:
vmap_mod = make_model(nnx.Rngs(0), jnp.ones((1, 100)))

In [23]:
nnx.scan(lambda x, mod: (jax.nn.tanh(mod(x)), None))(
  jnp.ones((1, 100)),
  vmap_mod, 
)

(Array([[-0.22681893,  0.12948306, -0.5660833 ,  0.54388836, -0.13207269,
         -0.28127185,  0.13340358,  0.21858714,  0.0342543 ,  0.40834817,
          0.13434777, -0.4555197 ,  0.06381328, -0.58547539, -0.76521007,
          0.16313892, -0.26837662, -0.54795653,  0.18024976,  0.34415041,
         -0.13140569,  0.56270905, -0.10050073,  0.38099573, -0.37621883,
          0.08581573,  0.47219038, -0.55236467, -0.43320429, -0.39558568,
          0.09934793,  0.15027676, -0.07010986, -0.22549472, -0.08241178,
          0.15171293, -0.33565122, -0.5415724 , -0.44550078, -0.02126699,
         -0.00690109, -0.00483307,  0.01377263,  0.22923029, -0.30733769,
          0.13028407, -0.23582329, -0.06795412,  0.17091568, -0.48275513,
          0.45854019, -0.15280847,  0.69775289, -0.27580605,  0.41420138,
          0.4121478 , -0.21960127, -0.3110823 ,  0.09149418,  0.60862518,
          0.15215905, -0.7924634 ,  0.27823312, -0.56006468, -0.19038152,
         -0.40971096, -0.07249114, -0.

In [54]:
X = jaxm.randn((10 ** 3,), dtype=jaxm.float32)
y = jaxm.sin(X)

@functools.partial(jax.jit, static_argnums=(0,))
def loss_fn(graphdef, params, x, y):
  model = nnx.merge(graphdef, params)
  x = x[..., None] + jnp.zeros(100)
  yp = nnx.scan(lambda x, mod: (jax.nn.tanh(mod(x)), None))(
    x,
    model, 
  )[0]
  return jaxm.mean((yp[..., 0] - y) ** 2)
  
@functools.partial(jax.jit, static_argnums=(0,))
def grad_fn(graphdef, params, x, y):
  return jax.grad(loss_fn, argnums=1)(graphdef, params, x, y)

In [55]:
graphdef, params = nnx.split(vmap_mod)
loss_fn(graphdef, params, X, y)
optimizer = optax.adam(1e-4)
opt_state = optimizer.init(params)
optimizer_update = jax.jit(optimizer.update)
optax_apply_updates = jax.jit(optax.apply_updates)
value_and_grad = jax.jit(jax.value_and_grad(loss_fn, argnums=1))
for _ in range(1000):
  l, gs = value_and_grad(graphdef, params, X, y)
  updates, opt_state = optimizer.update(gs, opt_state)
  params = optax_apply_updates(params, updates)
  print(f"loss = {l:.4e}")

loss = 7.5165e-01
loss = 6.2398e-01
loss = 5.0591e-01
loss = 4.0034e-01
loss = 3.0931e-01
loss = 2.3373e-01
loss = 1.7333e-01
loss = 1.2687e-01
loss = 9.2456e-02
loss = 6.7913e-02
loss = 5.1097e-02
loss = 4.0097e-02
loss = 3.3329e-02
loss = 2.9542e-02
loss = 2.7792e-02
loss = 2.7382e-02
loss = 2.7813e-02
loss = 2.8731e-02
loss = 2.9894e-02
loss = 3.1132e-02
loss = 3.2337e-02
loss = 3.3435e-02
loss = 3.4382e-02
loss = 3.5152e-02
loss = 3.5733e-02
loss = 3.6122e-02
loss = 3.6320e-02
loss = 3.6336e-02
loss = 3.6179e-02
loss = 3.5860e-02
loss = 3.5394e-02
loss = 3.4793e-02
loss = 3.4072e-02
loss = 3.3246e-02
loss = 3.2330e-02
loss = 3.1339e-02
loss = 3.0288e-02
loss = 2.9192e-02
loss = 2.8068e-02
loss = 2.6928e-02
loss = 2.5789e-02
loss = 2.4666e-02
loss = 2.3570e-02
loss = 2.2517e-02
loss = 2.1518e-02
loss = 2.0584e-02
loss = 1.9723e-02
loss = 1.8945e-02
loss = 1.8253e-02
loss = 1.7652e-02
loss = 1.7141e-02
loss = 1.6719e-02
loss = 1.6381e-02
loss = 1.6119e-02
loss = 1.5925e-02
loss = 1.5

In [42]:
print(nnx.split(vmap_mod)[1]["state"]["params"]["bias"].value.shape)
print(nnx.split(vmap_mod)[1]["state"]["params"]["kernel"].value.shape)

(100,)
(50, 100)


In [26]:
@nnx.jit
def gen_layer(r_in):
  return nnx.Linear(100, 100, rngs=r_in)
  
rng = rng
gen_layer().kernel.value

Array([[ 0.13516979, -0.05105809,  0.07632802, ...,  0.0749987 ,
        -0.05567247, -0.05779112],
       [-0.00094059,  0.04644845, -0.21303476, ...,  0.10651965,
        -0.08589667, -0.04726676],
       [-0.12583305,  0.17015733,  0.04022592, ..., -0.14773165,
        -0.0417266 , -0.00611958],
       ...,
       [ 0.08767611,  0.11031242, -0.05278822, ..., -0.02778996,
        -0.13942076,  0.12322947],
       [-0.00845618,  0.15537211,  0.13054934, ...,  0.17835702,
         0.10111373,  0.16907702],
       [ 0.05105361, -0.06231965,  0.04937202, ...,  0.07263377,
        -0.12006826, -0.04576119]], dtype=float32)

In [42]:
rngs = nnx.Rngs(0, bias=0)

In [53]:
stream = rngs.get("default")

In [36]:
rkey = stream.key

In [111]:
@jax.jit
def get_randn(rngs):
  return nn.initializers.normal(1)(rngs, 100)

In [349]:
@functools.partial(jax.tree_util.register_dataclass, data_fields=["r"], meta_fields=[])
@dataclasses.dataclass
class RngWrapper:
  r: nnx.Rngs
  def __call__(self, key : str="default"):
    return self.r.get(key)() if isinstance(self.r, nnx.Rngs) else self.r

In [239]:
r = RngWrapper(rngs)

In [346]:
@dataclasses.dataclass
class MyModel(nnx.Module):
  in_features: int
  out_features: int
  rngs: RngWrapper = None
  
  def __post_init__(self):
    shape = (self.in_features, self.out_features)
    #if isinstance(self.rngs, nnx.Rngs):
    #  self.kernel = nnx.initializers.lecun_normal()(self.rngs.get("default")(), shape)
    #  nnx.Linear
    #else:
    #  self.kernel = nnx.initializers.lecun_normal()(self.rngs, shape)
    self.kernel = nnx.initializers.lecun_normal()(self.rngs(), shape)
    #self.kernel = nnx.initializers.lecun_normal()(self.rngs("default"), shape)
    #self.kernel = nnx.initializers.lecun_normal()(self.rngs("default"), shape)
    #self.kernel = nnx.initializers.lecun_normal()(self.rngs, shape)
    
  def __call__(self):
    return self.kernel
  

In [348]:
MyModel(100, 100, RngWrapper(nnx.Rngs(0)))

TypeError: expected 0 arguments, got 1

In [345]:
@jax.jit
def create_linear(rngs):
  return nnx.Linear(100, 100, rngs=rngs)

In [316]:
t = time.time()
mod = nnx.vmap(lambda i, o, k: nnx.Linear(i, o, rngs=k), axis_size=50000, in_axes=(None, None, None))(100, 100, rngs)
t = time.time() - t
print(f"{t = :.4e}")

t = 2.8158e+00


In [337]:
graphdef, params = nnx.split(mod)
fn = jax.jit(nnx.scan(lambda x, p: (nnx.merge(graphdef, p)(x), None)))
#(jnp.zeros(100), params)

In [343]:
nnx.bridge

AttributeError: module 'flax.nnx' has no attribute 'bridge'

In [341]:
fn(jnp.zeros(100), params)

(Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float64),
 None)

In [211]:
under_jit(r)

Array([[ 3.55973911e-01, -2.37477006e-01, -9.74529953e-02,
        -3.39887647e-01,  6.99265619e-01,  4.49438212e-01,
        -1.17788739e-01,  4.91785924e-01,  1.43185656e-01,
        -1.35137965e-01],
       [ 2.11869889e-01, -4.85307116e-01, -1.85187564e-02,
         1.35506714e-01,  3.02765362e-02,  9.85176910e-02,
        -8.09381560e-02,  2.50044546e-01,  9.01392114e-03,
        -2.77441520e-02],
       [ 6.12978399e-01, -5.05180316e-01, -1.20650054e-01,
         3.75843736e-01,  2.17106242e-01,  1.02653012e-02,
         5.78397549e-01, -4.85449128e-01,  2.70678648e-01,
         2.44499461e-01],
       [-4.31290899e-01,  8.96568564e-02, -5.51868309e-01,
        -2.96766079e-01, -7.05394023e-01,  2.57234510e-01,
        -2.29387100e-01, -7.16125931e-01, -6.00397803e-01,
        -5.67827435e-01],
       [ 2.23155164e-01, -1.99485335e-01, -2.55706195e-01,
         4.28001152e-01, -3.65027100e-01, -1.38338403e-01,
        -5.10901823e-01,  4.56128621e-01,  5.56484914e-01,
        -5.

In [129]:
get_randn(rngs)

Array([-1.39944734e+00,  1.42729388e-01, -2.04663887e+00,  6.84886398e-01,
        7.53044294e-01,  1.75674106e+00, -2.62389780e-01, -4.68049738e-01,
       -4.79173378e-01,  1.64071374e-03,  1.82383563e+00, -2.44221580e+00,
       -1.25232444e+00,  3.04795824e-01,  1.15632370e+00,  7.60857420e-01,
       -1.33744763e+00, -1.77406268e+00,  5.06903744e-01,  1.72165095e+00,
        1.41896385e+00,  9.19465409e-01, -5.67572806e-01,  1.10987088e+00,
       -4.98479686e-01,  1.89477152e+00, -1.52960825e-01, -1.53472124e-01,
        2.82643017e+00, -2.21983111e+00,  2.00269146e-01,  1.09333757e+00,
       -1.66862222e+00, -8.56525357e-01, -1.02235849e+00, -7.24134529e-01,
        4.48094968e-01, -7.93246631e-01,  1.40813146e-01, -2.54582938e+00,
       -1.13855538e+00, -4.76501459e-01, -3.93253085e-01, -1.18342087e+00,
        2.13279447e+00, -1.40168059e+00,  6.49652544e-03, -4.12427051e-01,
       -6.57155787e-01, -1.19050183e+00,  6.96635743e-01,  2.27212476e-01,
       -3.50703119e-01,  

In [61]:
y1 = call_linen_model(layer, x)
y2 = nnx_layer(x)
err = jnp.linalg.norm(y1 - y2) / (jnp.linalg.norm(y1) + 1e-7)
print(f"{err = :.4e}")

err = 0.0000e+00


In [6]:
params = layer.init(jrandom.key(time.time_ns() % 2 ** 31), x)

In [17]:
nnx_layer.train()

In [25]:
jax.jit
def grad(graphdef, state, *args):
  gs = jax.grad(lambda state: jnp.sum(nnx.merge(graphdef, state)(*args)))(state)
  return gs

In [29]:
optimizer = optax.adam(1e-5)

In [32]:
opt_state = optimizer.init(nnx.split(nnx_layer)[1])
optimizer_update = jax.jit(optimizer.update)

In [50]:
params = nnx.split(nnx_layer)[1]

@jax.jit
def learning_step(graphdef, params, *args, opt_state=None):
  if opt_state is None:
    opt_state = optimizer.init(params)
  gs = grad(graphdef, params, x)
  updates, opt_state = optimizer.update(gs, opt_state)
  new_params = optax.apply_updates(params, updates)
  return new_params, opt_state

In [54]:
graphdef, params = nnx.split(nnx_layer)
opt_state = None
for _ in range(10000):
  params, opt_state = learning_step(graphdef, params, x, opt_state=opt_state)