feat(pi07,fsdp): enable FSDP-FULL_SHARD for full unfreeze + profile_step audit + ZeRO-2 vs FSDP matrix#273
Conversation
Adds `configs/examples/pi07_low_level_libero.json` for the real pi07 (Gemma 3 backbone) variant — defaults `attention_implementation="sdpa"` and `gradient_checkpointing=true`, both already plumbed end-to-end. The existing `pi07_libero.json` (legacy pi07-PaliGemma variant) is left untouched. Defaults `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` in `opentau-train` / `opentau-eval` (`scripts/launch.py`) when the user hasn't set it themselves, with a one-line print so the value is visible on every launch. The 5B+ pi07 model on A100-80GB sits at the allocator's edge: issue #264 reproduces an OOM at 79.21/79.25 GiB used with 3.4 GiB reserved-but-unallocated, exactly the fragmentation pattern expandable_segments resolves.
Three coupled changes that together let pi07 full-unfreeze training run
under native PyTorch FSDP instead of DeepSpeed ZeRO-2:
1. Loosen the gradient_checkpointing guard in train.py to allow FSDP.
With `use_reentrant=False` (the existing call style at
gemma3_with_expert.py:684), `torch.utils.checkpoint` uses
`saved_tensors_hooks` which co-exist with FSDP's all-gather hooks in
PyTorch >=2.4. DeepSpeed ZeRO-3 is still rejected; its hook model
needs `deepspeed.checkpointing.checkpoint`, not torch's.
2. Skip the `MasterWeightOptimizer` wrap and the eager `.to(bfloat16)`
under FSDP. FSDP's own MixedPrecision (param_dtype=bf16,
reduce_dtype=fp32) already provides the bf16-compute / fp32-master
semantics; double-mastering would burn `4 B/p × trainable`.
3. Ship two new accelerate configs:
- `accelerate_fsdp_config.yaml` — FULL_SHARD with TRANSFORMER_BASED_WRAP
over Gemma3DecoderLayer, GemmaDecoderLayer, SiglipEncoderLayer.
`use_orig_params: true` keeps the existing optimizer/param-group
plumbing working; `SHARDED_STATE_DICT` for fast resume parity.
- `accelerate_deepspeed_zero3_config.yaml` — for the ZeRO-3 baseline
comparison point (gradient_checkpointing must be off under ZeRO-3).
…xample
The pi07 example config sets `n_obs_steps: 8` on the policy but didn't
mirror it on the dataset_mixture, so `cfg.validate()` in the train
pipeline raises:
policy.n_obs_steps (8) != dataset_mixture.n_obs_history
(None; treated as 1 when unset)
Setting `n_obs_history: 8` and `history_interval: 1` on the
dataset_mixture matches the policy and lets the example train cleanly.
Pure config fix to a public example file; no source-code changes.
DeepSpeed ZeRO-3 installs a `partition_parameters` wrapper that shards each parameter as it is constructed. The accelerate field `zero3_init_flag` only controls the transformers `from_pretrained` integration, not this construction-time wrapper. Several init paths in our model graph need the full tensor shape and crash on the per-rank shard. Concretely, transformers' SigLIP `_init_weights` calls `lecun_normal_` → `_calculate_fan_in_and_fan_out`, which raises `ValueError: Fan in and fan out can not be computed for tensor with fewer than 2 dimensions` when given a sharded weight. Wrap `make_policy(...)` in `deepspeed.zero.Init(enabled=False)` for the duration of construction. ZeRO re-shards correctly when `accelerator.prepare` later wraps the model, so steady-state behaviour is unchanged. The helper `_zero3_disabled_init_context` no-ops on every non-DeepSpeed-ZeRO-3 backend, so DDP / single / DeepSpeed-ZeRO-1/2 / FSDP code paths are untouched. profile_step.py mirrors the same wrap so the diagnostic script stays in lockstep with train.py.
…O-3 init The previous attempt used `deepspeed.zero.Init(enabled=False)` to suppress ZeRO-3 partitioning during model construction, but it had no effect: the partition wrapper is gated on transformers' `is_deepspeed_zero3_enabled()` flag, which stays True under `DistributedType.DEEPSPEED` + `zero_stage: 3` regardless of the accelerate yaml's `zero3_init_flag` (that field only controls the transformers `from_pretrained` path). Verified empirically: probe under the ZeRO-3 yaml shows `is_deepspeed_zero3_enabled() == True` even with `zero3_init_flag: false`, and `unset_hf_deepspeed_config()` flips it to False. Switch the helper to a `@contextmanager` that explicitly unsets and later restores the active `HfDeepSpeedConfig` for the duration of `make_policy`. ZeRO-3 still partitions correctly when `accelerator.prepare` later wraps the model — only the construction-time interception is suppressed.
Pi07's interleaved Gemma 3 backbone + Gemma-v1 expert decoder previously ran one layer-step as a free-floating method `_run_layer` on the parent model. The body reaches into sub-components of both the backbone and expert layer (`layer.input_layernorm`, `layer.self_attn.q_proj`, ...) directly. Under FSDP / ZeRO-3 those accesses bypass the param-sharding backend's all-gather hooks (which fire on the wrapped layer's own `forward`) and produce either silently-wrong gradients (sub-component sees a sharded weight) or a hard NCCL hang (different ranks request different params). Refactor: - New `InterleavedDecoderLayer(nn.Module)` owns one backbone layer and one expert layer as submodules. Its `forward` is the body of the old `_run_layer`, parameterised by `layer_idx` (still needed for the KV cache key) and a list of input embeddings. - `Gemma3WithExpertModel.__init__` builds the gemma3 + gemma_expert as before, then reparents per-layer modules out of `gemma3.language_model.model.layers` and `gemma_expert.model.layers` into a flat `interleaved_layers: nn.ModuleList[InterleavedDecoderLayer]`. The source ModuleLists are emptied so each Parameter is registered exactly once (FSDP / ZeRO-3 walk the module tree and would otherwise try to wrap each layer twice). - `set_requires_grad` and `train` updated to also visit `interleaved_layers[i].backbone_layer` when freezing the backbone for `train_expert_only=True`. - `to_bfloat16_like_physical_intelligence` selector swapped from `language_model.model.layers` / `gemma_expert.model.layers` to the new `interleaved_layers` prefix. - Main `forward` iterates `self.interleaved_layers` and calls each as a unit-of-forward; gradient checkpointing now wraps the bundled module rather than the bare method, so checkpoint recompute is FSDP-safe. - `accelerate_fsdp_config.yaml` updated: `fsdp_transformer_layer_cls_to_wrap: InterleavedDecoderLayer,SiglipEncoderLayer`. Wrapping at the bundled level is what makes the all-gather fire before the sub-component accesses inside `forward`. CPU pi07 tests (76 in tests/policies/test_pi07_*.py) pass. GPU verification (FSDP and ZeRO-3 smoke runs) lands in follow-up commits.
Mirrors the train.py guard added in 63df8c5. profile_step.py was reverted during the FSDP1 attempt and the FSDP-skip never got reapplied. Wrapping MasterWeightOptimizer under FSDP1 misaligns the wrapper's fp32 master list with FSDP's FlatParameter handles, which (combined with mixed precision: bf16) sends rank 2/6 down a different code path during the first backward — observed as a NCCL ALLGATHER_BASE timeout with rank 2/6 at SeqNum=279 while ranks 0/1/3/4/5/7 stuck at SeqNum=274.
Default DatasetMixtureConfig leaves history_state/subgoal/response/metadata drops stochastic (e.g. subgoal_drop_prob=0.75). With per-sample rolls and data-parallel sharding, different ranks see different feature subsets, which makes pi07's InterleavedDecoderLayer.forward take rank-dependent branches (the `inputs_embeds[i] is None` checks on the backbone vs expert streams). The number of FSDP all-gather collectives launched per step then differs across ranks, producing the NCCL ALLREDUCE / ALLGATHER_BASE desync observed at SeqNum=274 (rank 0,1,3,4,5,7) vs 279 (rank 2,6). Match the existing pi07_libero.json (paligemma variant) and pin every drop_prob to 1.0. This drops history_state/subgoal/response/metadata unconditionally — not a quality regression for a smoke config and the only path that lets the example train cleanly under any param-sharding backend (FSDP, ZeRO-3).
The `InterleavedDecoderLayer` previously captured `model.get_attention_interface()` at init, which silently bypassed any monkey-patched `eager_attention_forward` / `sdpa_attention_forward` on the parent (the existing test pattern in `TestNoSlidingWindowEnforcement`). CI failure on PR #265: FAILED tests/policies/test_pi07_cpu.py::TestNoSlidingWindowEnforcement:: test_per_layer_mask_equals_input_mask_on_both_layer_types - AssertionError: expected one capture per layer, assert 0 == 2 Fix: store the parent's bound `get_attention_interface` method as `_attention_interface_provider` and call it on every forward to look up the current dispatch. The bound-method capture is one indirection that honours runtime patches. Also adds `TestInterleavedDecoderLayer` (9 tests) pinning the architectural invariants of the refactor: * One InterleavedDecoderLayer per text-config layer. * Source ModuleLists (`gemma3.language_model.layers`, `gemma_expert.model.layers`) are emptied so each Parameter is registered exactly once (FSDP / ZeRO-3 walk the tree and would otherwise wrap each layer twice). * Every Parameter object appears exactly once in `named_parameters`. * State-dict keys live under `interleaved_layers.X.{backbone,expert}_layer.*` and the old `language_model.model.layers.X.*` / `gemma_expert.model.layers.X.*` paths are gone. * `train_expert_only=True` freezes backbone layers under `interleaved_layers[i].backbone_layer` (used to be reachable via `self.gemma3.parameters()` only). * `train(mode=True)` with `train_expert_only=True` flips backbone layers to `.training=False` (was previously handled by walking `self.gemma3`). * Attention dispatch resolved at call time (the regression that broke CI). * Two seeded forwards are bit-identical. * The cached attention dispatch provider doesn't leak into state_dict.
…#264) `DatasetMixtureMetadata._to_standard_data_format` pads missing cameras up to `cfg.num_cams` and previously copied a `count` field from `standard_stats["state"]`. v2.0-format datasets like `physical-intelligence/libero` only carry `mean/std/min/max` — no `count` — so this raised `KeyError: 'count'` when `num_cams > actual_cams` (issue #264). Make the `count` copy conditional: if `state` has a `count` (v2.1+), mirror it; otherwise omit. The fields actually validated by the subsequent missing-keys check are `{mean, std, min, max}`, so omitting `count` is safe. Unblocks the `num_cams ∈ {3,4}` benchmark sweep on libero.
Under FSDP2 (`torch.distributed.fsdp.fully_shard`), parameters and
buffers attached to a wrapped module — including small replicated
ones like Normalize / Unnormalize stats — are exposed as `DTensor`.
Mixing a `DTensor` with a regular `Tensor` in arithmetic raises:
RuntimeError: aten.sub.Tensor got mixed torch.Tensor and DTensor,
need to convert all torch.Tensor to DTensor before calling
distributed operators!
This was the immediate blocker for FSDP2 on pi07: the very first
forward call hit `Normalize.forward`'s `(batch[key] - min) / (max - min + EPS)`,
where `min`/`max` were DTensors but `batch[key]` was a plain Tensor.
Add a tiny `_materialize(tensor)` helper that calls `.full_tensor()`
when the input is a `DTensor` (FSDP2 path) and returns the input as-is
otherwise (FSDP1 / DDP / single-process — fast no-op). Apply at every
stat read in `Normalize.forward` and `Unnormalize.forward`.
`full_tensor()` for these stats is cheap: they are tiny (per-feature
mean/std/min/max) and their sharding placement is "replicated" by
default, so the call is essentially a pointer-copy on each rank, no
collective involved.
CPU normalize tests still pass. Unblocks FSDP2 from its first failure
mode; further audit for other Tensor/DTensor mixing sites is a separate
change (see plan file for the FSDP2 status).
Adds ``configs/examples/accelerate_fsdp2_config.yaml`` so future work has a starting point for FSDP2. As of this commit FSDP2 is **not yet usable end-to-end** for pi07_low_level — verified on 8×A100, the Normalize fix in f3ba8a3 plus wrapping ``SiglipVisionEmbeddings`` here unblocks two of the three known DTensor-mixing sites, but ``SpaceTimeSiglipVideoEncoder.forward`` at ``video_encoder.py:346`` still calls ``self.layer_norm1(t_in)`` on a sub-component of the wrapped vision tower without going through that sub-component's own ``forward``, so the tensor stays a plain Tensor while the layernorm weight is a DTensor and the layernorm op refuses to mix. Fixing that requires the same architectural pattern as the ``InterleavedDecoderLayer`` refactor (commit aaefa6e): pull every direct sub-component access in ``video_encoder.py`` (and the temporal-attention sublayer it adds on top) into bundled ``nn.Module``s so FSDP2 has a single forward to hook on. This is ~half-day of focused work; deferred. FSDP1 (``accelerate_fsdp_config.yaml``) is the recommended path for now and delivers 19.0 samples/s @ bs=12 on 8×A100 (verified 200 measured steps); see PR #265 description for the full comparison.
The pi07 ``embed_prefix`` had data-dependent Python branches that decided
whether to emit response / subgoal / metadata blocks based on whether any
sample in the *local* batch had real (non-padded) tokens for that field.
The branches were optimizations: skipping a fully-dropped block avoids
(a) appending a spurious ``[1]`` causal-block boundary to ``att_masks``
that would shift every subsequent block id via cumsum, (b) injecting
real text indicators ("Subgoal: ", ";\n ") into the prefix when there is
no real content behind them, and (c) the state-end separator switch
between ", " (optional follows) and ":\n" (terminator).
Under FSDP / ZeRO-3 this was the NCCL desync trigger: each rank evaluates
``.any()`` on its own micro-batch, so under realistic stochastic
``*_drop_prob`` rolls one rank's batch can roll "at least one real" while
another's rolls "all dropped". The two ranks then issue different numbers
of ``embed_language_tokens`` calls — different number of FSDP all-gather
collectives — and NCCL hangs (observed: rank 0,1,3,4,5,7 stuck at
SeqNum=274 ALLREDUCE, rank 2,6 at SeqNum=279 ALLGATHER_BASE).
Fix: keep the optimization but synchronize the *branch decision* across
ranks. New ``_global_any(local: bool, device) -> bool`` helper does a
1-element MAX all-reduce when ``torch.distributed`` is initialised, and
returns the local value otherwise (CPU smoke / single-process). Compute
``has_response`` / ``has_subgoal`` / ``has_metadata`` once via the
helper and use those globally-synced flags at every branch site instead
of re-checking local ``.any()``.
Cost: 3 tiny scalar all-reduces per ``embed_prefix`` call, ~tens of µs
total — negligible against any real step time.
Side effect: ranks whose local micro-batch is all-dropped for a given
field still run the corresponding ``embed_language_tokens`` call (on
their own pad tokens) when some other rank's batch has real data. The
``pad_masks`` propagated downstream are still all-False on that rank,
so attention / CE loss correctly skip the padded tokens — accuracy is
preserved.
Drops back to honest defaults: pi07_low_level_libero.json no longer
forces every ``*_drop_prob: 1.0``. The earlier 1.0 setting was a
workaround for the desync this commit actually fixes.
Resolves conflicts from #263 (disable_action_expert) overlapping with the InterleavedDecoderLayer refactor: - gemma3_with_expert.py: drop the dead _run_layer block; thread disable_action_expert through _build_interleaved_layers and the new InterleavedDecoderLayer signature (expert_layer is Optional, the forward path's stream_idx==1 short-circuit makes the None safe). - test_pi07_cpu.py: keep both new test classes (TestInterleavedDecoderLayer + TestDisableActionExpert).
Review feedback on PR #265: the in-line comments at train.py and profile_step.py implied that FSDP MixedPrecision(param_dtype=bf16, reduce_dtype=fp32) gives an fp32-master / bf16-compute split. It does not — Gemma3WithExpertModel.__init__ already calls to_bfloat16_like_physical_intelligence() unconditionally, so the params are bf16 *before* FSDP wraps and MixedPrecision becomes a storage no-op. The optimizer steps on bf16 sharded params; only the gradient reduce-scatter touches fp32. Comments now describe this honestly. Also clarifies in modeling_pi07_low_level.py:_global_any that the branch-decision sync only protects against per-rank .any() divergence on a field that is uniformly *present* across ranks (true today because field presence is a global property of the dataset config); a future heterogeneous-mixture change with per-rank None-vs-Tensor variation would need its own sync.
…6 cast
Adds ``Gemma3WithExpertConfig.disable_internal_bf16_cast`` (default False).
When True, ``Gemma3WithExpertModel.__init__`` skips the unconditional
``to_bfloat16_like_physical_intelligence`` call so the model's outer
params stay in their constructed dtype (fp32). train.py and
profile_step.py now set the flag whenever the accelerator distributed
type is FSDP — matching the way the outer ``policy.to(bfloat16)`` cast
is also gated.
Net effect under FSDP-FULL_SHARD:
* Outer (master) params: fp32, sharded across ranks
* Compute params: bf16 transient, materialized by FSDP MixedPrecision
on every all-gather (param_dtype=bf16)
* Gradient reduce-scatter: bf16 (accelerate maps mixed_precision: bf16
to reduce_dtype=bf16, matching DeepSpeed BF16_Optimizer)
* AdamW state: fp32 (built over fp32 outer params)
Before this change the inner cast made params bf16 before FSDP wrapped,
so MixedPrecision became a storage no-op and AdamW's exp_avg /
exp_avg_sq were silently allocated in bf16 — a regression vs DDP /
DeepSpeed which both keep fp32 master + fp32 Adam state. The 7-bit
mantissa of bf16 underflows on small late-training updates.
Other backends (DDP, single, DeepSpeed ZeRO-1/2) are unchanged: the
flag stays False, the inner cast still runs, and either
MasterWeightOptimizer or BF16_Optimizer layers fp32 master back on top
as before.
Memory cost on 8×A100: roughly +4 GB/rank for params+states (fp32
sharded vs bf16 sharded), comfortably within the 80 GB budget.
Tests: 2 new ``TestDisableInternalBf16Cast`` cases pin both directions
(default keeps the cast, flag skips it). All 64 pi07 CPU tests pass.
…table
Four targeted fixes to make the benchmark numbers trustworthy. None
change the per-step computation (forward / backward / optim are
untouched), so the seeded loss series is unchanged.
B1. Throughput is now global, not rank-0. Gather per-rank step-time
means across ranks; use max() (slowest rank gates collectives) for
the headline samples/s figure. Per-rank means also printed for
diagnostics so straggler ranks are visible.
B2. Peak HBM tracking. Reset cuda peak stats at the warmup->measured
transition; gather max_memory_{allocated,reserved} across ranks
and report the worst-rank ceiling alongside per-rank values.
Also written into the JSON dump.
B3. backward_step dropped from the human-readable phase table. It
was an aggregate of (bwd + unscale_clip + optim_step +
zero_grad_sched) and printed alongside its components made shares
sum to ~135% rather than ~100%. Still collected and emitted into
the JSON dump for backward compatibility.
B4. assert gradient_accumulation_steps == 1. The loop has no
with accelerator.accumulate(policy): context, so optim fires
every micro-step regardless of cfg.gradient_accumulation_steps;
ga>1 silently produced a throughput number ga-times too high.
Sweep dataloader_batch_size for per-rank batch instead.
Also adds a DeepSpeed banner clarifying that under DeepSpeed the optim
work fuses into backward hooks, so optim_step / unscale_clip will
appear near-zero - only TOTAL step time is comparable across backends.
…k matrix The closed PR #265 ran the FSDP/ZeRO-2 matrix at n_obs_history=8. Per the latest hardware-budget analysis we want n_obs_history=6 for the cam=4 pretraining target this benchmark feeds, so align the example config so the matrix runs against the same value the production run will use. - dataset_mixture.n_obs_history: 8 -> 6 - policy.n_obs_steps: 8 -> 6 - policy.n_obs_history: 8 -> 6 history_interval stays at 1.
|
[claude-review] summary for commit 78c84c3 No blocking issues found. Recheck on 78c84c3 ( Architecture / refactor invariants on this branch all check out:
(Stale inline comments from the c670608 reviews can't be dismissed via the API — GitHub only accepts dismissals on CHANGES_REQUESTED / APPROVED reviews — but they're superseded by this summary.) |
|
@claude fix per review. |
- addresses @claude[bot] (profile_step.py stale FSDP comment): rewrite the rationale block at profile_step.py:330 to match the post-30ff84a regime (FSDP runs with fp32 master + fp32 Adam state via disable_internal_bf16_cast, and the reason for skipping MasterWeightOptimizer is flat-param handle misalignment + redundant fp32 wrap on top of FSDP's MixedPrecision, not bf16 storage). - addresses @claude[bot] (perf, _global_any batching): replace per-flag _global_any (3× 1-element MAX all-reduces + 3× .item() syncs per embed_prefix) with a single _global_or_branch_decisions SUM all-reduce over all flags. Cuts collective count and GPU→CPU syncs by 3× per step. - addresses @claude[bot] (defensive presence sync): same call also OR-reduces per-field presence flags and raises RuntimeError on cross-rank divergence (None vs Tensor). Future heterogeneous-mixture regressions fail loud at the embed_prefix contract instead of as a hung NCCL collective hours into a run. - addresses @claude[bot] (gpu-marked FSDP wrap test): add TestGemma3WithExpertFSDPWrap::test_fsdp_wrap_forward_backward — single- rank FSDP wrap on the InterleavedDecoderLayer-based model with one forward+backward, asserting the wrap actually picked up InterleavedDecoderLayer (not a single root-level wrap that would mask regressions). Plus three CPU tests for the new helper's no-distributed path and input-validation contract. tests: passed — pytest -m "not gpu" -n auto tests/policies/test_pi07_cpu.py tests/policies/test_pi07_low_level.py tests/scripts/test_profile_step.py tests/scripts/test_train.py Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Recheck on c670608. Prior findings (stale FSDP comment, _global_any batching, defensive presence sync, gpu-marked FSDP wrap test) all look correctly addressed in this commit; one minor comment-vs-yaml inconsistency below.
| # cast, so by the time ``accelerator.prepare`` wraps the policy under | ||
| # FSDP, parameters are still fp32. The optimizer is therefore built | ||
| # over fp32 outer params, and FSDP's ``MixedPrecision(param_dtype=bf16, | ||
| # reduce_dtype=fp32, buffer_dtype=fp32)`` provides bf16 compute on the |
There was a problem hiding this comment.
nit (stale comment, second pass) — this rewritten block now says FSDP runs with MixedPrecision(param_dtype=bf16, reduce_dtype=fp32, buffer_dtype=fp32), but configs/examples/accelerate_fsdp_config.yaml sets mixed_precision: bf16 (no overrides for fsdp_reduce_dtype/fsdp_buffer_dtype), which accelerate translates to MixedPrecision(param_dtype=bf16, reduce_dtype=bf16, buffer_dtype=bf16). The corresponding rationale block in train.py:378-379 correctly says all-bf16. The substantive argument (skip the MasterWeightOptimizer wrap because FSDP already provides fp32 master via the fp32-built outer params, and the wrap would misalign with FSDP's flat-param handles) is unaffected; just the parenthesised MixedPrecision(...) summary is wrong here. Worth a one-line fix to keep the two comments in sync.
The post-c670608 rationale block at profile_step.py:330 said FSDP runs with ``MixedPrecision(param_dtype=bf16, reduce_dtype=fp32, buffer_dtype=fp32)``, but ``configs/examples/accelerate_fsdp_config.yaml`` sets only ``mixed_precision: bf16`` with no ``fsdp_reduce_dtype`` / ``fsdp_buffer_dtype`` overrides. Verified against accelerate 1.12.0 ``FullyShardedDataParallelPlugin.set_mixed_precision``: that translates to all-bf16 (``param_dtype=bf16, reduce_dtype=bf16, buffer_dtype=bf16``), matching the parallel comment in train.py:378-379. The substantive argument (skip MasterWeightOptimizer because FSDP already gives fp32 master via fp32-built outer params + a wrap would misalign with FSDP's flat-param handles) is unaffected — ``param_dtype`` only controls the compute-time downcast, not the storage of the outer master shards. Addresses second-pass review note from claude-pr-review on PR #273.
What this does
Enables FSDP-FULL_SHARD training for pi07_low_level (full unfreeze), audits
and fixes
profile_step.pymeasurement bugs that caused unreconcilablenumbers in the prior closed PR, and runs a full FSDP vs ZeRO-2 throughput
upcoming cam=4 pretraining run.
The branch carries the FSDP-enabling code from earlier work (the closed
PR was on this same branch) plus two new commits with the
measurement-side fixes:
aaefa6e—InterleavedDecoderLayerrefactor inpolicies/pi07/gemma3_with_expert.py. Bundles each (Gemma3 backbonelayer + Gemma-v1 expert layer) pair into one
nn.Moduleso the FSDPall-gather hook fires before any sub-component access. Without this,
FSDP1 hangs at NCCL with mismatched all-gather sizes.
1f27b98—_global_anybranch sync inpolicies/pi07/low_level/modeling_pi07_low_level.py. OR-reducesper-rank decisions to enter
if has_response / has_subgoal / has_metadatablocks so all ranks issue the same number of FSDPcollectives.
30ff84a—disable_internal_bf16_castflag onGemma3WithExpertConfigplus FSDP-conditional setter intrain.pyand
profile_step.py. Keeps the policy fp32 going into FSDP soMixedPrecision can manage the bf16-compute / fp32-master split itself
(preserves fp32 Adam state).
f43fc91—unset HfDeepSpeedConfigcontext manager somake_policydoesn't fault under ZeRO-3 init.63df8c5— relax thegradient_checkpointingguard to allow FSDP;ships
accelerate_fsdp_config.yamlandaccelerate_deepspeed_zero3_config.yaml.3bd7817— NEW fourprofile_step.pybug fixes (B1-B4) +DeepSpeed phase-blur banner. Detail below.
ffd6208— NEWpi07_low_level_libero.jsonn_obs_history8→6for the matrix.
profile_step.py audit findings (fixed in 3bd7817)
The prior numbers couldn't reconcile because the script had four real
bugs in measurement / aggregation. None affect the per-step computation
(forward / backward / optim are untouched), so the seeded loss series is
unchanged.
num_processes. Under stragglers this is optimistic.max()(slowest rank gates collectives). Print per-rank means too.torch.cuda.reset_peak_memory_stats()at warmup→measured boundary; gathermax_memory_{allocated,reserved}across ranks; report worst rank + per-rank values; include in JSON dump.backward_step(an aggregate of bwd+unscale_clip+optim_step+zero_grad_sched) was printed alongside its components, so the share column summed to ~135% rather than ~100%.with accelerator.accumulate(policy):context, sooptimizer.step()runs every micro-step regardless ofcfg.gradient_accumulation_steps. ga>1 silently produced a throughput number ga× too high.assert cfg.gradient_accumulation_steps == 1at script start. Sweepdataloader_batch_sizefor per-rank batch instead.A documented limitation (banner in output, not fixed): under DeepSpeed,
optim work fuses into backward hooks, so
optim_step/unscale_clipappear near-zero (tens of µs). Compare TOTAL step time across backends,
not per-phase breakdowns.
How it was tested
Unit tests (CPU)
88 passed.
Integration: full FSDP vs ZeRO-2 matrix on an internal 8×A100 node
10 cells × probe-walk-to-OOM at PROFILE_STEPS=10 + 20 warmup. Branch tip
n_obs_history=6example config + full unfreeze(
--policy.freeze_vision_encoder=false --policy.train_expert_only=false).bf16 mixed precision,
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,gradient_accumulation_steps=1, A100 80 GiB.Probe matrix (raw, all completed cells)
Bold rows = throughput peak per cell. SHARD_GRAD_OP cam=4 sweep was
queued but stopped early — the cam=1 data already shows SGO behaves
like FULL_SHARD on memory (both ~70 GB reserved at bs=8) and collapses
under memory pressure rather than approaching ZeRO-2's profile.
Headline backend comparison (sdpa+grad_ckpt only)
Reconciliation vs theory (closed-PR's predictions)
"ZeRO-2 has +11 GB persistent vs FSDP" — REFUTED. Observation:
ZeRO-2 cam=1 bs=8 alloc=36.22 GB, FSDP=43.86 GB — ZeRO-2 has 7.6 GB
less. The +11 GB is real for replicated params, but FSDP carries
even more than that in transient unsharded-layer buffers (FlatParam
all-gather staging held during forward+backward). Net: FSDP has the
higher peak ceiling AND much worse allocator fragmentation
(reserved-vs-allocated gap of 26 GB for FSDP cam=1 bs=8 vs 0.2 GB
for ZeRO-2).
FSDP per-step overhead "10× the all-gather prediction" — CONFIRMED
AND EXPLAINED. At cam=4 bs=8 FSDP is +3120 ms vs ZeRO-2 (vs ~440
ms predicted). 88% of the gap is in the backward (FSDP 6633 ms vs
ZeRO-2 3900 ms): every layer pays an all-gather for grad
computation on top of the reduce-scatter — ZeRO-2 only pays
reduce-scatter because params are replicated. Add to that grad_ckpt
activation recompute, which fires another all-gather per
checkpointed segment.
"FSDP +1.27× over ZeRO-2 baseline" headline (closed PR) —
REFUTED. That comparison was apples-to-oranges (FSDP-with-sdpa-ckpt
vs ZeRO-2-eager-no-ckpt). Apples-to-apples, ZeRO-2 wins 1.62× to
1.96×.
SHARD_GRAD_OP doesn't recover ZeRO-2's profile. SGO keeps params
replicated like ZeRO-2 in theory, but in practice the FSDP framework
wrapping (TRANSFORMER_BASED_WRAP, prefetch, hooks) keeps the
allocator pressure pattern of FULL_SHARD. The FSDP/DeepSpeed gap is
not just about FULL_SHARD's all-gather — it's about the framework's
memory and execution patterns end-to-end.
eager+no-ckpt is strictly worse than sdpa+grad_ckpt for full
unfreeze at this scale. Without checkpointing the activation
memory blows up; even at the maximum-fitting batch, throughput is
~2× lower (e.g. ZeRO-2 cam=1 best 12.3 sps without ckpt vs 26.7 sps
with ckpt).
Recommendation
Use ZeRO-2 with
accelerate_deepspeed_config.yamlatbs=20perrank for the cam=4 pre-training run — 13.7 global samples/s (1.96×
FSDP), 79% of HBM (real headroom). The
disable_internal_bf16_castflag should NOT be set under ZeRO-2; itonly helps FSDP's MixedPrecision.
If FSDP is required for an unrelated reason, use cam=1 bs=12 or
cam=4 bs=8 — but expect ~half the throughput and a 91-98% memory
budget.
How to checkout & try? (for the reviewer)
CPU sanity:
pytest -m "not gpu" -n auto tests/policies/test_pi07_cpu.py \ tests/scripts/test_profile_step.py \ tests/scripts/test_train.pyReproduce one matrix cell on an 8×A100 node (the recommended
production setting):
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True PROFILE_STEPS=50 \ accelerate launch \ --config_file configs/examples/accelerate_deepspeed_config.yaml \ --num_processes=8 \ src/opentau/scripts/profile_step.py \ --config_path=configs/examples/pi07_low_level_libero.json \ --batch_size=20 --dataloader_batch_size=20 \ --gradient_accumulation_steps=1 --num_cams=4 \ --policy.freeze_vision_encoder=false --policy.train_expert_only=false \ --policy.gradient_checkpointing=true \ --policy.attention_implementation=sdpa \ --wandb.enable=falseThe same invocation against
accelerate_fsdp_config.yamlreproducesthe FSDP comparison cell.
Checklist