[DECOUPLED-MODE] Adding Decoupling Logic#2865
Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom Feb 1, 2026
Merged
[DECOUPLED-MODE] Adding Decoupling Logic#2865copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
shralex
reviewed
Jan 20, 2026
886d8cf to
d6c0251
Compare
shralex
reviewed
Jan 29, 2026
a394772 to
cc35028
Compare
shralex
approved these changes
Jan 31, 2026
add decoupling logic config patch to tests add correct ICI parallelism to tests for decoupled mode fixing more UTs adding decoupling logic, biggest change add tensorboardX stub fixing train_tests removing tunix from decoupling logic renaming datasets to local_datasets to avoid confusion with HF datasets library make jax_remove_size_one_mesh_axis_from_type param setting in try block, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode revert legacy rl trainer Update grpo_trainer.py adding conditional imports because pytest collect always imports before marks are used centralize decoupled dataset paths and base_output_directory skip packed attention if not on cuda sm90+ parameterize test_env_smoke tests undo sft test decoupling changes as it is marked as external training move path logic to setup method and use if logic in train_smoke_test.py fix ref to dummy summary writer add yield to GOODPUT_STUB Add pytest marker for train_compile tests These are actually requiring libtpu and should be TPU tests. fixed path for GrainArrayRecordBestFitPackingTest update local output adding refactoring for test_utils import moved local_datasets, refactored add missing import from flop_calculation test adding gcloud_stub_test.py fix gcloud_stub test, add cpu_only marker
bc822d8 to
1f56029
Compare
SurbhiJainUSC
approved these changes
Jan 31, 2026
80762f9
into
AI-Hypercomputer:main
32 of 33 checks passed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.
Details:
Tests
Added two tests to test decoupled mode and gcloud_stub.
tests/unit/gcloud_stub_test.py:
All stub function tests with proper variable usage
GCS utils integration tests (6 tests):
_gcs_guard behavior in different modes
write_config_raw_keys_for_gcs skipping logic
Maxengine config integration tests (4 tests):
create_exp_maxengine device handling in decoupled/non-decoupled mode
get_server_config raising in decoupled mode
create_maxengine ignoring devices
Maxengine server integration tests (4 tests):
Server startup guard in decoupled mode
Prefix caching config creation logic
All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==
Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False
works.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.