Skip to content

fix mimo optimizer checkpoint metadata restore#4791

Draft
liding-nv wants to merge 1 commit into
NVIDIA:mainfrom
liding-nv:mimo-ckpt-metadata
Draft

fix mimo optimizer checkpoint metadata restore#4791
liding-nv wants to merge 1 commit into
NVIDIA:mainfrom
liding-nv:mimo-ckpt-metadata

Conversation

@liding-nv
Copy link
Copy Markdown

Problem

Distributed-checkpoint load for MimoOptimizer can return the sharded tensor state plus the extracted MIMO per-module param-groups but without the original nested "optimizer" common-state wrapper. The inner optimizer's load_state_dict then KeyErrors on the missing wrapper. Separately, param_state_sharding_type is not extracted/restored across save+load — a real divergence for non-colocated rank module ownership, where rank 0 may own the language module while rank 1 owns the vision module.

Changes

Save side: extract param_state_sharding_type into a ShardedObject keyed optimizer.mimo.<module>.<suffix>.param_state_sharding_type so it round-trips through DistCkpt.

Load side: restore param_state_sharding_type from the _mimo_param_state_sharding_type* key, and reconstruct {'optimizer': {'param_groups': ...}} via setdefault so the inner optimizer's load_state_dict sees the keys it expects.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Signed-off-by: Li Ding <liding@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@liding-nv
Copy link
Copy Markdown
Author

@yashaswikarnati can u take a look? not sure if this PR solves the same problem as #4801

yashaswikarnati added a commit to yashaswikarnati/Megatron-LM that referenced this pull request May 15, 2026
…ng loop (#26)

* Add distributed-checkpoint save/load to the hetero MIMO training loop

Adds the standalone `examples/mimo/training/hetero/checkpointing.py` module
plus the CLI surface and loop wiring needed to round-trip MimoModel,
MimoOptimizer (ChainedOptimizer-of-DistributedOptimizers in the MoE recipe)
and the LR/WD scheduler through `megatron.core.dist_checkpointing` without
depending on the `parallel_state` singleton.

Layout stays compatible with `megatron/training/checkpointing.py` output:
`<save>/latest_checkpointed_iteration.txt` plus per-iteration directories
containing `common.pt`, `metadata.json`, `.metadata`, and torch_dist shards.
Common state now carries `args`, `checkpoint_version=3.0`, the LR scheduler
state, and a per-branch `mimo.{branch}.rng_state` ShardedObject; the tracker
read uses a cross-rank MAX reduce to mirror megatron's `read_metadata`.

Fixes three pre-existing dist-ckpt bugs that hetero usage uncovered:
- `megatron/core/ssm/mamba_mixer.py` was calling
  `make_sharded_tensors_for_checkpoint` without passing `tp_group` and
  `dp_cp_group`, which fell back to the parallel_state singleton and
  asserted in hetero mode (gated_delta_net was already correct).
- `MimoOptimizer.sharded_state_dict` now applies
  `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` to each per-branch
  optimizer sub-dict so two modules' identical internal ShardedObject keys
  (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide.
- `_get_replica_id` now folds in `tp_rank` so two TP ranks within DP=0
  don't both claim primary writer for the same shard.

Also routes DistributedOptimizer's per-module `param_state_sharding_type`
config string through a new ShardedObject (`_extract_*` helpers) so the
non-rank-0 module owner doesn't lose it when only rank 0's common.pt is
authoritative.

A `_propagate_tp_groups_for_checkpoint` walker stamps `self.tp_group` on
descendants that omit it (e.g. `ExtendedRMSNorm`, RADIO submodules) so the
default `MegatronModule.sharded_state_dict` path doesn't fall through to
`parallel_state.get_tensor_model_parallel_group`.

Validated end-to-end on cw-dfw 8-GPU 20L mock (stage2):
- Save iter 3 (DistributedOptimizer + EP=4 + TP=2 + 2-module Chained)
- Reload iter 3 → resume at iter 4 with cosine LR continuation
  (1.59e-4 → 1.32e-4 → 1.01e-4), losses match prior trajectory.

New flags: `--save`, `--load`, `--save-interval`, `--no-save-optim`,
`--no-load-optim`, `--no-load-scheduler`, `--no-save-rng`, `--no-load-rng`,
`--finetune`, `--dist-ckpt-optim-fully-reshardable`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Adopt MimoOptimizer checkpoint patterns from NVIDIA#4801

Three convergent simplifications to MimoOptimizer's distributed-checkpoint
path, matching Kamran's MimoOptimizer fixes PR:

1. Replace the ShardedObject-based round-trip for
   `param_state_sharding_type` with a metadata stash. The sharding type is
   not per-rank state — it's a load-time interpretation hint that the caller
   supplies via the `metadata` kwarg on `sharded_state_dict()`. We stash that
   metadata in `self._last_sharded_metadata` at save and re-inject the
   sharding type into each per-module sub state-dict during
   `load_state_dict()` for ranks that lost it via dist_checkpointing's
   common-state path (i.e. non-rank-0 module owners in non-colocated
   layouts). Drops `_extract_param_state_sharding_type` /
   `_restore_param_state_sharding_type` along with their ShardedObject keys.

2. `_restore_param_groups` now uses `setdefault('optimizer', {})` before
   writing back `param_groups`. After `_extract_param_groups` deletes
   `param_groups` at save time, the leftover empty `'optimizer'` dict can be
   dropped by the common-state round-trip on ranks whose active module
   wasn't on rank 0 at save. The setdefault makes the restore path tolerant
   of that drop.

3. `_get_replica_id` reorders to `(tp_rank, pp_rank, dp_rank)` to match the
   convention used by `make_sharded_object_for_checkpoint` in
   `megatron/core/transformer/utils.py:168-172`. Dedup math is unchanged —
   `(0, 0, 0)` is still the primary replica — but the order is now
   consistent with the rest of the codebase.

Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer +
ChainedOptimizer + EP=4 + TP=2): save iter 3, reload, resume iter 4 with
cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4) and matching loss
trajectory. Save exit 0, load exit 0.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Document _propagate_tp_groups_for_checkpoint escapees from no-walker run

