Skip to content

NNX train#3442

Draft
charlesli640 wants to merge 5 commits intoAI-Hypercomputer:mainfrom
CIeNET-International:charlesli/nnx_train
Draft

NNX train#3442
charlesli640 wants to merge 5 commits intoAI-Hypercomputer:mainfrom
CIeNET-International:charlesli/nnx_train

Conversation

@charlesli640
Copy link
Collaborator

Description

Implement pre-train using NNX style

Tests

python3 src/maxtext/trainers/pre_train/nnx_train.py src/maxtext/configs/base.yml \
run_name="run_llama2_7b" \
model_name="llama2-7b" \
dataset_type=synthetic \
steps=10 \
scan_layers=True \
debug_sharding=True \
async_checkpointing=False \
remat_policy=full \
checkpoint_storage_use_zarr3=false \
enable_checkpointing=false \
enable_nnx=true \
pure_nnx_decoder=true

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 18, 2026

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @charlesli640. Is it possible to migrate train.py directly instead of forking this logic? I am concerned these two files will get out of sync before they are merged, this setup will skip running unit tests, and it makes the code a bit more complicated/harder to follow

@charlesli640
Copy link
Collaborator Author

Thank you @charlesli640. Is it possible to migrate train.py directly instead of forking this logic? I am concerned these two files will get out of sync before they are merged, this setup will skip running unit tests, and it makes the code a bit more complicated/harder to follow

Definitely we can move the logic to train.py and make it controlled by enable_nnx config. Actually this is one of experimental solutions I am doing internally - try to create brand-new pre-train using pure NNX style, leaving old linen style pre-train untouched/co-existing.

Another solution is submitted on PR #3427. This solution tries to keep/re-use current linen style train loop as much as possible. It created TrainStateNNX class and make existing linen functions compatible to both linen model and nnx model. Please also review the PR3427. We can discuss more on which direction we are going.

@charlesli640 charlesli640 marked this pull request as draft March 19, 2026 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants