Background
Pi07's Gemma3WithExpertModel.embed_prefix has data-dependent branches that decide whether to add response / metadata / subgoal / prefix-end blocks based on whether any sample in the local batch has real (non-padded) content for that field. Each branch's body calls embed_language_tokens(...) and/or embed_image(...), which trigger forward calls into the FSDP-wrapped tree.
Under FSDP / ZeRO-3 with realistic stochastic *_drop_prob settings, different ranks roll different drop outcomes, take different branches, and end up issuing different counts of FSDP all-gather collectives → NCCL deadlock. PR #265 fixes this with a _global_any(local, device) helper that OR-reduces each branch decision across ranks via a 1-element MAX all-reduce, so every rank takes the same branch.
That fix works, but it has two warts:
- Three small all-reduces per
embed_prefix (~tens of µs total — negligible against step time, but still extra collectives).
- Ranks whose local micro-batch has no data for a field still embed pad tokens through the whole prefix when some other rank has that field's content. Their
pad_masks are all-False so attention/loss correctly ignore them, but every interleaved layer still spends compute carrying those pad slots.
Proposed cleanup
Hoist every embed_language_tokens(...) and embed_image(...) call out of the if has_* branches. Make the call unconditional (so every rank issues the same FSDP all-gather count) but keep the bookkeeping (embs.append, pad_masks.append, att_masks += [...]) gated on the local condition.
Sketch:
```python
Always run — uniform FSDP all-gather count across ranks.
response_emb = (
self.gemma3_with_expert.embed_language_tokens(response_tokens)
if response_tokens is not None
else None
)
Bookkeeping stays conditional on the LOCAL data — no global sync needed.
if response_emb is not None and response_masks is not None and response_masks.any():
embs.append(response_emb)
pad_masks.append(response_masks)
att_masks += [1] * response_emb.shape[1]
```
Same shape for metadata, prefix_end, subgoal_images (the subgoal body has 2× embed_language_tokens + N× embed_image; all should be hoisted).
Different ranks would then have different prefix lengths — that's OK because:
- FSDP collectives are tied to Module.forward call counts, not input shapes; uniform call counts → no desync.
- Each rank's
gemma3_with_expert.forward processes its own seq length; the per-layer InterleavedDecoderLayer / SiglipEncoderLayer forwards are still uniformly called once per layer.
- Loss / metric reductions go through
accelerator.gather_for_metrics which handles per-rank scalar reductions.
Verification needed before merging
The above only holds if no downstream op depends on a uniform prefix structure across ranks. Quick checklist to walk before merging:
Win
Drops the 3 all-reduces and the wasted-pad compute. Restores per-rank-honest prefix structure. Aesthetic match to the existing pattern of "the unconditional path always runs; the optional path appends" elsewhere in the codebase.
Out of scope (already handled in PR #265)
The _global_any fix is already shipped and verified — it works, just leaves these efficiency / cleanliness tradeoffs on the table. This issue is for the follow-up cleanup.
Background
Pi07's
Gemma3WithExpertModel.embed_prefixhas data-dependent branches that decide whether to add response / metadata / subgoal / prefix-end blocks based on whether any sample in the local batch has real (non-padded) content for that field. Each branch's body callsembed_language_tokens(...)and/orembed_image(...), which trigger forward calls into the FSDP-wrapped tree.Under FSDP / ZeRO-3 with realistic stochastic
*_drop_probsettings, different ranks roll different drop outcomes, take different branches, and end up issuing different counts of FSDP all-gather collectives → NCCL deadlock. PR #265 fixes this with a_global_any(local, device)helper that OR-reduces each branch decision across ranks via a 1-elementMAXall-reduce, so every rank takes the same branch.That fix works, but it has two warts:
embed_prefix(~tens of µs total — negligible against step time, but still extra collectives).pad_masksare all-False so attention/loss correctly ignore them, but every interleaved layer still spends compute carrying those pad slots.Proposed cleanup
Hoist every
embed_language_tokens(...)andembed_image(...)call out of theif has_*branches. Make the call unconditional (so every rank issues the same FSDP all-gather count) but keep the bookkeeping (embs.append,pad_masks.append,att_masks += [...]) gated on the local condition.Sketch:
```python
Always run — uniform FSDP all-gather count across ranks.
response_emb = (
self.gemma3_with_expert.embed_language_tokens(response_tokens)
if response_tokens is not None
else None
)
Bookkeeping stays conditional on the LOCAL data — no global sync needed.
if response_emb is not None and response_masks is not None and response_masks.any():
embs.append(response_emb)
pad_masks.append(response_masks)
att_masks += [1] * response_emb.shape[1]
```
Same shape for
metadata,prefix_end,subgoal_images(the subgoal body has 2×embed_language_tokens+ N×embed_image; all should be hoisted).Different ranks would then have different prefix lengths — that's OK because:
gemma3_with_expert.forwardprocesses its own seq length; the per-layer InterleavedDecoderLayer / SiglipEncoderLayer forwards are still uniformly called once per layer.accelerator.gather_for_metricswhich handles per-rank scalar reductions.Verification needed before merging
The above only holds if no downstream op depends on a uniform prefix structure across ranks. Quick checklist to walk before merging:
embed_prefixcallers (PI07LowLevelFlowMatching.forwardand friends) work with per-rank-variableprefix_embs.shape[1].Win
Drops the 3 all-reduces and the wasted-pad compute. Restores per-rank-honest prefix structure. Aesthetic match to the existing pattern of "the unconditional path always runs; the optional path appends" elsewhere in the codebase.
Out of scope (already handled in PR #265)
The
_global_anyfix is already shipped and verified — it works, just leaves these efficiency / cleanliness tradeoffs on the table. This issue is for the follow-up cleanup.