Disabled `_propagate_tp_groups_for_checkpoint` and re-ran the 20L mock
to enumerate exactly which modules fall through to
`parallel_state.get_tensor_model_parallel_group()` and assert. Confirmed
both branches escape:

- RADIO encoder internals (first failure, reached via
  `nemotron_moe_vlm.RadioEncoder.sharded_state_dict` → HF radio_model
  leaves with no tp_group + no own sharded_state_dict).
- `MambaLayer.__init__` in `megatron/core/ssm/mamba_layer.py` plumbs
  pg_collection to the mixer but never sets `self.tp_group`.
- `ExtendedRMSNorm` at `megatron/core/ssm/mamba_mixer.py:93` never sees
  pg_collection at all.

Fixing each at the source would mean patches across core (Mamba) plus a
partial walk of RADIO's HF wrapper, validated against all existing
non-hetero users of those modules. The walker is the smaller intervention:
one place, hasattr-guarded, applied per branch with the correct pg.

Re-enables the walker (it was already in PR1; this commit only updates
the docstring to record the experiment's findings).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Switch back to ShardedObject for param_state_sharding_type (NVIDIA#4791)

Kamran reverted the metadata-stash approach in NVIDIA#4801 (discussion
r3250847203) and adopted Li Ding's PR NVIDIA#4791 pattern, which is the same
ShardedObject round-trip we had originally. Align our MimoOptimizer with
that final shape:

- Restore `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type`
  helpers. Hooks back into the existing `_iter_optimizer_sub_dicts` loop.
- Add `if not opt_sub: del sub_sd['optimizer']` to `_extract_param_groups`
  (from NVIDIA#4791) so the now-empty `'optimizer'` wrapper doesn't round-trip
  through common-state with undefined behavior on the load side.
- Drop `self._last_sharded_metadata` and the metadata-stash recover path
  from `load_state_dict` / `sharded_state_dict`. The ShardedObject route
  is self-contained and doesn't need caller-state coupling.

Kept (not in NVIDIA#4791, specific to our non-colocated hetero layout):
- `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` so the two
  branches' identical inner ShardedObject keys (e.g.
  `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide.
- `_get_replica_id` returning `(tp_rank, pp_rank, dp_rank)` (from NVIDIA#4801).

Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer
+ ChainedOptimizer + EP=4 + TP=2): save iter 3 exit 0, reload + resume
iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4),
matching loss trajectory across the boundary.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Drop _stamp_tp_group walker; fix the three constructors at the source

Three modules under our hetero save path don't store `self.tp_group` in
their constructors and therefore trip `MegatronModule.sharded_state_dict`'s
parallel_state fallback (`megatron/core/transformer/module.py:85`) in
heterogeneous-parallelism layouts where parallel_state is intentionally
not initialized. Fix them at the source instead of papering over with the
hasattr-guarded walker:

- `megatron/core/models/vision/radio.py:RADIOViTModel.__init__` — already
  extracts `tp_group` at line 129 for the embedder; now also stamps
  `self.tp_group = tp_group`.
- `megatron/core/ssm/mamba_layer.py:MambaLayer.__init__` — takes
  pg_collection and plumbs it into the mixer; now also stores
  `self.tp_group = pg_collection.tp` on the layer itself.
- `megatron/core/ssm/mamba_mixer.py:ExtendedRMSNorm` — adds an
  `__init__(*args, tp_group=None, **kwargs)` override that stores
  `self.tp_group` eagerly, and updates the single call site at line ~369
  to pass `tp_group=self.pg_collection.tp`. The lazy `hasattr` fallback
  inside `sharded_state_dict` is preserved for callers that don't pass
  tp_group.

With these three constructor fixes in place, the
`_propagate_tp_groups_for_checkpoint` walker (and `_stamp_tp_group`
helper) in `examples/mimo/training/hetero/runtime.py` is no longer needed.
Removed entirely.

Validated on cw-dfw 1-node 8-GPU 20L mock with the walker disabled:
- save iter 3 exit 0 (DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2)
- reload iter 3 → resume iter 4-5 with cosine LR continuation
  (1.59e-4 → 1.32e-4 → 1.01e-4), exit 0
- losses match prior runs (iter 1: 12.187, iter 2: 12.190, iter 3: 12.177,
  resume iter 4: 11.817, iter 5: 11.264)

The downstream check `if not hasattr(self, 'tp_group')` in subsequent
descendants (TransformerBlock, TransformerLayer, Attention, MLP,
ColumnParallelLinear) was already satisfied by their own constructors;
verified by reading those files.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kamran-nvidia added a commit to kamran-nvidia/Megatron-LM that referenced this pull request May 19, 2026
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant