Skip to content

Merge claude/verify-precision-issue-SjtJf into claude/wonderful-fermat-0831fd#186

Closed
shuheng-liu wants to merge 19 commits into
claude/wonderful-fermat-0831fdfrom
claude/verify-precision-issue-SjtJf
Closed

Merge claude/verify-precision-issue-SjtJf into claude/wonderful-fermat-0831fd#186
shuheng-liu wants to merge 19 commits into
claude/wonderful-fermat-0831fdfrom
claude/verify-precision-issue-SjtJf

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

Summary

Merges the precision-fix and pi05 perf/feature work from claude/verify-precision-issue-SjtJf into the pi05_mem development branch claude/wonderful-fermat-0831fd.

Why this helps pi0.5-mem

pi05_mem reuses PaliGemmaWithExpertModel from opentau.policies.pi05.paligemma_with_expert and trains through the same src/opentau/scripts/train.py pipeline as pi05, so it inherits every fix and feature in this branch.

Critical correctness fixes (PR #181)

  • 2009918 fix(train): avoid bf16 optimizer state under DDP — under DDP the previous unconditional policy.to(torch.bfloat16) made torch.zeros_like(p) allocate Adam's exp_avg/exp_avg_sq in bf16 with eps=1e-8, silently corrupting training (pi05 LIBERO: 0% under DDP vs 94% under DeepSpeed).
  • d259616 refactor: fp32 master weights wrapper for DDP path — introduces MasterWeightOptimizer, mirroring DeepSpeed's BF16_Optimizer design (fp32 master copies + post-upcast clip) for DDP / FSDP / single-process backends.
  • 8be2cd1 fix(master_weights): subclass Optimizer + rebind LR scheduler — fixes two latent integration bugs (LR schedule never reaching the inner optimizer; MasterWeightOptimizer failing isinstance(Optimizer) so Accelerator.prepare skipped grad-accum gating).

Without these, any pi05_mem run under the now-default DDP backend would suffer the same precision degradation.

Performance & feature work that pi05_mem inherits

Notes for the reviewer

  • The merge base is 111d1dc (feat: support spec. start time of attached video (#160)); both branches have diverged since then, so expect a non-trivial merge.
  • 1f7a7f6 (fix: sync DeepSpeed gradient_accumulation_steps from TrainPipelineConfig (#175)) already exists on claude/wonderful-fermat-0831fd — no action needed; git will see the same change set.
  • After merge, verify that gradient_checkpointing from PI05Config is forwarded into the inner PaliGemmaWithExpertConfig from PI05MemConfig so pi05_mem can opt into the new flag.
  • Memory caveat from fix(pi05): DDP fp32 Adam state + SDPA / grad-ckpt opt-ins (#181, #177) #182: keeping fp32 master weights under DDP costs ~+6.8 GB/rank for params and roughly doubles Adam state (13.6 → 27.2 GB). Per-rank batch may need adjustment.

Test plan

  • Resolve any merge conflicts (likely in src/opentau/scripts/train.py and configs/).
  • pytest -m "not gpu" tests/optim/test_master_weights.py
  • pytest -m "not gpu" tests/policies/test_pi05_mem.py
  • GPU smoke: train pi05_mem under DDP+bf16 and confirm Adam state is fp32 (per MasterWeightOptimizer invariants).
  • Confirm pi05_mem end-to-end test (tests/policies/test_pi05_mem_gpu.py) still passes after the merge.

Refs: #181, #182, #176

https://claude.ai/code/session_01P7ou4rfy1D3KoTTjwAZC3o


Generated by Claude Code

bananaSnail and others added 19 commits April 22, 2026 09:28
Co-authored-by: nafeng <fengna@crestline.pro>
#163)

Co-authored-by: Shuheng Liu <wish1104@icloud.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Claude <noreply@anthropic.com>
Only cast the policy to bf16 under DeepSpeed; under DDP/FSDP/single-process
we now keep fp32 master parameters so ``torch.optim.AdamW`` allocates
``exp_avg`` and ``exp_avg_sq`` in fp32. This relies on accelerate's
``mixed_precision: bf16`` autocast for the bf16 forward/backward path.

Under DDP the previous unconditional ``policy.to(torch.bfloat16)`` made
``torch.zeros_like(p)`` allocate bf16 optimizer state, and Adam stepped
in bf16 with eps=1e-8. Over many steps this corrupts training: on pi05 /
libero, DDP 4-GPU hit 0% success while DeepSpeed ZeRO-2 4-GPU hit 94%
(see issue #181, regression introduced by PR #176's DDP default).

Also adds three diagnostic scripts and a CPU regression test:

- ``verify_adam_dtype_pure.py``: pure PyTorch dtype check.
- ``verify_adam_dtype_live.py``: accelerate-launched A/B for DDP vs
  DeepSpeed on a minimal ``nn.Linear``.
- ``verify_adam_divergence.py``: synthetic regression showing the bf16
  path plateaus far above the fp32+autocast path.
- ``tests/scripts/test_verify_adam_dtype_pure.py``: CPU unit test that
  pins the upstream PyTorch behaviour so a future regression is caught.

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
Introduce ``MasterWeightOptimizer`` (``src/opentau/optim/master_weights.py``),
a duck-typed proxy that mirrors DeepSpeed's ``BF16_Optimizer`` design:
fp32 master copies of bf16 model parameters, fp32 Adam state, and
post-upcast gradient clipping. The wrapper lives between
``torch.optim.AdamW`` (over the fp32 masters) and accelerate's
``AcceleratedOptimizer``, exposing ``param_groups`` / ``defaults`` /
``state`` / ``state_dict`` / ``load_state_dict`` / ``add_param_group``
/ ``step`` / ``zero_grad`` / ``clip_grad_norm_``.

In ``src/opentau/scripts/train.py``:

- Restore unconditional ``policy.to(torch.bfloat16)`` so forward and
  backward run in bf16 regardless of backend (activation memory + bf16
  compute).
- After ``make_optimizer_and_scheduler``, wrap the optimizer with
  ``MasterWeightOptimizer.from_existing(...)`` for every backend except
  DeepSpeed (which already provides equivalent semantics through
  ``BF16_Optimizer``). The wrapper rebuilds the inner optimizer over
  fp32 masters, preserving per-group hyperparameters and the original
  ``defaults`` dict.
- The grad-clip site now dispatches: when the inner optimizer is the
  master-weights wrapper, call its ``clip_grad_norm_`` (which performs
  the bf16 -> fp32 grad upcast and clips on fp32 master grads); under
  DeepSpeed, keep the existing ``accelerator.clip_grad_norm_`` path so
  ``BF16_Optimizer`` continues to clip internally.
- On resume, after ``accelerator.load_state(...)`` repopulates the live
  bf16 weights, ``rebuild_masters_from_live(policy.parameters())``
  re-syncs the fp32 masters so the inner optimizer's restored Adam
  state operates over consistent masters.

Under DDP this runs per-rank: gradients are already cross-rank synced
in bf16 by ``accelerator.backward``, so the per-rank fp32 upcast and
clip yield identical fp32 norms with no extra reduction.

Adds CPU-only unit tests in ``tests/optim/test_master_weights.py``
covering construction, convergence, bf16 live-param sync, fp32 Adam
state, ``clip_grad_norm_`` magnitude, ``state_dict`` round-trip,
``zero_grad`` semantics, ``from_existing`` hyperparameter preservation,
and ``rebuild_masters_from_live``. The existing diagnostic scripts and
``tests/scripts/test_verify_adam_dtype_pure.py`` are unchanged.

Refs: #181, #176

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
Two follow-ups to the live A/B script:

1. Cast model to bf16 unconditionally so the script faithfully reproduces
   the train.py prefix that triggers the bug. Previously cast_if_deepspeed
   only fired under DeepSpeed, so on DDP the model stayed fp32 and the
   script printed fp32 state — neither demonstrating the bug nor exercising
   the wrapper. Wrapper coverage lives in tests/optim/test_master_weights.py;
   this script's job is the minimal raw repro.

2. Set train_micro_batch_size_per_gpu=1 on the DeepSpeed plugin before
   prepare(). accelerate's DS path normally infers it from a passed
   dataloader; this script has none, so prepare() raised ValueError.

Also matches input batch dtype to the model's param dtype so the synthetic
forward works regardless of which backend cast the model.

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
setdefault() was a no-op because the key is initialized to the string "auto"
by accelerate's DeepSpeedPlugin (utils/dataclasses.py:1268). accelerate's
is_auto() returns True for "auto", which then trips the dataloaderless guard
in _prepare_deepspeed (accelerator.py:2127) and raises ValueError.

Overwrite the value unconditionally so is_auto() returns False and the else
branch runs with batch_size_per_device = 1.

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
…tate

Two real bugs in the live A/B script under DeepSpeed:

1. DeepSpeedOptimizerWrapper.step() is a no-op (accelerate/utils/deepspeed.py:313);
   the actual DS step runs inside accelerator.backward(loss) via
   DeepSpeedEngineWrapper.backward -> engine.step(). Replace
   loss.backward() + optimizer.step() with accelerator.backward(loss) so
   DS actually steps and populates state. (DDP path unaffected:
   accelerator.backward is loss.backward() under DDP.)

2. Under DS the populated state lives on fp32 flat-master partitions, not
   on the bf16 model parameters returned by named_parameters(). Walk the
   .optimizer chain (defence-in-depth) until .state is non-empty, then
   fall back to <flat-master-N> labels for keys that don't appear in
   model.named_parameters().

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
The three verify_adam_*.py scripts were created to reproduce the bug in
issue #181 and demonstrate the fix's behaviour. Now that the bug is
documented and the MasterWeightOptimizer wrapper has its own production
unit tests (tests/optim/test_master_weights.py), the reproducers are no
longer useful for ongoing reuse — keeping them adds maintenance surface
for no recurring value.

Removes: verify_adam_dtype_pure.py, verify_adam_dtype_live.py,
verify_adam_divergence.py, and the test that depended on the pure script.

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
Bug: ``MasterWeightOptimizer.from_existing`` is called BEFORE
``accelerator.prepare(...)`` in train.py, which is the moment the
policy moves from CPU to GPU. The wrapper consequently clones fp32
masters from the bf16 live params while live is still on CPU, and the
masters stay on CPU even after live is migrated. Adam then runs on
CPU master tensors, paying a CPU<->GPU memcpy on every grad upcast
and param downcast and never letting the master state appear in
nvidia-smi accounting. Functional but very slow under DDP.

Two fixes:

1. ``rebuild_masters_from_live``: replace the in-place
   ``master.data.copy_(...)`` (which preserves master's old device)
   with a re-assignment ``master.data = live.data.detach().to(fp32).clone()``
   that follows live's current device. Migrate any populated Adam
   state (``exp_avg`` / ``exp_avg_sq``) to the new device before the
   swap so the next ``step()`` does not crash on a device-mismatch
   addmul. Parameter object identity is preserved, so the inner
   optimizer's ``param_groups`` references remain valid.

2. ``train.py``: call ``inner_opt.rebuild_masters_from_live(policy.parameters())``
   once after ``accelerator.prepare`` so the migration fires on every
   fresh run, not just on resume. The resume path already calls it
   after ``load_state``; this just adds the same call on the
   non-resume path.

Refs: #181, #177
Three GPU-marked tests covering the bug fixed in dd1154e (where fp32
masters cloned at wrap-time on CPU were silently left behind when
accelerator.prepare migrated the live policy to GPU, causing Adam to
run on CPU master tensors with per-step GPU<->CPU memcpy):

1. fresh-run migration: masters constructed on CPU follow the live params
   to CUDA after rebuild_masters_from_live;
2. resume-path migration with populated state: previously-allocated
   exp_avg / exp_avg_sq move to the new device, and a follow-up step()
   succeeds (no device-mismatch addmul_);
3. parameter-object identity is preserved across migration so the inner
   optimizer's param_groups references stay valid.

Excluded from CPU CI by @pytest.mark.gpu; will run in GPU CI alongside
the policy GPU tests.

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
Two integration bugs surfaced during a re-audit before flipping #182
out of draft. Both were latent — neither would have crashed; both
silently degrade training quality.

Bug 1 (CRITICAL): LR schedule never reached the inner optimizer.

  ``make_optimizer_and_scheduler`` builds the scheduler with
  ``cfg.scheduler.build(optimizer, ...)``, so ``lr_scheduler.optimizer``
  permanently points at the original ``torch.optim.AdamW``.
  ``MasterWeightOptimizer.from_existing`` discards that AdamW and builds
  a fresh inner with new ``param_groups`` dicts. ``lr_scheduler.step()``
  then mutates the orphaned optimizer's groups, never the wrapper's
  inner. Effective LR was frozen at the construction value forever.
  ``optimizer.param_groups[0]["lr"]`` (read for ``train_metrics.lr``)
  reads the inner's frozen value, so wandb showed a flat line that
  looked intentional.

  Fix: rebind ``lr_scheduler.optimizer = wrapped`` immediately after
  ``from_existing``, before ``accelerator.prepare``. The wrapper's
  ``param_groups`` property proxies to the inner, so scheduler mutations
  land on the right dicts. Verified by reading
  ``torch.optim.lr_scheduler.LRScheduler.step`` (lr_scheduler.py:296)
  and ``accelerate.Accelerator.prepare_scheduler`` (accelerator.py:2806).

Bug 2 (HIGH, latent): ``MasterWeightOptimizer`` was not recognised as a
``torch.optim.Optimizer``.

  ``Accelerator._prepare_one`` (accelerator.py:1403) is a strict
  ``isinstance`` check, so the wrapper was returned unchanged from
  ``Accelerator.prepare`` — never wrapped in ``AcceleratedOptimizer``.
  Consequences: gradient-accumulation gating in
  ``AcceleratedOptimizer.step`` / ``zero_grad`` was bypassed (running
  ``cfg.gradient_accumulation_steps > 1`` under DDP would silently
  collapse to accum=1); the wrapper was missing from
  ``Accelerator._optimizers``, so ``prepare_scheduler``'s identity
  match failed and ``AcceleratedScheduler`` got an empty optimizer
  list (less serious because Bug 1 already broke the schedule).

  Fix: subclass ``torch.optim.Optimizer`` without calling
  ``super().__init__()``. ``isinstance`` now succeeds; ``param_groups``,
  ``state``, ``defaults`` already proxy to ``self.inner`` via the
  existing ``@property`` accessors. ``AcceleratedOptimizer`` itself uses
  this same skip-super pattern, so the approach is proven safe.

Tests added (CPU, fast):

* ``test_wrapper_is_torch_optimizer_subclass``: pins Bug 2.
* ``test_lr_scheduler_reaches_inner_after_from_existing_with_rebind``:
  positive case for Bug 1 — runs LinearLR through the prefix that
  train.py executes and asserts the inner's lr decays.
* ``test_lr_scheduler_does_not_reach_inner_without_rebind``: negative
  control documenting the failure mode if a future change drops the
  rebind line.

Refs: #181

https://claude.ai/code/session_0128Zg2WaKkyhAf5JyLTJ544
…n flag

Phase 3.3: wire the already-declared attention_implementation flag to
actually dispatch. Adds an sdpa_attention_forward sibling to the existing
eager kernel and makes get_attention_interface() honor
self.config.attention_implementation (which was previously ignored).

Default stays "eager" in PI05Config; benchmark first via ATTENTION_IMPL
env var in profile_step, then flip the default once the A/B is clean.

Why this is scoped to one kernel swap and doesn't touch anything else:

  * attention_implementation was a dead flag. get_attention_interface()
    unconditionally returned eager regardless of config value. The
    validator accepted "eager" or "fa2" but "fa2" was never implemented.

  * pi05 bypasses HuggingFace's attention dispatch entirely. The custom
    PaliGemmaWithExpertModel.forward() at L353-503 hand-rolls the layer
    loop (Q/K/V projection, RoPE, attention, output projection, gated
    residual) and calls get_attention_interface() as the only attention
    chokepoint. HF's GemmaDecoderLayer.forward never runs, so HF's
    _attn_implementation="sdpa" would have no effect either.

  * src/opentau/utils/transformers_patch.py patches (notably
    PatchedGemmaRMSNorm which returns a (tensor, gate) tuple) are
    orthogonal to attention — they feed into the custom layer loop
    which already unpacks the tuple at L412, L478, L498. The attention
    kernel only sees (Q, K, V, attention_mask).

Numerical note: sdpa_attention_forward does NOT upcast Q/K to float32
before the matmul (the eager kernel does at L575-576). SDPA/Flash does
the matmul in bf16 with fp32 accumulation inside softmax — cleaner and
faster, and matches modern attention conventions. Loss-equivalence
sanity check should be part of the validation run before flipping the
default.

Backward compat: "fa2" in the validator's allowed set is preserved so
existing configs don't throw; it now emits a one-time warning at config
construction and falls back to eager (which is what was effectively
happening anyway since "fa2" was never implemented).

Benchmark via profile_step.py:
  ATTENTION_IMPL=eager accelerate launch ... profile_step.py ...
  ATTENTION_IMPL=sdpa  accelerate launch ... profile_step.py ...

Expected on 8xA100 DDP bs=12: forward ~320ms (from 379ms), bwd ~600ms
(from 705ms), total ~990ms, samples/s ~97 (+17% vs 83.1). If we only
get half of that it means SDPA dispatched to the mem-efficient backend
instead of FlashAttention-2; still a win.

Files changed:
  src/opentau/policies/pi05/paligemma_with_expert.py:
    + import logging
    + extend validator to accept "sdpa", warn on "fa2"
    + make get_attention_interface dispatch on config flag
    + add sdpa_attention_forward method (~65 lines)
  src/opentau/scripts/profile_step.py:
    + ATTENTION_IMPL env var override before make_policy

Files explicitly NOT changed in this commit:
  src/opentau/utils/transformers_patch.py  (orthogonal to attention)
  src/opentau/policies/pi05/configuration_pi05.py  (default flip later)
  src/opentau/policies/pi0/paligemma_with_expert.py  (mirror in follow-up
    after pi05 benchmark validates the change)

https://claude.ai/code/session_0129LtYYua8s3sJ4gfcjn1VB
Phase 3.4: add opt-in per-layer activation checkpointing. Trades ~25-33%
same-batch compute for ~30-40 GB of activation memory per rank, enabling
a larger per-rank batch that amortizes fixed per-step cost (~70ms of
optim_step + clip + sync_gather that doesn't scale with bs).

Default stays False in PI05Config; set True in the training config JSON
or via GRAD_CHECKPOINT env var on profile_step to benchmark.

Design notes:

1. Custom forward loop. PaliGemmaWithExpertModel.forward hand-rolls a
   dual-tower per-layer loop that bypasses HuggingFace's GemmaDecoderLayer
   entirely, so HF's model.gradient_checkpointing_enable() would be a
   no-op. We instead extract the per-layer body into a new _run_layer
   method and wrap the call in torch.utils.checkpoint.checkpoint with
   use_reentrant=False (modern, DDP-safe, RNG-preserving) when the flag
   is active AND model is in training mode.

2. Backend safety. torch.utils.checkpoint is only semantically correct
   under backends that replicate full params during forward. DDP
   (MULTI_GPU), NO, and DeepSpeed ZeRO-1/2 all qualify; ZeRO-3 and FSDP
   re-shard parameters and rely on forward-time all-gather hooks that
   plain checkpoint does not re-trigger during backward recompute,
   producing silent grad corruption or NCCL hangs.

   train.py adds a strict startup guard: if gradient_checkpointing=True
   and distributed_type is not in the allowlist (or is DEEPSPEED with
   zero_stage >= 3), raise ValueError with a pointer to the safe
   backends. Silent incorrectness is worse than a failed job start.

3. Past KV cache mutation. The layer body writes to past_key_values[layer_idx]
   when fill_kv_cache=True. Each layer writes a unique key so checkpoint
   recompute is idempotent. Hoisted the `if past_key_values is None: {}`
   init out of the loop so _run_layer always receives a non-None dict
   under fill_kv_cache=True — makes the mutation pattern explicit.

4. RMSNorm tuple return, gated residual, AdaRMS conditioning, custom
   RoPE, attention dispatch, dropout — all preserved inside _run_layer
   with byte-identical behavior vs the inlined original.

Files changed:

  src/opentau/policies/pi05/paligemma_with_expert.py:
    + gradient_checkpointing field on PaliGemmaWithExpertConfig
    + _run_layer method (extracted per-layer body)
    + forward calls _run_layer, optionally wrapped in checkpoint
    + past_key_values={} hoist before the layer loop

  src/opentau/policies/pi05/configuration_pi05.py:
    + gradient_checkpointing field on PI05Config, default False

  src/opentau/policies/pi05/modeling_pi05.py:
    + thread gradient_checkpointing into PaliGemmaWithExpertConfig

  src/opentau/scripts/train.py:
    + strict ValueError guard for unsupported distributed backends,
      covering both top-level distributed_type (MULTI_GPU/NO/DEEPSPEED)
      and DeepSpeed zero_stage >= 3 subcase.

  src/opentau/scripts/profile_step.py:
    + GRAD_CHECKPOINT env var override (benchmark-only).

TODO (separate PR): mirror to sibling policies that share the same
custom layer loop — src/opentau/policies/pi0/paligemma_with_expert.py
has the identical structure and would benefit from the same flag.
Deferred here to keep this PR focused on pi05.

Benchmark via:
  GRAD_CHECKPOINT=true FIND_UNUSED_PARAMS=false accelerate launch \
      --config_file configs/libero/reproduce_pi05_libero_accelerate_config_ddp_8gpu.yaml \
      src/opentau/scripts/profile_step.py \
      --config_path=configs/libero/reproduce_pi05_libero.json \
      --batch_size=20

Expected: bs=12 with ckpt runs ~25-33% slower vs bs=12 eager (1155 ms -> ~1450 ms);
bs=20 with ckpt should fit in memory and recover samples/s to ~95-105 (+15-25%
over the 83.1 baseline). If bs=20 OOMs we scale down to bs=18/16; if bs=20 fits
but throughput doesn't improve, the fixed-cost portion is smaller than expected
and we revert.

https://claude.ai/code/session_0129LtYYua8s3sJ4gfcjn1VB
… benchmarks

Before this fix, profile_step.py reproduced the issue #181 bug: it cast
the policy to bf16, then built torch.optim.AdamW directly over the bf16
params. Adam state (exp_avg, exp_avg_sq) was therefore allocated in bf16
under any non-DeepSpeed backend, taking ~half the memory of the fp32
master state PR #182 enforces in train.py.

The benchmark consequently:
- under-reported per-rank memory by ~+20 GB
- over-reported the largest batch that fits

(In practice: the post-#182 DDP probes here landed on the same bs=11-12
ceiling that PR #176 reported for the buggy bf16-state DDP, exactly
because both runs were measuring the same buggy config.)

Mirror train.py's wrap (src/opentau/scripts/train.py:272-279):

  - After make_optimizer_and_scheduler (and after any FUSED_ADAMW
    rebuild), if accelerator.distributed_type is not DEEPSPEED, wrap
    the optimizer via MasterWeightOptimizer.from_existing(...).
  - Dispatch the clip-grad-norm site between the wrapper's
    clip_grad_norm_ and accelerator.clip_grad_norm_ so the bf16->fp32
    grad upcast is amortized into the clip phase, matching train.py.

DeepSpeed already provides equivalent semantics via BF16_Optimizer, so
the wrap is skipped on that backend (matching train.py).

Refs: #177, #181, #182
Bug: ``MasterWeightOptimizer.from_existing`` clones fp32 masters from
the bf16 live params at wrap-time. In the train.py / profile_step.py
flow, the wrap fires AFTER ``policy.to(torch.bfloat16)`` but BEFORE
``accelerator.prepare(...)`` -- which is when the policy gets moved
from CPU to GPU. The masters were therefore allocated on CPU and stayed
there: Adam ran on CPU master tensors, every grad upcast and param
downcast paid a CPU<->GPU memcpy, and the fp32 master state never
showed up in nvidia-smi accounting (so memory benchmarks under-reported
peak by ~+20 GB / rank, and per-step latency was artificially high
from the CPU<->GPU traffic).

Two fixes:

1. ``rebuild_masters_from_live``: replace the in-place
   ``master.data.copy_(...)`` (which preserves master's old device) with
   a re-assignment ``master.data = live.data.detach().to(fp32).clone()``
   that moves master to live's current device. Migrate any existing
   Adam state (``exp_avg`` / ``exp_avg_sq``) to the new device before
   the swap so subsequent ``step()`` calls do not crash on a device-
   mismatch addmul. Parameter object identity is preserved, so the
   inner optimizer's ``param_groups`` references stay valid.

2. ``train.py`` and ``profile_step.py``: call
   ``inner_opt.rebuild_masters_from_live(policy.parameters())`` once
   after ``accelerator.prepare`` so the migration fires on every fresh
   run, not just on resume. The resume path in train.py already calls
   it after ``load_state``; this just adds the same call to the
   non-resume path.

Refs: #181, #182, #177
profile_step.py mirrors train.py's wrapping prefix and was missing the
LR scheduler rebind from 8be2cd1. Without it, profile_step's loss/lr
A/B benchmarks would silently freeze the inner LR (Bug 1 from PR #182).

Identical fix and rationale as train.py 18cf70f -> b2a8bfa.
@shuheng-liu shuheng-liu self-assigned this Apr 25, 2026
shuheng-liu added a commit that referenced this pull request Apr 25, 2026
shuheng-liu added a commit that referenced this pull request Apr 27, 2026
…ful encoder

The ``temporal_gate`` scalar added in ebf576b was a debugging aid to isolate
the DDP precision issue (#181) from any contribution of the temporal sublayer
itself. Now that #186 has landed (``MasterWeightOptimizer`` keeps fp32 master
copies for DDP, mirroring DeepSpeed ZeRO), DDP runs converge cleanly without
any gating: an end-to-end LIBERO run from ``william-yue/pi05_base`` reaches
~85% success on libero_10 by step 23k under DDP + master_weights, matching
the trajectory of the equivalent DeepSpeed ZeRO run to within sampling
variance, with α-gate values that converged to small magnitudes (~0.03–0.05
absolute) — i.e. the temporal pathway *does* contribute, but the gate
isn't doing meaningful interpolation work and the architecture works as
specified in the MEM paper.

The MEM paper (Torne, Pertsch, Walke et al., Section III-C + Appendix C)
specifies the encoder with no per-wrapper scalar gate. Keeping ``temporal_gate``
in the codebase introduces 6 extra learnable scalars and a state_dict-key
divergence from vanilla SigLIP for no demonstrated benefit on this workload.
This commit reverts the gate and the docstrings/tests that referenced it.

**Changes**

- ``SpaceTimeEncoderLayerWrapper.__init__``: drop
  ``self.temporal_gate = nn.Parameter(torch.zeros(()))``.
- ``SpaceTimeEncoderLayerWrapper.forward``: revert
  ``t_res = ... + self.temporal_gate.to(t_out.dtype) * t_out`` →
  ``t_res = ... + t_out``.
- Class + module docstrings: restore the original "no new learnable
  parameters" wording and the T=1-only single-frame-invariance claim.

**Tests**

- Drop ``test_alpha_zero_is_vanilla_siglip_at_t8``,
  ``test_alpha_gradient_flows``, ``test_state_dict_loads_without_alpha``.
- Restore ``test_state_dict_keys_match_vanilla_siglip`` (rename from
  ``test_state_dict_keys_are_vanilla_siglip_plus_alpha``) with the
  exact-equality invariant.
- Restore the strict ``set(vt_no_st.state_dict().keys()) == set(vt_st.
  state_dict().keys())`` check inside
  ``test_single_frame_invariance_structural``.

All 46 CPU tests pass (``pytest tests/policies/test_pi05_mem.py -m 'not gpu'``).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants