Skip to content

[bugfix] fix ZCH finetune from checkpoint with different world size#467

Merged
tiankongdeguiji merged 8 commits into
alibaba:masterfrom
tiankongdeguiji:fix/zch-finetune-world-size
Apr 9, 2026
Merged

[bugfix] fix ZCH finetune from checkpoint with different world size#467
tiankongdeguiji merged 8 commits into
alibaba:masterfrom
tiankongdeguiji:fix/zch-finetune-world-size

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

Summary

Loading a ZCH (managed-collision) checkpoint into a model with a different world size crashed in MCHManagedCollisionModule.validate_state():

AssertionError: shard within range [500000, 1000000] cannot be built out of
segements tensor([0, 1000000, -1, ..., -1, -1, -1])

Two distinct issues stacked on top of each other:

  1. _output_segments_tensor is a fixed-shape [1025] replicated buffer in MCHManagedCollisionModule whose contents are rank-specific partition boundaries. Its shape is identical across world sizes, so torch DCP silently overwrites the freshly-built local value with the saved one, then validate_state() fails because _output_global_offset (a Python int that is not in the state dict) no longer appears in the segments.

  2. The per-position MCH lookup state (_mch_sorted_raw_ids, _mch_remapped_ids_mapping, _mch_<metadata>) is wrapped as ShardedTensor by ShardedManagedCollisionCollection._initialize_torch_state (torchrec mc_modules.py:262). Across matching world sizes this works fine, but across different world sizes torch DCP only does byte-level position slicing — and the values in _mch_remapped_ids_mapping are global slot indices in [output_global_offset, output_global_offset + zch_size), so the loaded values fall outside the new local range and crash the FBGEMM TBE bounds check / CUDA gather kernel during the first training step.

Fix

  • PartialLoadPlanner always skips _output_segments_tensor; the local buffer is rebuilt by fix_mch_state (extended to handle the post-init_parameters all-zeros case used by the export path) and is also called from restore_model so the load post-hook validation passes.
  • When the saved sharding plan's world size differs from the current world size, PartialLoadPlanner also skips the sharded MCH state buffers, and restore_model runs a new _redistribute_mch_state pass that:
    1. Reads the full saved global tensors for each MC table.
    2. For every non-empty saved entry, computes target_rank = saved_remapped_value // local_zch_size. This matches the row-wise position-based sharding that ShardedTensor already uses for the embedding weight, so the (raw_id → embedding row) binding is preserved end-to-end across the world size change.
    3. Writes the kept entries into the local module's buffers, fills the unused positions with the local-range remapped values, and calls _sort_mch_buffers to keep the binary-search invariant.

The embedding weight tensors are still loaded by torch DCP via the normal ShardedTensor path; the value-based bucket above is chosen to align with that position-based slice, so no extra weight movement is needed.

Test plan

  • New test_multi_tower_din_zch_finetune_world_size_change integration test in tzrec/tests/rank_integration_test.py: trains the ZCH multi-tower DIN config on 1 GPU, then runs --fine_tune_checkpoint on 2 GPUs.
  • Existing test_multi_tower_din_zch_with_fg_train_eval_export (same world size, ZCH train + export) still passes — exercises the fix_mch_state export path.
  • Existing test_multi_tower_din_fg_encoded_finetune (non-ZCH finetune) still passes — sanity check that the planner skip does not affect non-MCH modules.
  • Manual end-to-end: 1-GPU → 2-GPU finetune and 2-GPU → 1-GPU finetune both reach Train and Evaluate Finished. cleanly.

🤖 Generated with Claude Code

tiankongdeguiji and others added 2 commits April 7, 2026 21:12
Loading a ZCH (managed-collision) checkpoint into a model with a different
world size crashed in MCHManagedCollisionModule.validate_state() because
`_output_segments_tensor` is a fixed-shape replicated buffer whose contents
are rank-specific, and torch DCP silently overwrote the freshly-built local
value with the saved one from a different partition layout.

Even after fixing that, the per-position MCH lookup state
(`_mch_sorted_raw_ids`, `_mch_remapped_ids_mapping`, `_mch_<metadata>`) is
wrapped as `ShardedTensor` by torchrec and can only be loaded by byte-level
position slicing, which does not preserve the per-rank value semantics: the
remapped values are global slot indices in the saved world's range and fall
outside the new local range, crashing the FBGEMM/CUDA gather kernel during
the first training step.

