feat: migrate pipeline to nnx#2885
feat: migrate pipeline to nnx#2885mesakhcienet wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
6875da8 to
f34b1a3
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
12a3907 to
2c16599
Compare
64dc147 to
9e4518e
Compare
631a73e to
ac97a1d
Compare
1849f0b to
669dc01
Compare
2d742f9 to
fc3fe0b
Compare
2e46721 to
b732cb3
Compare
618de58 to
e7656b2
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
@gobbleturk what testing do you recommend for migrating pipeline parallelism to NNX? I'll send over an internal doc @hsuan-lun-chiang, @mesakhcienet, and others put together that shows the tests they have already run
@NuojCheng any thoughts here? |
NuojCheng
left a comment
There was a problem hiding this comment.
Some additional train compile test for pipeline NNX migration:
- Train compile test 1: https://paste.googleplex.com/5960957017849856
- Train compile test 2: https://paste.googleplex.com/5749974483730432
- Train compile test 3: https://paste.googleplex.com/5201745681711104
If the train compile tests above can pass without getting OOM + current tests in pipeline_parallelism_test.py can all pass, then I think it is good to go! Please ping me if the PR is ready for review.
|
There are also some linen usage in
I don't see them get updated in this PR but I think they probably should be updated? Another thing is the usage of function in maxtext/src/maxtext/utils/pipeline_utils.py Lines 151 to 162 in 77f5334 |
As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks! |
Shouldn't we have a nnx version of functions in pipeline_utils.py as well? |
Are we able to bridge the NNX version back to Linen at a higher layer? If so, then I think we could get rid of the old Linen code that is no longer used |
@bvandermoon Option 1: If we use Option 2: Delay the full migration until Please let me know which of these two solutions you prefer, thank you. @NuojCheng The NNX pipeline classes (NNXPipeline, NNXCircularPipeline) already handle these internally with JAX-native equivalents:
So no NNX versions of those functions are needed — the NNX path bypasses them entirely. maybe you have some suggestions or any part that i am wrong? Please let me know. Thank you. |
Thank you @mesakhcienet. Let's go with option 1 please. That way we can continue running unit tests along the way, and we don't need to worry about the Linen/NNX versions diverging before the migration is fully done |
|
Update on the pipeline migration approach: After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why: The current branch keeps both Linen ( Reasoning:
Plan to converge to Option 1: Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you. |
Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed? |
Description
implement nnx-based pipeline.
This PR extends PR#2831
Main changes:
NNXPipeline, which is a nnx-based pipeline class.Tests
we run the pipeline process with command below:
MODEL_NAME=llama2-7b python -m MaxText.train src/maxtext/configs/base.yml \ run_name=pipeline_test_${MODEL_NAME}_nnx \ base_output_directory=/dev/shm/pipeline_test_nnx \ model_name=${MODEL_NAME}\ dataset_type=synthetic \ steps=15 \ debug_sharding=true \ per_device_batch_size=2 \ max_target_length=32 \ ici_pipeline_parallelism=2 \ num_pipeline_microbatches=4 \ num_layers_per_pipeline_stage=2 \ enable_checkpointing=false \ enable_nnx=true \ pure_nnx_decoder=true \ scan_layers_per_stage=false \ async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.