Skip to content

[DECOUPLED-MODE] Adding Decoupling Logic#2865

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
ROCm:rocm-main-pr
Feb 1, 2026
Merged

[DECOUPLED-MODE] Adding Decoupling Logic#2865
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
ROCm:rocm-main-pr

Conversation

@gulsumgudukbay
Copy link
Copy Markdown
Collaborator

@gulsumgudukbay gulsumgudukbay commented Dec 21, 2025

Description

This PR is the second part of the decoupling support. It adds logic for decoupling support, along with some test modifications for decoupling to be enabled.

Details:

  1. Update decoupled_base_test.yml
  2. Add decoupling locig to src/MaxText/decode.py, src/MaxText/elastic_train.py, src/MaxText/experimental/rl/grpo_trainer.py, src/MaxText/gcp_workload_monitor.py, src/MaxText/max_utils.py, src/MaxText/maxengine.py, src/MaxText/maxengine_config.py, src/MaxText/maxengine_server.py, src/MaxText/metric_logger.py, src/MaxText/prefill_packing.py, src/MaxText/profiler.py, src/MaxText/sft/hooks.py, src/MaxText/sft/sft_trainer.py, src/MaxText/train.py, src/MaxText/utils/gcs_utils.py, src/MaxText/utils/goodput_utils.py, src/MaxText/vertex_tensorboard.py
  3. Update src/MaxText/gcloud_stub.py to add IS_STUB variables, and add google_cloud_mldiagnostics stub
  4. Update tests to support decoupled mode (add markers, update file paths, make them use decoupled_base_test.yml config file).

Tests

Added two tests to test decoupled mode and gcloud_stub.

tests/unit/gcloud_stub_test.py:

  1. All stub function tests with proper variable usage

  2. GCS utils integration tests (6 tests):
    _gcs_guard behavior in different modes
    write_config_raw_keys_for_gcs skipping logic

  3. Maxengine config integration tests (4 tests):
    create_exp_maxengine device handling in decoupled/non-decoupled mode
    get_server_config raising in decoupled mode
    create_maxengine ignoring devices

  4. Maxengine server integration tests (4 tests):
    Server startup guard in decoupled mode
    Prefix caching config creation logic

All unit tests pass in decoupled mode.
UT results:
== 306 passed, 170 skipped, 25 deselected, 6588 warnings in 975.16s (0:16:15) ==

Train test:
python -m MaxText.train MaxText/configs/base.yml run_name=test hardware=gpu steps=5 model_name=llama2-7b attention=cudnn_flash_te enable_checkpointing=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 ici_data_parallelism=1 remat_policy=minimal scan_layers=True dataset_type=synthetic logits_dot_in_fp32=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=2048 shardy=False

works.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Comment thread tests/integration/smoke/train_using_ragged_dot_smoke_test.py
Comment thread tests/train_int8_smoke_test.py Outdated
Comment thread tests/test_env_smoke.py Outdated
Comment thread tests/test_env_smoke.py Outdated
Comment thread tests/sft_hooks_test.py Outdated
Comment thread src/MaxText/sft/sft_trainer.py Outdated
Comment thread src/MaxText/sft/hooks.py Outdated
Comment thread tests/integration_tests/grpo_correctness.py Outdated
Comment thread tests/grpo_trainer_correctness_test.py Outdated
Comment thread src/MaxText/configs/decoupled_base_test.yml Outdated
Comment thread src/MaxText/configs/decoupled_base_test.yml Outdated
Comment thread src/MaxText/configs/decoupled_base_test.yml Outdated
Comment thread src/maxtext/common/gcloud_stub.py
@gulsumgudukbay gulsumgudukbay force-pushed the rocm-main-pr branch 7 times, most recently from 886d8cf to d6c0251 Compare January 29, 2026 15:27
Comment thread tests/utils/test_utils.py Outdated
Comment thread tests/utils/__init__.py
Comment thread tests/utils/test_helpers.py
Comment thread tests/unit/test_env_smoke.py Outdated
Comment thread src/maxtext/common/gcloud_stub.py
Comment thread src/MaxText/max_utils.py Outdated
@gulsumgudukbay gulsumgudukbay force-pushed the rocm-main-pr branch 4 times, most recently from a394772 to cc35028 Compare January 31, 2026 07:31
add decoupling logic config patch to tests
add correct ICI parallelism to tests for decoupled mode
fixing more UTs
adding decoupling logic, biggest change
add tensorboardX stub
fixing train_tests
removing tunix from decoupling logic
renaming datasets to local_datasets to avoid confusion with HF datasets library
make jax_remove_size_one_mesh_axis_from_type param setting in try block, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode
revert legacy rl trainer
Update grpo_trainer.py
adding conditional imports because pytest collect always imports before marks are used
centralize decoupled dataset paths and base_output_directory
skip packed attention if  not on cuda sm90+
parameterize test_env_smoke tests
undo sft test decoupling changes as it is marked as external training
move path logic to setup method and use if logic in train_smoke_test.py
fix ref to dummy summary writer
add yield to GOODPUT_STUB
Add pytest marker for train_compile tests
These are actually requiring libtpu and should be TPU tests.
fixed path for GrainArrayRecordBestFitPackingTest
update local output
adding refactoring for test_utils import
moved local_datasets, refactored
add missing import from flop_calculation test
adding gcloud_stub_test.py
fix gcloud_stub test, add cpu_only marker
@copybara-service copybara-service Bot merged commit 80762f9 into AI-Hypercomputer:main Feb 1, 2026
32 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants