[bugfix] fix ZCH finetune from checkpoint with different world size#467
Conversation
Loading a ZCH (managed-collision) checkpoint into a model with a different world size crashed in MCHManagedCollisionModule.validate_state() because `_output_segments_tensor` is a fixed-shape replicated buffer whose contents are rank-specific, and torch DCP silently overwrote the freshly-built local value with the saved one from a different partition layout. Even after fixing that, the per-position MCH lookup state (`_mch_sorted_raw_ids`, `_mch_remapped_ids_mapping`, `_mch_<metadata>`) is wrapped as `ShardedTensor` by torchrec and can only be loaded by byte-level position slicing, which does not preserve the per-rank value semantics: the remapped values are global slot indices in the saved world's range and fall outside the new local range, crashing the FBGEMM/CUDA gather kernel during the first training step. Fix: - `PartialLoadPlanner` always skips `_output_segments_tensor`; the local buffer is rebuilt by `fix_mch_state` (now also handles the post-init all-zeros case used by the export path). - When the saved sharding plan's world size differs from the current world size, also skip the sharded MCH state buffers and run a new `_redistribute_mch_state` pass that reads the full saved tensors and reassigns each non-empty entry to the rank that owns its global value range. This matches the row-wise position-based sharding used by `ShardedTensor` for the embedding `weight`, so the (raw_id → embedding row) binding is preserved end-to-end across the world size change. Add an integration test exercising 1-GPU train → 2-GPU finetune through the ZCH config; verified that 1↔2 GPU finetune both directions plus the existing same-world-size ZCH train+export path still pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The previous fix triggered `_redistribute_mch_state` whenever the saved
checkpoint world size differed from the current one. That over-triggers on
single-rank export of a multi-rank training checkpoint
(`WORLD_SIZE=1` is forced in export_model_normal), where the path:
1. skips loading every `_managed_collision_modules.` buffer via
PartialLoadPlanner;
2. only redistributes the sharded main `mc_ebc` modules by value.
INPUT_TILE=3 export also creates user-side `mc_ebc_user` MCH modules whose
state is supposed to be loaded from the saved main keys via
`ckpt_param_map_path`. The user-side FQNs do not appear in the saved
metadata, so `_redistribute_mch_state` silently skips them and the
user-side `_mch_sorted_raw_ids` / `_mch_remapped_ids_mapping` buffers stay
at the all-zeros fill produced by `init_parameters`. Every input then
matches raw_id 0, all user-side embeddings collapse to slot 0, and
INPUT_TILE=3 predictions diverge from the no-tile baseline — breaking
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile` and
`test_multi_tower_zch_with_fg_train_eval_export_trt` in CI.
The bug only occurs when the current per-rank zch range is *strictly
smaller* than the saved per-rank zch range (i.e. `cur_world_size >
saved_world_size`); going the other direction the saved values already
fit within the current per-rank range and position-based `ShardedTensor`
slicing is correct. Narrow the condition accordingly so export (and any
`cur <= saved` load) keeps the stock load path and `mc_ebc_user` state
is still populated via the ckpt-param-map remap.
Verified locally that the four affected ZCH tests now pass:
`test_multi_tower_din_zch_finetune_world_size_change`,
`test_multi_tower_din_zch_with_fg_train_eval_export`,
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile`,
`test_multi_tower_zch_with_fg_train_eval_export_trt`.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| except Exception as e: | ||
| logger.warning(f"Failed to inspect saved plan {plan_path}: {e}") |
There was a problem hiding this comment.
Bug risk: broad except Exception silently disables MCH redistribution
If the plan file exists but is corrupted or has an unexpected schema, this catch-all swallows the error and leaves saved_world_size = None, which sets needs_mch_redistribution = False. The code then falls back to position-based ShardedTensor loading — exactly the broken path this PR is fixing — without any indication that redistribution was needed but couldn't be determined.
Consider distinguishing "file not found" (acceptable) from "file exists but unparseable" (should raise or at least log at ERROR level):
| except Exception as e: | |
| logger.warning(f"Failed to inspect saved plan {plan_path}: {e}") | |
| except (json.JSONDecodeError, KeyError, TypeError) as e: | |
| logger.error( | |
| f"Plan file {plan_path} exists but failed to parse: {e}. " | |
| "MCH redistribution will be skipped — this may cause " | |
| "incorrect state loading if the world size has changed." | |
| ) |
| taken = torch.zeros(local_zch_size, dtype=torch.bool, device=local_dev) | ||
| taken[kept_values - local_offset] = True | ||
| unused = all_local[~taken] | ||
| new_remapped[n:] = unused[: local_zch_size - n] |
There was a problem hiding this comment.
Robustness: add a defensive assertion for duplicate remapped IDs
If a corrupted checkpoint contains duplicate values in _mch_remapped_ids_mapping, taken would have fewer True entries than expected, making unused too short. new_remapped[n:] would then be partially filled with stale arange values — silently corrupting the model state without any error.
Consider adding:
assert kept_values.unique().numel() == kept_values.numel(), (
f"MCH [{prefix}]: duplicate remapped IDs in checkpoint data"
)And after line 492:
assert unused.numel() >= local_zch_size - n, (
f"MCH [{prefix}]: insufficient unused slots ({unused.numel()}) "
f"to fill {local_zch_size - n} empty positions"
)| m._output_global_offset + m._zch_size, | ||
| ] | ||
| m._buffers["_output_segments_tensor"] = torch.tensor( | ||
| output_segments + [-1] * (1025 - len(output_segments)), |
There was a problem hiding this comment.
Nit: hardcoded magic number 1025
This value presumably matches a torchrec internal constant (MAX_WORLD_SIZE + 1 or similar). Consider deriving it from the existing buffer's shape to be resilient against future torchrec changes:
| output_segments + [-1] * (1025 - len(output_segments)), | |
| output_segments + [-1] * (buf.shape[0] - len(output_segments)), |
This also handles the is_meta case correctly since meta tensors preserve their shape.
| routes inputs to the rank that owns each value's range. Position-based | ||
| `ShardedTensor` resharding therefore does not preserve semantics across |
There was a problem hiding this comment.
Nit: ambiguous zch_size in docstring
Here zch_size means the global total across all ranks, but throughout the code m._zch_size (assigned to local_zch_size) is the per-rank size. Consider clarifying:
| routes inputs to the rank that owns each value's range. Position-based | |
| `ShardedTensor` resharding therefore does not preserve semantics across | |
| slot indices in `[0, global_zch_size)` (where `global_zch_size = local_zch_size * | |
| world_size`) — NOT positions, NOT bytes — and torchrec |
Code Review SummaryWell-crafted bugfix for a genuinely tricky distributed systems problem. The value-aware MCH redistribution algorithm is sound, the inline comments are excellent (especially explaining why position-based Issues
Testing gaps
Scaling note
Positive notes
🤖 Generated with Claude Code |
- Replace the `plan` JSON inspection with `_ckpt_world_size`, which counts the per-rank `__<rank>_<part>.distcp` shard files under `<ckpt>/model`. The plan file may not always be present, the rank set derived from it duplicates information that DCP already records on disk, and the shard count is the authoritative source. - Drop `skip_mc_module_state` from the optimizer planner construction. MCH buffers (`_mch_*`, `_output_segments_tensor`) are registered via `register_buffer`, never appear in the optimizer state dict, and the fused-optimizer accumulators for embedding `weight` are row-sharded with the same position-based semantics in both directions. The flag was dead code on the optimizer side. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…visible The previous condition `cur_world_size > saved_world_size` is too narrow: position-based `ShardedTensor` slicing of MCH state buffers is only safe when each saved per-rank value range `[S * sps, (S+1) * sps)` lies entirely inside one current per-rank value range `[R * cps, (R+1) * cps)`. For uniform row-wise sharding this holds iff `saved_world_size % cur_world_size == 0`. That divisibility test agrees with `cur > saved` for the common pow-of-2 GPU counts (1/2/4/8) reported in alibaba#360 — verified by reproducing the issue locally: - train on 4 GPUs (saves `__0_0.distcp` ... `__3_0.distcp`) - finetune on 2 GPUs and on 1 GPU - both reach `Train and Evaluate Finished.` with no out-of-bounds warnings, because position-based loading correctly handles 4→2 and 4→1 (saved per-rank value range fits entirely inside one cur per-rank range) and the unconditional `_output_segments_tensor` skip already clears the validate_state assertion the issue reporter was hitting. But the two conditions disagree for non-pow-of-2 shrinking like 6→4 or 4→3, where `cur < saved` but `saved % cur != 0`: a saved chunk straddles two adjacent cur ranks, the splits inherit that saved chunk's value range, and one of the resulting cur ranks ends up with values from the neighbouring chunk that exceed its local range. Switch to `saved_world_size % cur_world_size != 0` so those cases also fall back to value-aware redistribution. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
`_output_segments_tensor` is replicated and rank-specific, but the saved
tensor is safe to load via DCP whenever the same divisibility condition
that makes the rest of the MCH state safe to load via position-based
slicing also holds: every current-world boundary `R*cps` then equals
some saved-world boundary `S*sps`, so the loaded segments tensor still
contains the values that `validate_state()` checks. Conversely, when
`saved_world_size % cur_world_size != 0`, the saved boundaries do not
contain the current ones and `_output_segments_tensor` must be left at
its locally-built value — which is exactly the case where
`_redistribute_mch_state` is also needed.
Switch the `_output_segments_tensor` skip in `PartialLoadPlanner` to
piggy-back on `_skip_mc_module_state` so it is only suppressed in the
redistribution case. This removes the need for the unconditional
`fix_mch_state(model)` call I had added at the start of `restore_model`
(plus the matching all-zeros branch in `fix_mch_state` itself):
- Train continue / finetune (same world size): MCH modules are built
on cuda by DMP via `rebuild_with_output_id_range` and the buffer is
never zeros. The load is allowed and is a no-op since saved == local.
- Finetune across non-divisible world sizes (e.g. 1→2): the skip is
on, the locally-correct buffer from `rebuild_with_output_id_range`
is preserved, and `validate_state()` passes against it.
- Export of any multi-rank training checkpoint (cur=1 forced; saved %
cur == 0 always): the skip is off, the saved value is loaded into
the post-init_parameters zeros buffer, `validate_state()` passes,
the existing `fix_mch_state` call at `export_util.py:169` then
rebuilds the now-meta buffer (validate_state replaces it with the
stale `_init_output_segments_tensor` reference) — the original
master flow.
Revert `fix_mch_state` to its original meta-only behaviour and drop the
extra `restore_model` call. Verified all four ZCH tests still pass:
`test_multi_tower_din_zch_finetune_world_size_change`,
`test_multi_tower_din_zch_with_fg_train_eval_export`,
`test_multi_tower_din_zch_with_fg_train_eval_export_input_tile`,
`test_multi_tower_zch_with_fg_train_eval_export_trt`, plus manual 4→2
and 4→1 finetune from a 4-GPU training checkpoint.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…omments - Move the saved/current world-size compatibility check into a small `_needs_mch_redistribution(model_ckpt_path, cur_world_size) -> bool` helper next to `_ckpt_world_size`. - `_redistribute_mch_state` no longer takes `saved_world_size` — it reads it inline for the rank-0 log message. - Drop the multi-paragraph block comments inside `restore_model` and `PartialLoadPlanner.create_local_plan`; the helpers' docstrings are the single source of truth for the divisibility rationale. - Drop the optimizer-load explanatory comment. Verified: `test_multi_tower_din_zch_finetune_world_size_change`, `test_multi_tower_din_zch_with_fg_train_eval_export`, and `test_multi_tower_din_zch_with_fg_train_eval_export_input_tile` all still pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The broad `._managed_collision_modules.` skip predicate also dropped `_current_iter_tensor`, which is a plain replicated buffer excluded from torchrec's `sharded_parameter_names` and holds a globally identical iteration counter used by `MCHManagedCollisionModule.profile` to drive `_eviction_interval` accounting. On the non-divisible world-size path (e.g. 1→2 finetune) the load was silently skipped and every rank's counter stayed at the post-init 0, resetting the eviction clock and desynchronizing the LFU/LRU/DistanceLFU eviction schedule from whatever the saved training state had reached. Narrow the predicate to a small `_is_resharded_mc_buffer` helper that skips everything under `._managed_collision_modules.` *except* `_current_iter_tensor`. `_output_segments_tensor` is kept in the skip set: it is replicated but rank-specific, so the locally-rebuilt value from `rebuild_with_output_id_range` must be preserved when the saved boundaries don't align with the current ones. Verified with a debug hook on 1→2 finetune of a 1-GPU ZCH checkpoint whose saved `_current_iter_tensor` was 8: both cur ranks' buffers now start at 8 after `restore_model` (was 0 before the fix). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ll global read The previous design skipped every _managed_collision_modules.* buffer in PartialLoadPlanner and then had _redistribute_mch_state read the full saved _mch_* tensors from disk on every rank via an extra DCP load, keeping O(saved_total) CPU tensors per buffer per table per rank. For a table with zch_size=100M that is ~800 MB per replica per rank. Switch to a collective value-aware permutation: - PartialLoadPlanner now only skips _output_segments_tensor (the one replicated but rank-specific buffer whose saved boundaries do not contain the current ones when divisibility fails). Every _mch_* buffer and _current_iter_tensor is loaded via the normal DCP position-based path; after the load each rank already holds a cps-sized slice of the saved global state, just at the wrong position for the new sharding. - _redistribute_mch_state now runs before model.load_state_dict and works off the live per-rank buffers: filter valid entries, compute dest_rank = remapped // cps, sort by dest_rank, exchange split sizes via one all_to_all_single, then exchange _mch_sorted_raw_ids / _mch_remapped_ids_mapping / each sharded metadata buffer via all_to_all_single with input_split_sizes / output_split_sizes, and finally place the received entries at [0, n) plus the unused local-range values as padding. Live buffers are updated via buf.copy_(...) so state_dict ShardedTensor wrappers pick up the new values; the subsequent model.load_state_dict post-hook (validate_state -> _sort_mch_buffers) re-sorts the new local table without any explicit call here. - _read_full_mch_tensors is removed; _redistribute_mch_state no longer needs the checkpoint path argument. - PartialLoadPlanner parameter renamed from skip_mc_module_state to skip_output_segments_tensor to reflect the narrowed scope. Memory: worst-case O(cps) extra per rank instead of O(saved_total) on every rank. I/O: one normal DCP load. Comms: one small all_to_all_single for split sizes plus one per sharded MCH buffer per table. Verified: all three ZCH integration tests still pass; manual 4->2 and 4->1 finetune (divisible, position-based); manual 2->4 finetune (non-divisible, all-to-all path, log confirms "Redistributing MCH (ZCH) state via all_to_all_single across 6 MC modules"); _current_iter_tensor carry-over check on 1->2 finetune still shows both cur ranks inherit the saved counter value (8, not 0). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| try: | ||
| saved_world_size = _ckpt_world_size(model_ckpt_path) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to detect saved world size from {model_ckpt_path}: {e}") | ||
| return False |
There was a problem hiding this comment.
Bug (deadlock risk): If _ckpt_world_size raises on some ranks but not others (e.g. NFS latency, partial file visibility), those ranks return False here while others return True. Since _redistribute_mch_state calls dist.all_to_all_single — a collective — the process group will deadlock if only a subset of ranks enters.
Consider synchronizing the decision across ranks after computing it locally:
needs = _needs_mch_redistribution(model_ckpt_path, cur_world_size)
if dist.is_initialized():
flag = torch.tensor([int(needs)], device="cuda")
dist.all_reduce(flag, op=dist.ReduceOp.MAX)
needs = flag.item() > 0Or alternatively, let the exception propagate (remove the try/except) so all ranks fail consistently rather than silently diverging.
| ) | ||
| taken = torch.zeros(cps, dtype=torch.bool, device=local_dev) | ||
| if n > 0: | ||
| taken[remapped_recv - local_offset] = True |
There was a problem hiding this comment.
Bug (silent data corruption): remapped_recv - local_offset is used as an index into taken (size cps) with no bounds check. If any received value falls outside [local_offset, local_offset + cps), PyTorch's negative-index wrapping will silently corrupt unrelated positions rather than raising an error.
Add a bounds check before indexing:
indices = remapped_recv - local_offset
if n > 0:
assert (indices >= 0).all() and (indices < cps).all(), (
f"MCH [{prefix}] rank {cur_rank}: received remapped values outside "
f"local range [{local_offset}, {local_offset + cps})"
)
taken[indices] = True| assert n <= cps, ( | ||
| f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}" | ||
| ) |
There was a problem hiding this comment.
Nit: assert is stripped under python -O. For a runtime invariant guarding data integrity in distributed code, prefer if ... : raise RuntimeError(...).
| assert n <= cps, ( | |
| f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}" | |
| ) | |
| if n > cps: | |
| raise RuntimeError( | |
| f"MCH [{prefix}] rank {cur_rank}: received {n} > local zch_size {cps}" | |
| ) |
| ranks.add(int(m.group(1))) | ||
| if not ranks: | ||
| raise RuntimeError(f"No .distcp files under {ckpt_dir}") | ||
| return max(ranks) + 1 |
There was a problem hiding this comment.
Suggestion: max(ranks) + 1 assumes contiguous rank numbering. If a checkpoint is partially written (e.g. rank 1 shard missing, only ranks 0 and 2 present), this silently returns the wrong world size, which could cascade into incorrect redistribution decisions.
Consider adding a consistency check:
expected = max(ranks) + 1
if len(ranks) != expected:
raise RuntimeError(
f"Checkpoint {ckpt_dir} has rank gaps: found {sorted(ranks)}, "
f"expected contiguous 0..{expected - 1}"
)| if needs_mch_redistribution: | ||
| _redistribute_mch_state(model) | ||
| model.load_state_dict(state_dict) |
There was a problem hiding this comment.
Fragility warning: This sequence relies on model.state_dict() (line 487) returning tensors that share underlying storage with model._buffers, so that _redistribute_mch_state's .copy_() into the buffers also updates what state_dict references. If PyTorch/TorchRec ever changes state_dict() to return detached copies, the redistributed values would be silently overwritten by pre-redistribution data in load_state_dict.
Consider explicitly copying the redistributed buffer values back into state_dict after redistribution, or adding a comment + assertion that the storage sharing invariant holds (e.g. assert state_dict[some_key].data_ptr() == model.some_buffer.data_ptr()).
Code Review SummaryWell-crafted fix for a gnarly problem — the value-aware Issues to address1. Deadlock risk from inconsistent redistribution decision across ranks (comment) 2. Missing bounds check on 3. 4. 5. Implicit storage-sharing dependency in Optimization opportunityFuse multiple TestingThe integration test covers the critical 1→2 GPU path and is correctly gated on GPU availability. Consider adding unit tests for the pure helper functions ( 🤖 Generated with Claude Code |
Summary
Loading a ZCH (managed-collision) checkpoint into a model with a different world size crashed in
MCHManagedCollisionModule.validate_state():Two distinct issues stacked on top of each other:
_output_segments_tensoris a fixed-shape[1025]replicated buffer inMCHManagedCollisionModulewhose contents are rank-specific partition boundaries. Its shape is identical across world sizes, so torch DCP silently overwrites the freshly-built local value with the saved one, thenvalidate_state()fails because_output_global_offset(a Python int that is not in the state dict) no longer appears in the segments.The per-position MCH lookup state (
_mch_sorted_raw_ids,_mch_remapped_ids_mapping,_mch_<metadata>) is wrapped asShardedTensorbyShardedManagedCollisionCollection._initialize_torch_state(torchrecmc_modules.py:262). Across matching world sizes this works fine, but across different world sizes torch DCP only does byte-level position slicing — and the values in_mch_remapped_ids_mappingare global slot indices in[output_global_offset, output_global_offset + zch_size), so the loaded values fall outside the new local range and crash the FBGEMM TBE bounds check / CUDA gather kernel during the first training step.Fix
PartialLoadPlanneralways skips_output_segments_tensor; the local buffer is rebuilt byfix_mch_state(extended to handle the post-init_parametersall-zeros case used by the export path) and is also called fromrestore_modelso the load post-hook validation passes.PartialLoadPlanneralso skips the sharded MCH state buffers, andrestore_modelruns a new_redistribute_mch_statepass that:target_rank = saved_remapped_value // local_zch_size. This matches the row-wise position-based sharding thatShardedTensoralready uses for the embeddingweight, so the(raw_id → embedding row)binding is preserved end-to-end across the world size change._sort_mch_buffersto keep the binary-search invariant.The embedding
weighttensors are still loaded by torch DCP via the normalShardedTensorpath; the value-based bucket above is chosen to align with that position-based slice, so no extra weight movement is needed.Test plan
test_multi_tower_din_zch_finetune_world_size_changeintegration test intzrec/tests/rank_integration_test.py: trains the ZCH multi-tower DIN config on 1 GPU, then runs--fine_tune_checkpointon 2 GPUs.test_multi_tower_din_zch_with_fg_train_eval_export(same world size, ZCH train + export) still passes — exercises thefix_mch_stateexport path.test_multi_tower_din_fg_encoded_finetune(non-ZCH finetune) still passes — sanity check that the planner skip does not affect non-MCH modules.Train and Evaluate Finished.cleanly.🤖 Generated with Claude Code