fix mimo optimizer checkpoint metadata restore#4791
Draft
liding-nv wants to merge 1 commit into
Draft
Conversation
Signed-off-by: Li Ding <liding@nvidia.com>
Author
|
@yashaswikarnati can u take a look? not sure if this PR solves the same problem as #4801 |
5 tasks
yashaswikarnati
added a commit
to yashaswikarnati/Megatron-LM
that referenced
this pull request
May 15, 2026
…ng loop (#26) * Add distributed-checkpoint save/load to the hetero MIMO training loop Adds the standalone `examples/mimo/training/hetero/checkpointing.py` module plus the CLI surface and loop wiring needed to round-trip MimoModel, MimoOptimizer (ChainedOptimizer-of-DistributedOptimizers in the MoE recipe) and the LR/WD scheduler through `megatron.core.dist_checkpointing` without depending on the `parallel_state` singleton. Layout stays compatible with `megatron/training/checkpointing.py` output: `<save>/latest_checkpointed_iteration.txt` plus per-iteration directories containing `common.pt`, `metadata.json`, `.metadata`, and torch_dist shards. Common state now carries `args`, `checkpoint_version=3.0`, the LR scheduler state, and a per-branch `mimo.{branch}.rng_state` ShardedObject; the tracker read uses a cross-rank MAX reduce to mirror megatron's `read_metadata`. Fixes three pre-existing dist-ckpt bugs that hetero usage uncovered: - `megatron/core/ssm/mamba_mixer.py` was calling `make_sharded_tensors_for_checkpoint` without passing `tp_group` and `dp_cp_group`, which fell back to the parallel_state singleton and asserted in hetero mode (gated_delta_net was already correct). - `MimoOptimizer.sharded_state_dict` now applies `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` to each per-branch optimizer sub-dict so two modules' identical internal ShardedObject keys (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide. - `_get_replica_id` now folds in `tp_rank` so two TP ranks within DP=0 don't both claim primary writer for the same shard. Also routes DistributedOptimizer's per-module `param_state_sharding_type` config string through a new ShardedObject (`_extract_*` helpers) so the non-rank-0 module owner doesn't lose it when only rank 0's common.pt is authoritative. A `_propagate_tp_groups_for_checkpoint` walker stamps `self.tp_group` on descendants that omit it (e.g. `ExtendedRMSNorm`, RADIO submodules) so the default `MegatronModule.sharded_state_dict` path doesn't fall through to `parallel_state.get_tensor_model_parallel_group`. Validated end-to-end on cw-dfw 8-GPU 20L mock (stage2): - Save iter 3 (DistributedOptimizer + EP=4 + TP=2 + 2-module Chained) - Reload iter 3 → resume at iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), losses match prior trajectory. New flags: `--save`, `--load`, `--save-interval`, `--no-save-optim`, `--no-load-optim`, `--no-load-scheduler`, `--no-save-rng`, `--no-load-rng`, `--finetune`, `--dist-ckpt-optim-fully-reshardable`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Adopt MimoOptimizer checkpoint patterns from NVIDIA#4801 Three convergent simplifications to MimoOptimizer's distributed-checkpoint path, matching Kamran's MimoOptimizer fixes PR: 1. Replace the ShardedObject-based round-trip for `param_state_sharding_type` with a metadata stash. The sharding type is not per-rank state — it's a load-time interpretation hint that the caller supplies via the `metadata` kwarg on `sharded_state_dict()`. We stash that metadata in `self._last_sharded_metadata` at save and re-inject the sharding type into each per-module sub state-dict during `load_state_dict()` for ranks that lost it via dist_checkpointing's common-state path (i.e. non-rank-0 module owners in non-colocated layouts). Drops `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type` along with their ShardedObject keys. 2. `_restore_param_groups` now uses `setdefault('optimizer', {})` before writing back `param_groups`. After `_extract_param_groups` deletes `param_groups` at save time, the leftover empty `'optimizer'` dict can be dropped by the common-state round-trip on ranks whose active module wasn't on rank 0 at save. The setdefault makes the restore path tolerant of that drop. 3. `_get_replica_id` reorders to `(tp_rank, pp_rank, dp_rank)` to match the convention used by `make_sharded_object_for_checkpoint` in `megatron/core/transformer/utils.py:168-172`. Dedup math is unchanged — `(0, 0, 0)` is still the primary replica — but the order is now consistent with the rest of the codebase. Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2): save iter 3, reload, resume iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4) and matching loss trajectory. Save exit 0, load exit 0. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Document _propagate_tp_groups_for_checkpoint escapees from no-walker run Disabled `_propagate_tp_groups_for_checkpoint` and re-ran the 20L mock to enumerate exactly which modules fall through to `parallel_state.get_tensor_model_parallel_group()` and assert. Confirmed both branches escape: - RADIO encoder internals (first failure, reached via `nemotron_moe_vlm.RadioEncoder.sharded_state_dict` → HF radio_model leaves with no tp_group + no own sharded_state_dict). - `MambaLayer.__init__` in `megatron/core/ssm/mamba_layer.py` plumbs pg_collection to the mixer but never sets `self.tp_group`. - `ExtendedRMSNorm` at `megatron/core/ssm/mamba_mixer.py:93` never sees pg_collection at all. Fixing each at the source would mean patches across core (Mamba) plus a partial walk of RADIO's HF wrapper, validated against all existing non-hetero users of those modules. The walker is the smaller intervention: one place, hasattr-guarded, applied per branch with the correct pg. Re-enables the walker (it was already in PR1; this commit only updates the docstring to record the experiment's findings). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Switch back to ShardedObject for param_state_sharding_type (NVIDIA#4791) Kamran reverted the metadata-stash approach in NVIDIA#4801 (discussion r3250847203) and adopted Li Ding's PR NVIDIA#4791 pattern, which is the same ShardedObject round-trip we had originally. Align our MimoOptimizer with that final shape: - Restore `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type` helpers. Hooks back into the existing `_iter_optimizer_sub_dicts` loop. - Add `if not opt_sub: del sub_sd['optimizer']` to `_extract_param_groups` (from NVIDIA#4791) so the now-empty `'optimizer'` wrapper doesn't round-trip through common-state with undefined behavior on the load side. - Drop `self._last_sharded_metadata` and the metadata-stash recover path from `load_state_dict` / `sharded_state_dict`. The ShardedObject route is self-contained and doesn't need caller-state coupling. Kept (not in NVIDIA#4791, specific to our non-colocated hetero layout): - `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` so the two branches' identical inner ShardedObject keys (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide. - `_get_replica_id` returning `(tp_rank, pp_rank, dp_rank)` (from NVIDIA#4801). Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2): save iter 3 exit 0, reload + resume iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), matching loss trajectory across the boundary. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Drop _stamp_tp_group walker; fix the three constructors at the source Three modules under our hetero save path don't store `self.tp_group` in their constructors and therefore trip `MegatronModule.sharded_state_dict`'s parallel_state fallback (`megatron/core/transformer/module.py:85`) in heterogeneous-parallelism layouts where parallel_state is intentionally not initialized. Fix them at the source instead of papering over with the hasattr-guarded walker: - `megatron/core/models/vision/radio.py:RADIOViTModel.__init__` — already extracts `tp_group` at line 129 for the embedder; now also stamps `self.tp_group = tp_group`. - `megatron/core/ssm/mamba_layer.py:MambaLayer.__init__` — takes pg_collection and plumbs it into the mixer; now also stores `self.tp_group = pg_collection.tp` on the layer itself. - `megatron/core/ssm/mamba_mixer.py:ExtendedRMSNorm` — adds an `__init__(*args, tp_group=None, **kwargs)` override that stores `self.tp_group` eagerly, and updates the single call site at line ~369 to pass `tp_group=self.pg_collection.tp`. The lazy `hasattr` fallback inside `sharded_state_dict` is preserved for callers that don't pass tp_group. With these three constructor fixes in place, the `_propagate_tp_groups_for_checkpoint` walker (and `_stamp_tp_group` helper) in `examples/mimo/training/hetero/runtime.py` is no longer needed. Removed entirely. Validated on cw-dfw 1-node 8-GPU 20L mock with the walker disabled: - save iter 3 exit 0 (DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2) - reload iter 3 → resume iter 4-5 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), exit 0 - losses match prior runs (iter 1: 12.187, iter 2: 12.190, iter 3: 12.177, resume iter 4: 11.817, iter 5: 11.264) The downstream check `if not hasattr(self, 'tp_group')` in subsequent descendants (TransformerBlock, TransformerLayer, Attention, MLP, ColumnParallelLinear) was already satisfied by their own constructors; verified by reading those files. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kamran-nvidia
added a commit
to kamran-nvidia/Megatron-LM
that referenced
this pull request
May 19, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Distributed-checkpoint load for
MimoOptimizercan return the sharded tensor state plus the extracted MIMO per-module param-groups but without the original nested "optimizer" common-state wrapper. The inner optimizer'sload_state_dictthen KeyErrors on the missing wrapper. Separately,param_state_sharding_typeis not extracted/restored across save+load — a real divergence for non-colocated rank module ownership, where rank 0 may own the language module while rank 1 owns the vision module.Changes
Save side: extract
param_state_sharding_typeinto a ShardedObject keyedoptimizer.mimo.<module>.<suffix>.param_state_sharding_typeso it round-trips through DistCkpt.Load side: restore
param_state_sharding_typefrom the_mimo_param_state_sharding_type*key, and reconstruct{'optimizer': {'param_groups': ...}}via setdefault so the inner optimizer'sload_state_dictsees the keys it expects.Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.