Fix:
- `PartialLoadPlanner` always skips `_output_segments_tensor`; the local
  buffer is rebuilt by `fix_mch_state` (now also handles the post-init
  all-zeros case used by the export path).
- When the saved sharding plan's world size differs from the current world
  size, also skip the sharded MCH state buffers and run a new
  `_redistribute_mch_state` pass that reads the full saved tensors and
  reassigns each non-empty entry to the rank that owns its global value
  range. This matches the row-wise position-based sharding used by
  `ShardedTensor` for the embedding `weight`, so the
  (raw_id → embedding row) binding is preserved end-to-end across the
  world size change.

Add an integration test exercising 1-GPU train → 2-GPU finetune through the
ZCH config; verified that 1↔2 GPU finetune both directions plus the
existing same-world-size ZCH train+export path still pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The previous fix triggered `_redistribute_mch_state` whenever the saved
checkpoint world size differed from the current one. That over-triggers on
single-rank export of a multi-rank training checkpoint
(`WORLD_SIZE=1` is forced in export_model_normal), where the path:

  1. skips loading every `_managed_collision_modules.` buffer via
     PartialLoadPlanner;
  2. only redistributes the sharded main `mc_ebc` modules by value.

INPUT_TILE=3 export also creates user-side `mc_ebc_user` MCH modules whose
state is supposed to be loaded from the saved main keys via
`ckpt_param_map_path`. The user-side FQNs do not appear in the saved
metadata, so `_redistribute_mch_state` silently skips them and the
user-side `_mch_sorted_raw_ids` / `_mch_remapped_ids_mapping` buffers stay
at the all-zeros fill produced by `init_parameters`. Every input then
matches raw_id 0, all user-side embeddings collapse to slot 0, and
INPUT_TILE=3 predictions diverge from the no-tile baseline — breaking
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile` and
`test_multi_tower_zch_with_fg_train_eval_export_trt` in CI.

The bug only occurs when the current per-rank zch range is *strictly
smaller* than the saved per-rank zch range (i.e. `cur_world_size >
saved_world_size`); going the other direction the saved values already
fit within the current per-rank range and position-based `ShardedTensor`
slicing is correct. Narrow the condition accordingly so export (and any
`cur <= saved` load) keeps the stock load path and `mc_ebc_user` state
is still populated via the ckpt-param-map remap.

Verified locally that the four affected ZCH tests now pass:
`test_multi_tower_din_zch_finetune_world_size_change`,
`test_multi_tower_din_zch_with_fg_train_eval_export`,
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile`,
`test_multi_tower_zch_with_fg_train_eval_export_trt`.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 8, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Apr 8, 2026
Comment thread tzrec/utils/checkpoint_util.py Outdated
Comment on lines +567 to +568
except Exception as e:
logger.warning(f"Failed to inspect saved plan {plan_path}: {e}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug risk: broad except Exception silently disables MCH redistribution

If the plan file exists but is corrupted or has an unexpected schema, this catch-all swallows the error and leaves saved_world_size = None, which sets needs_mch_redistribution = False. The code then falls back to position-based ShardedTensor loading — exactly the broken path this PR is fixing — without any indication that redistribution was needed but couldn't be determined.

Consider distinguishing "file not found" (acceptable) from "file exists but unparseable" (should raise or at least log at ERROR level):

Suggested change
except Exception as e:
logger.warning(f"Failed to inspect saved plan {plan_path}: {e}")
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(
f"Plan file {plan_path} exists but failed to parse: {e}. "
"MCH redistribution will be skipped — this may cause "
"incorrect state loading if the world size has changed."
)

Comment thread tzrec/utils/checkpoint_util.py Outdated
Comment on lines +490 to +493
taken = torch.zeros(local_zch_size, dtype=torch.bool, device=local_dev)
taken[kept_values - local_offset] = True
unused = all_local[~taken]
new_remapped[n:] = unused[: local_zch_size - n]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Robustness: add a defensive assertion for duplicate remapped IDs

If a corrupted checkpoint contains duplicate values in _mch_remapped_ids_mapping, taken would have fewer True entries than expected, making unused too short. new_remapped[n:] would then be partially filled with stale arange values — silently corrupting the model state without any error.

Consider adding:

assert kept_values.unique().numel() == kept_values.numel(), (
    f"MCH [{prefix}]: duplicate remapped IDs in checkpoint data"
)

And after line 492:

assert unused.numel() >= local_zch_size - n, (
    f"MCH [{prefix}]: insufficient unused slots ({unused.numel()}) "
    f"to fill {local_zch_size - n} empty positions"
)

Comment thread tzrec/utils/state_dict_util.py Outdated
m._output_global_offset + m._zch_size,
]
m._buffers["_output_segments_tensor"] = torch.tensor(
output_segments + [-1] * (1025 - len(output_segments)),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: hardcoded magic number 1025

This value presumably matches a torchrec internal constant (MAX_WORLD_SIZE + 1 or similar). Consider deriving it from the existing buffer's shape to be resilient against future torchrec changes:

Suggested change
output_segments + [-1] * (1025 - len(output_segments)),
output_segments + [-1] * (buf.shape[0] - len(output_segments)),

This also handles the is_meta case correctly since meta tensors preserve their shape.

Comment thread tzrec/utils/checkpoint_util.py Outdated
Comment on lines +373 to +374
routes inputs to the rank that owns each value's range. Position-based
`ShardedTensor` resharding therefore does not preserve semantics across
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ambiguous zch_size in docstring

Here zch_size means the global total across all ranks, but throughout the code m._zch_size (assigned to local_zch_size) is the per-rank size. Consider clarifying:

Suggested change
routes inputs to the rank that owns each value's range. Position-based
`ShardedTensor` resharding therefore does not preserve semantics across
slot indices in `[0, global_zch_size)` (where `global_zch_size = local_zch_size *
world_size`) — NOT positions, NOT bytesand torchrec

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 8, 2026

Code Review Summary

Well-crafted bugfix for a genuinely tricky distributed systems problem. The value-aware MCH redistribution algorithm is sound, the inline comments are excellent (especially explaining why position-based ShardedTensor slicing breaks across world sizes), and the second commit narrowing the condition to cur > saved shows good iterative debugging. A few items worth addressing:

Issues

  1. Silent fallback on plan file corruption — The broad except Exception when parsing the saved plan file silently disables MCH redistribution if the file is corrupt, falling back to the exact broken path this PR fixes. See inline comment.

  2. No defense against duplicate remapped IDs from corrupted checkpointstaken[kept_values - local_offset] = True with duplicate values would produce fewer unused slots than expected, silently corrupting the local buffer. See inline comment.

  3. Hardcoded magic number 1025 in fix_mch_state — should derive from the buffer's existing shape (buf.shape[0]). See inline comment.

Testing gaps

  1. Integration test only covers 1→2 GPU direction. The reverse (2→1) exercises a different code path (fix_mch_state rebuild without redistribution) that has no dedicated test. The PR description mentions manual verification of both directions — consider adding the 2→1 test as well.

  2. Test verifies non-crash but not correctness. The assertions check self.success and steps > 0, but don't verify that embeddings were actually preserved across the world-size change. A smoke check comparing a known feature's embedding before and after would strengthen confidence in the redistribution logic.

Scaling note

  1. Every rank independently loads the full global MCH state in _read_full_mch_tensors. For a bugfix this is fine, but at larger scale (many ranks, large ZCH tables) this multiplies both I/O and CPU memory by world_size. Worth a # TODO for future optimization (rank-0 load + broadcast).

Positive notes

  • The algorithm in _redistribute_mch_state correctly aligns value-based MCH bucketing with the position-based ShardedTensor slicing used for embedding weights — this is the key insight that makes the fix work end-to-end.
  • The PartialLoadPlanner skip logic is clean and well-documented.
  • The second commit narrowing needs_mch_redistribution to only cur > saved is a good catch that preserves the ckpt_param_map_path remap path for export.

🤖 Generated with Claude Code

tiankongdeguiji and others added 6 commits April 8, 2026 13:26
- Replace the `plan` JSON inspection with `_ckpt_world_size`, which counts
  the per-rank `__<rank>_<part>.distcp` shard files under
  `<ckpt>/model`. The plan file may not always be present, the rank set
  derived from it duplicates information that DCP already records on
  disk, and the shard count is the authoritative source.
- Drop `skip_mc_module_state` from the optimizer planner construction.
  MCH buffers (`_mch_*`, `_output_segments_tensor`) are registered via
  `register_buffer`, never appear in the optimizer state dict, and the
  fused-optimizer accumulators for embedding `weight` are row-sharded
  with the same position-based semantics in both directions. The flag
  was dead code on the optimizer side.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…visible

The previous condition `cur_world_size > saved_world_size` is too narrow:
position-based `ShardedTensor` slicing of MCH state buffers is only safe
when each saved per-rank value range
`[S * sps, (S+1) * sps)` lies entirely inside one current per-rank value
range `[R * cps, (R+1) * cps)`. For uniform row-wise sharding this holds
iff `saved_world_size % cur_world_size == 0`.

That divisibility test agrees with `cur > saved` for the common pow-of-2
GPU counts (1/2/4/8) reported in alibaba#360 — verified by
reproducing the issue locally:

  - train on 4 GPUs (saves `__0_0.distcp` ... `__3_0.distcp`)
  - finetune on 2 GPUs and on 1 GPU
  - both reach `Train and Evaluate Finished.` with no out-of-bounds
    warnings, because position-based loading correctly handles 4→2 and
    4→1 (saved per-rank value range fits entirely inside one cur per-rank
    range) and the unconditional `_output_segments_tensor` skip already
    clears the validate_state assertion the issue reporter was hitting.

But the two conditions disagree for non-pow-of-2 shrinking like 6→4 or
4→3, where `cur < saved` but `saved % cur != 0`: a saved chunk straddles
two adjacent cur ranks, the splits inherit that saved chunk's value
range, and one of the resulting cur ranks ends up with values from the
neighbouring chunk that exceed its local range. Switch to
`saved_world_size % cur_world_size != 0` so those cases also fall back
to value-aware redistribution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
`_output_segments_tensor` is replicated and rank-specific, but the saved
tensor is safe to load via DCP whenever the same divisibility condition
that makes the rest of the MCH state safe to load via position-based
slicing also holds: every current-world boundary `R*cps` then equals
some saved-world boundary `S*sps`, so the loaded segments tensor still
contains the values that `validate_state()` checks. Conversely, when
`saved_world_size % cur_world_size != 0`, the saved boundaries do not
contain the current ones and `_output_segments_tensor` must be left at
its locally-built value — which is exactly the case where
`_redistribute_mch_state` is also needed.

Switch the `_output_segments_tensor` skip in `PartialLoadPlanner` to
piggy-back on `_skip_mc_module_state` so it is only suppressed in the
redistribution case. This removes the need for the unconditional
`fix_mch_state(model)` call I had added at the start of `restore_model`
(plus the matching all-zeros branch in `fix_mch_state` itself):

  - Train continue / finetune (same world size): MCH modules are built
    on cuda by DMP via `rebuild_with_output_id_range` and the buffer is
    never zeros. The load is allowed and is a no-op since saved == local.
  - Finetune across non-divisible world sizes (e.g. 1→2): the skip is
    on, the locally-correct buffer from `rebuild_with_output_id_range`
    is preserved, and `validate_state()` passes against it.
  - Export of any multi-rank training checkpoint (cur=1 forced; saved %
    cur == 0 always): the skip is off, the saved value is loaded into
    the post-init_parameters zeros buffer, `validate_state()` passes,
    the existing `fix_mch_state` call at `export_util.py:169` then
    rebuilds the now-meta buffer (validate_state replaces it with the
    stale `_init_output_segments_tensor` reference) — the original
    master flow.

Revert `fix_mch_state` to its original meta-only behaviour and drop the
extra `restore_model` call. Verified all four ZCH tests still pass:
`test_multi_tower_din_zch_finetune_world_size_change`,
`test_multi_tower_din_zch_with_fg_train_eval_export`,
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile`,
`test_multi_tower_zch_with_fg_train_eval_export_trt`, plus manual 4→2
and 4→1 finetune from a 4-GPU training checkpoint.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…omments

- Move the saved/current world-size compatibility check into a small
  `_needs_mch_redistribution(model_ckpt_path, cur_world_size) -> bool`
  helper next to `_ckpt_world_size`.
- `_redistribute_mch_state` no longer takes `saved_world_size` — it
  reads it inline for the rank-0 log message.
- Drop the multi-paragraph block comments inside `restore_model` and
  `PartialLoadPlanner.create_local_plan`; the helpers' docstrings are
  the single source of truth for the divisibility rationale.
- Drop the optimizer-load explanatory comment.

Verified: `test_multi_tower_din_zch_finetune_world_size_change`,
`test_multi_tower_din_zch_with_fg_train_eval_export`, and
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile` all
still pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The broad `._managed_collision_modules.` skip predicate also dropped
`_current_iter_tensor`, which is a plain replicated buffer excluded
from torchrec's `sharded_parameter_names` and holds a globally identical
iteration counter used by `MCHManagedCollisionModule.profile` to drive
`_eviction_interval` accounting. On the non-divisible world-size path
(e.g. 1→2 finetune) the load was silently skipped and every rank's
counter stayed at the post-init 0, resetting the eviction clock and
desynchronizing the LFU/LRU/DistanceLFU eviction schedule from whatever
the saved training state had reached.

Narrow the predicate to a small `_is_resharded_mc_buffer` helper that
skips everything under `._managed_collision_modules.` *except*
`_current_iter_tensor`. `_output_segments_tensor` is kept in the skip
set: it is replicated but rank-specific, so the locally-rebuilt value
from `rebuild_with_output_id_range` must be preserved when the saved
boundaries don't align with the current ones.

Verified with a debug hook on 1→2 finetune of a 1-GPU ZCH checkpoint
whose saved `_current_iter_tensor` was 8: both cur ranks' buffers now
start at 8 after `restore_model` (was 0 before the fix).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ll global read

The previous design skipped every _managed_collision_modules.* buffer
in PartialLoadPlanner and then had _redistribute_mch_state read the
full saved _mch_* tensors from disk on every rank via an extra DCP
load, keeping O(saved_total) CPU tensors per buffer per table per
rank. For a table with zch_size=100M that is ~800 MB per replica per
rank.

Switch to a collective value-aware permutation:

- PartialLoadPlanner now only skips _output_segments_tensor (the one
  replicated but rank-specific buffer whose saved boundaries do not
  contain the current ones when divisibility fails). Every _mch_*
  buffer and _current_iter_tensor is loaded via the normal DCP
  position-based path; after the load each rank already holds a
  cps-sized slice of the saved global state, just at the wrong
  position for the new sharding.
- _redistribute_mch_state now runs before model.load_state_dict and
  works off the live per-rank buffers: filter valid entries, compute
  dest_rank = remapped // cps, sort by dest_rank, exchange split
  sizes via one all_to_all_single, then exchange _mch_sorted_raw_ids
  / _mch_remapped_ids_mapping / each sharded metadata buffer via
  all_to_all_single with input_split_sizes / output_split_sizes, and
  finally place the received entries at [0, n) plus the unused
  local-range values as padding. Live buffers are updated via
  buf.copy_(...) so state_dict ShardedTensor wrappers pick up the new
  values; the subsequent model.load_state_dict post-hook
  (validate_state -> _sort_mch_buffers) re-sorts the new local table
  without any explicit call here.
- _read_full_mch_tensors is removed; _redistribute_mch_state no
  longer needs the checkpoint path argument.
- PartialLoadPlanner parameter renamed from skip_mc_module_state to
  skip_output_segments_tensor to reflect the narrowed scope.

Memory: worst-case O(cps) extra per rank instead of O(saved_total) on
every rank. I/O: one normal DCP load. Comms: one small
all_to_all_single for split sizes plus one per sharded MCH buffer
per table.

Verified: all three ZCH integration tests still pass; manual 4->2 and
4->1 finetune (divisible, position-based); manual 2->4 finetune
(non-divisible, all-to-all path, log confirms "Redistributing MCH
(ZCH) state via all_to_all_single across 6 MC modules");
_current_iter_tensor carry-over check on 1->2 finetune still shows
both cur ranks inherit the saved counter value (8, not 0).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 8, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Apr 8, 2026
Comment on lines +292 to +296
try:
saved_world_size = _ckpt_world_size(model_ckpt_path)
except Exception as e:
logger.warning(f"Failed to detect saved world size from {model_ckpt_path}: {e}")
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug (deadlock risk): If _ckpt_world_size raises on some ranks but not others (e.g. NFS latency, partial file visibility), those ranks return False here while others return True. Since _redistribute_mch_state calls dist.all_to_all_single — a collective — the process group will deadlock if only a subset of ranks enters.

Consider synchronizing the decision across ranks after computing it locally:

needs = _needs_mch_redistribution(model_ckpt_path, cur_world_size)
if dist.is_initialized():
    flag = torch.tensor([int(needs)], device="cuda")
    dist.all_reduce(flag, op=dist.ReduceOp.MAX)
    needs = flag.item() > 0

Or alternatively, let the exception propagate (remove the try/except) so all ranks fail consistently rather than silently diverging.

)
taken = torch.zeros(cps, dtype=torch.bool, device=local_dev)
if n > 0:
taken[remapped_recv - local_offset] = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug (silent data corruption): remapped_recv - local_offset is used as an index into taken (size cps) with no bounds check. If any received value falls outside [local_offset, local_offset + cps), PyTorch's negative-index wrapping will silently corrupt unrelated positions rather than raising an error.

Add a bounds check before indexing:

indices = remapped_recv - local_offset
if n > 0:
    assert (indices >= 0).all() and (indices < cps).all(), (
        f"MCH [{prefix}] rank {cur_rank}: received remapped values outside "
        f"local range [{local_offset}, {local_offset + cps})"
    )
    taken[indices] = True

Comment on lines +399 to +401
assert n <= cps, (
f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: assert is stripped under python -O. For a runtime invariant guarding data integrity in distributed code, prefer if ... : raise RuntimeError(...).

Suggested change
assert n <= cps, (
f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}"
)
if n > cps:
raise RuntimeError(
f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}"
)

