Skip to content

NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding#3427

Open
ecnal-cienet wants to merge 1 commit intomainfrom
feat/pure_nnx_flag_and_init_state_fn
Open

NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding#3427
ecnal-cienet wants to merge 1 commit intomainfrom
feat/pure_nnx_flag_and_init_state_fn

Conversation

@ecnal-cienet
Copy link
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 16, 2026

NNX Migration Route Map

  1. ✅ Add NNX support function and utils / pure_nnx flag=False. Won't affect current Linen workflow.
  2. ❌ NNX fully supported. pure_nnx flag=False, but user can do NNX runs/tests.
  3. ❌ NNX unit tests and performance tests are completed and verified. pure_nnx flag=True.
  4. ❌ Remove Linen code related NNX flags.

Description

Note: This is the first in a series of NNX migration PRs. Pure NNX training is not yet implemented — all NNX code paths currently raise NotImplementedError. This PR only introduces the structural scaffolding needed for subsequent patches to plug in NNX logic without modifying shared infrastructure.

This PR introduces two abstractions that enable incremental NNX migration while keeping existing Linen code fully functional.

pure_nnx config flag

A boolean (configs/base.yml, configs/types.py) that will route all major code paths — training, compilation, inference, RL, and utilities — to pure-NNX logic when True, falling back to Linen otherwise. Defaults to False so all existing behaviour is unchanged.

init_state_fn

A pluggable callable for initializing the model training state, threaded through create_checkpoint_manager, setup_training_state, setup_decode_state, and get_abstract_state. This decouples state
initialization from shared infrastructure so future NNX and Linen paths can provide their own implementations without forking utilities.

Other structural changes

  • create_training_tools is split into create_training_optimizer and create_checkpoint_manager for cleaner separation of concerns.
  • jit_train_step gains a mesh parameter to accommodate NNX callers where model is a GraphDef with no .mesh attribute.
  • get_shaped_inputs in train_compile.py adds a pure_nnx branch that omits example_rng, matching the future NNX train_step(state, batch) signature.
  • get_first_step restored to the two-argument (model, state) form to support both Linen TrainState and NNX TrainStateNNX step retrieval.
  • All entry points (grpo_trainer, maxengine, generate_param_only_checkpoint, layerwise_quantization, lora_utils, standalone_checkpointer, integration tests) updated to accept and pass init_state_fn.

New files

  • src/maxtext/layers/train_state_nnx.py — NNX TrainStateNNX container wrapping nnx.Module + nnx.Optimizer (mirrors Linen TrainState).
  • src/maxtext/utils/maxtext_utils_nnx.py — NNX-specific utilities: abstract state, named sharding helpers, and sharded model creation.

Lint / test fixes

  • maxtext_utils.py — remove unused ShardMode import.
  • maxtext_utils_test.py — fix duplicate Any import; restore Transformer alias; update get_abstract_state call to new (config, mesh, init_state_fn) signature.
  • sharding_compare_test.py — add pure_nnx=False/enable_nnx=False/ pure_nnx_decoder=False to config params; update get_abstract_state and get_logical_annotations calls to new API.
  • state_dtypes_test.py — update get_abstract_state call to new API.

Tests

python3 -m pytest tests/unit/train_utils_test.py -v
python3 -m pytest tests/unit/train_compile_test.py -v
python3 -m pytest tests/unit/maxtext_utils_test.py -v
python3 -m pytest tests/unit/state_dtypes_test.py -v
python3 -m pytest tests/unit/sharding_compare_test.py -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/pure_nnx_flag_and_init_state_fn branch 6 times, most recently from b76f427 to 3fc5161 Compare March 17, 2026 15:14
@ecnal-cienet ecnal-cienet force-pushed the feat/pure_nnx_flag_and_init_state_fn branch 12 times, most recently from ed69422 to cd34a4d Compare March 18, 2026 19:20
@ecnal-cienet ecnal-cienet changed the title NNX: add pure_nnx flag and init_state_fn for NNX/Linen co-existence NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding Mar 18, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/pure_nnx_flag_and_init_state_fn branch from cd34a4d to 77ba3df Compare March 18, 2026 20:11
@charlesli640 charlesli640 mentioned this pull request Mar 19, 2026
4 tasks
@ecnal-cienet ecnal-cienet force-pushed the feat/pure_nnx_flag_and_init_state_fn branch from 77ba3df to ac4408a Compare March 19, 2026 19:12
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
@ecnal-cienet ecnal-cienet force-pushed the feat/pure_nnx_flag_and_init_state_fn branch from ac4408a to 97e2ab7 Compare March 20, 2026 18:35
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding NNX migration prep (2/N): pure_nnx flag and init_state_fn scaffolding Mar 20, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (2/N): pure_nnx flag and init_state_fn scaffolding NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding Mar 20, 2026
@ecnal-cienet ecnal-cienet marked this pull request as ready for review March 20, 2026 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants