Merge release/v26.2 back into main#635
Conversation
…32B Configs for MI300X & MI355X (#556) YF: Only SFT related config and Doc changes, bypassing unit CI tests ## Summary This PR introduces post-training documentation and updates Qwen3 32B model configuration files to support AMD MI300X and MI355X accelerators. --- ## Changes ### 📘 Documentation - **Added `posttraining.md`** - New comprehensive guide for post-training workflows - Covers setup instructions, configuration details, and usage examples - **Updated `docs/README.md`** - Added a new section referencing post-training documentation - Improved documentation organization and navigation --- ### ⚙️ Configuration Updates - **Updated Qwen3_32B model YAML configs** - Added/modified configurations optimized for: - MI300X - MI355X - Adjusted parameters for compatibility and stable execution --- ## Validation - Verified updated configs load and execute successfully on MI300X and MI355X environments - Confirmed documentation links and structure render correctly --- ## Checklist - [x] Added `posttraining.md` - [x] Updated `docs/README.md` - [x] Modified Qwen3_32B YAML configs - [x] Verified changes locally
Co-authored-by: Mingyu Yang <Mingyu.Yang@amd.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: Kailash Gogineni <gkailashnath1998@gmail.com> Co-authored-by: HuangWei-95 <Wei.Huang4@amd.com> Co-authored-by: HuangWei-95 <weihuan@amd.com> Co-authored-by: Xiaoming-AMD <Xiaoming.Peng@amd.com> Co-authored-by: WangLingxun <linxwang@amd.com>
…578) Expand projection.md with memory projection and performance details.
…581) Hook Megatron validate_args alongside parse_args so Primus-injected arguments are validated consistently, and run additional ROCM-specific argument checks during initialization.
There was a problem hiding this comment.
Pull request overview
This PR merges release/v26.2 back into main, bringing in Megatron-side Mamba/hybrid model support, a ROCm-specific Mamba correctness patch, and a broad set of config additions/updates for new models and tuned example workloads.
Changes:
- Add Mamba-aware model-provider resolution and select GPT vs Mamba pretrain entrypoints based on
model_type. - Introduce a hybrid Mamba+MLA stack implementation/spec and a ROCm patch for Mamba’s Triton backward path.
- Add new Mamba/Zebra model configs and update many example configs (fusion flags, batch sizes, GC frequency, DeepSeek turbo flags).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| primus/tools/preflight/network/network_probe.py | Minor formatting/clarity changes in distributed-intent detection. |
| primus/modules/trainer/megatron/trainer.py | Adds defaults/logic for model_type=mamba model-provider selection and extra logging. |
| primus/modules/trainer/megatron/pre_trainer.py | Adjusts forward path for Mamba vs GPT (schedule-plan gating, loss_mask handling). |
| primus/core/utils/import_utils.py | Extends get_model_provider() to support model_type (GPT vs Mamba). |
| primus/configs/modules/megatron/trainer_base.yaml | Comments out hybrid/mamba config blocks; removes is_hybrid_model from trainer base. |
| primus/configs/models/megatron/zebra_llama_8B.yaml | Adds Zebra Llama 8B hybrid-Mamba model config. |
| primus/configs/models/megatron/zebra_llama_3B.yaml | Adds Zebra Llama 3B hybrid-Mamba model config. |
| primus/configs/models/megatron/zebra_llama_1B.yaml | Adds Zebra Llama 1B hybrid-Mamba model config. |
| primus/configs/models/megatron/mamba_base.yaml | Adds shared base config for Mamba models. |
| primus/configs/models/megatron/mamba_370M.yaml | Adds a Mamba 370M model config. |
| primus/configs/models/megatron/language_model.yaml | Adds model_type and original_max_position_embeddings; removes MTP keys from this base. |
| primus/configs/models/megatron/hybrid_model_base.yaml | Adds a hybrid-model base (mamba + hybrid ratios + MTP keys). |
| primus/configs/models/megatron/deepseek_v3_base.yaml | Adds original_max_position_embeddings metadata. |
| primus/configs/models/megatron/deepseek_v2_base.yaml | Adds original_max_position_embeddings metadata. |
| primus/backends/megatron/patches/mamba_rocm_patches.py | New ROCm-only patch to disable Triton buffer_ops for a Mamba backward path. |
| primus/backends/megatron/megatron_pretrain_trainer.py | Selects GPT vs Mamba pretrain modules/provider based on backend_args.model_type. |
| primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py | New ModuleSpec wiring for hybrid Mamba+MLA stack. |
| primus/backends/megatron/core/models/hybrid/hybrid_block.py | New HybridStack implementation (layer allocation, forward, sharded state). |
| primus/backends/megatron/core/models/hybrid/init.py | Exposes hybrid_stack_spec for import by configs. |
| examples/torchtitan/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml | Adds gc_freq in torchtitan training section. |
| examples/torchtitan/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml | Adds gc_freq in torchtitan training section. |
| examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml | Updates batch sizing/steps and adds gc_freq. |
| examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml | Updates batch sizing/steps and adds gc_freq. |
| examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml | Adds gc_freq in torchtitan training section. |
| examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml | Adds gc_freq in torchtitan training section. |
| examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml | Updates batch sizing and adds gc_freq. |
| examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml | Updates batch sizing and adds gc_freq. |
| examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml | New example pretrain config for Zebra Llama 8B hybrid spec. |
| examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml | New example pretrain config for Zebra Llama 3B hybrid spec. |
| examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml | New example pretrain config for Zebra Llama 1B hybrid spec. |
| examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml | Enables grad-accum fusion and cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/qwen2.5_72B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3_8B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3_8B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3.3_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml | Updates batch sizes; enables grad-accum fusion and cross-entropy TE fusion. |
| examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml | Updates batch sizes; enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml | Adjusts batch/recompute; enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml | Enables grad-accum fusion and cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama2_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/deepseek_v3-FP8-pretrain.yaml | Adds deepep/turbo-related flags; enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/deepseek_v3-BF16-pretrain.yaml | Adds turbo/deepep flags and MoE fusion toggles; enables cross-entropy TE fusion. |
| examples/megatron/configs/MI355X/deepseek_v2_lite-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/deepseek_v2_lite-BF16-pretrain.yaml | Flips grouped-gemm legacy flag; adds MoE fusion toggles; enables cross-entropy TE fusion. |
| examples/megatron/configs/MI355X/deepseek_v2-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI355X/deepseek_v2-BF16-pretrain.yaml | Adds MoE fusion toggles; enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml | New example pretrain config for Zebra Llama 8B hybrid spec (MI300X). |
| examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml | New example pretrain config for Zebra Llama 3B hybrid spec (MI300X). |
| examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml | New example pretrain config for Zebra Llama 1B hybrid spec (MI300X). |
| examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml | New example pretrain config for Mamba 370M. |
| examples/megatron/configs/MI300X/llama3_8B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3_8B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v3-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v3-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v2_lite-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml | Adds MoE fusion toggles; enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v2-FP8-pretrain.yaml | Enables cross-entropy TE fusion flags. |
| examples/megatron/configs/MI300X/deepseek_v2-BF16-pretrain.yaml | Enables cross-entropy TE fusion flags. |
There was a problem hiding this comment.
The decision to include/exclude loss_mask is based on the model class name containing "Mamba". This is a fragile proxy for API compatibility and can lead to incorrect calls (or missed optimization paths) if the class name changes or a wrapper alters it. Prefer checking support for the loss_mask argument directly (signature inspection or a safe call-and-fallback).
There was a problem hiding this comment.
model_type is treated as a free-form string: any value other than exactly "mamba" silently falls back to the GPT import/provider path. This makes misconfiguration hard to detect and can run the wrong training stack. Consider validating model_type against an allowed set (e.g., {"gpt","mamba"}) and raising a clear ValueError for unknown values.
There was a problem hiding this comment.
Unknown model_type values currently fall through to the GPT training components path. This can hide configuration errors and run the wrong pretrain entrypoint. Consider validating model_type and raising an error for unsupported values (or explicitly mapping supported types).
There was a problem hiding this comment.
The header comment says "Zebra Llama 8B configuration" but this file is the 3B config; please update the comment to match the model size to avoid confusion.
There was a problem hiding this comment.
The header comment says "Zebra Llama 8B configuration" but this file is the 1B config; please update the comment to match the model size to avoid confusion.
There was a problem hiding this comment.
allocate_layers() will raise ZeroDivisionError when hybrid_attention_ratio is 0 (or small enough that num_attention_layers becomes 0), because num_mamba_per_attention_layer = num_mamba_layers // num_attention_layers and the later modulo also divide by num_attention_layers. Please add a guard for num_attention_layers == 0 (and potentially == num_layers//2) to produce a valid all-Mamba/all-attention layout without crashing.
There was a problem hiding this comment.
Using print() here will emit from every rank and bypass the project logging controls; this can spam stdout in multi-rank runs. Prefer log_rank_0(...) (or a logger) and consider gating behind a debug flag.
| def _is_rocm(ctx: PatchContext) -> bool: | ||
| """Return True when running on an AMD ROCm platform.""" | ||
| return getattr(torch.version, "hip", None) is not None | ||
|
|
||
|
|
||
| def _make_triton_wrapper(original_fn): | ||
| from triton import knobs as _triton_knobs | ||
|
|
||
| def _chunk_state_bwd_db_no_buffer_ops(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): | ||
| with _triton_knobs.amd.scope(): | ||
| _triton_knobs.amd.use_buffer_ops = False | ||
| return original_fn(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) | ||
|
|
||
| return _chunk_state_bwd_db_no_buffer_ops | ||
|
|
||
|
|
||
| @register_patch( | ||
| "megatron.mamba.rocm_chunk_state_bwd_db", | ||
| backend="megatron", | ||
| phase="before_train", | ||
| description=( | ||
| "Disable Triton buffer_ops in Mamba _chunk_state_bwd_db backward pass " | ||
| "to work around ROCm-specific correctness issues." | ||
| ), | ||
| condition=_is_rocm, | ||
| tags=["rocm", "mamba"], | ||
| ) | ||
| def patch_mamba_rocm_chunk_state_bwd_db(ctx: PatchContext): | ||
| """ | ||
| Patch mamba_ssm _chunk_state_bwd_db to disable Triton buffer_ops on ROCm. | ||
|
|
||
| The Triton buffer_ops feature can cause correctness issues on AMD GPUs | ||
| during the backward pass of the Mamba chunk state computation. This patch | ||
| wraps the original function to set use_buffer_ops = False within an AMD | ||
| Triton knobs scope. | ||
|
|
||
| Both ``ssd_chunk_state`` (definition) and ``ssd_combined`` (import-time | ||
| binding) module namespaces are patched so every call-site picks up the | ||
| wrapper regardless of which module it was imported from. | ||
| """ | ||
| import mamba_ssm.ops.triton.ssd_chunk_state as ssd_chunk_state | ||
| import mamba_ssm.ops.triton.ssd_combined as ssd_combined | ||
|
|
There was a problem hiding this comment.
This patch is gated only on ROCm, but the handler unconditionally imports mamba_ssm (and _make_triton_wrapper imports triton). On ROCm systems running non-Mamba workloads or without these optional deps installed, applying patches will raise ImportError before training starts. Please tighten the condition (e.g., also check importlib.util.find_spec(...)) and/or make the handler a no-op with a log when the modules are unavailable.
There was a problem hiding this comment.
Model-type detection via unwrapping .module and substring matching on the class name ("Mamba" in model_class_name) is brittle (can break with wrappers, renames, or hybrid classes). Prefer using a reliable signal like get_args().model_type, a model attribute, or capability detection (e.g., try calling with loss_mask and catch TypeError, or inspect the forward signature) so this doesn’t mis-detect and fail at runtime.
|
@wenxie-amd can you help to take a look at this and merge it. Nothing changes, just merge the |
No description provided.