ranks.add(int(m.group(1)))
if not ranks:
raise RuntimeError(f"No .distcp files under {ckpt_dir}")
return max(ranks) + 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: max(ranks) + 1 assumes contiguous rank numbering. If a checkpoint is partially written (e.g. rank 1 shard missing, only ranks 0 and 2 present), this silently returns the wrong world size, which could cascade into incorrect redistribution decisions.

Consider adding a consistency check:

expected = max(ranks) + 1
if len(ranks) != expected:
    raise RuntimeError(
        f"Checkpoint {ckpt_dir} has rank gaps: found {sorted(ranks)}, "
        f"expected contiguous 0..{expected - 1}"
    )

Comment on lines +496 to 498
if needs_mch_redistribution:
_redistribute_mch_state(model)
model.load_state_dict(state_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragility warning: This sequence relies on model.state_dict() (line 487) returning tensors that share underlying storage with model._buffers, so that _redistribute_mch_state's .copy_() into the buffers also updates what state_dict references. If PyTorch/TorchRec ever changes state_dict() to return detached copies, the redistributed values would be silently overwritten by pre-redistribution data in load_state_dict.

Consider explicitly copying the redistributed buffer values back into state_dict after redistribution, or adding a comment + assertion that the storage sharing invariant holds (e.g. assert state_dict[some_key].data_ptr() == model.some_buffer.data_ptr()).

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 8, 2026

Code Review Summary

Well-crafted fix for a gnarly problem — the value-aware all_to_all_single redistribution is the right approach, and the commit history shows careful iterative refinement of the divisibility condition and skip scope. The PR description and commit messages are exemplary.

Issues to address

1. Deadlock risk from inconsistent redistribution decision across ranks (comment)
_needs_mch_redistribution catches filesystem exceptions and returns False. If _ckpt_world_size fails on some ranks but not others (NFS latency, partial file visibility), some ranks enter all_to_all_single while others skip it → distributed deadlock. Either synchronize the decision via all_reduce or let the exception propagate.

2. Missing bounds check on remapped_recv indexing (comment)
taken[remapped_recv - local_offset] = True has no guard. Out-of-range values silently corrupt via PyTorch negative-index wrapping rather than raising. Add an assertion before indexing.

3. assert used for runtime invariant (comment)
The assert n <= cps check is stripped under python -O. Use if/raise RuntimeError for data integrity checks in distributed code.

4. _ckpt_world_size assumes contiguous rank numbering (comment)
max(ranks) + 1 silently returns the wrong world size if any intermediate shard file is missing. A len(ranks) != max(ranks) + 1 check would catch corrupted checkpoints early.

5. Implicit storage-sharing dependency in restore_model (comment)
The state_dict → load → redistribute → load_state_dict flow relies on state_dict tensors sharing storage with model buffers. If this invariant breaks in a future PyTorch/TorchRec version, redistributed values would be silently overwritten.

Optimization opportunity

Fuse multiple all_to_all_single calls per MC module. Currently K+2 collectives are issued per module (raw_ids + remapped + K metadata buffers). Since all share identical split sizes, they could be concatenated into one tensor, exchanged with a single all_to_all_single, and sliced apart. This reduces collective synchronization overhead, especially at higher rank counts.

Testing

The integration test covers the critical 1→2 GPU path and is correctly gated on GPU availability. Consider adding unit tests for the pure helper functions (_ckpt_world_size, _needs_mch_redistribution, _strip_dmp_prefix) — they require no GPU/distributed setup and would catch edge cases like non-contiguous ranks or the divisibility boundary conditions.

🤖 Generated with Claude Code

@tiankongdeguiji tiankongdeguiji merged commit 1911ce5 into alibaba:master Apr 9, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants