Fixing reward normalizer#3
Merged
Merged
Conversation
WilliamYue37
requested changes
Dec 29, 2025
shuheng-liu
requested changes
Dec 29, 2025
WilliamYue37
approved these changes
Dec 29, 2025
3 tasks
shuheng-liu
pushed a commit
that referenced
this pull request
Apr 23, 2026
Fixes three distinct correctness issues in the Gemma 3 4B backbone +
Gemma-v1 action expert wiring that a careful second pass uncovered:
1. Vision projector crashes at 448×448. `Gemma3MultiModalProjector`
hard-codes `patches_per_image = image_size // patch_size`, so the
default config's `vision_config.image_size = 896` made the
projector reshape `(B, 1024, 64, 64)` when SigLIP actually emits
`B × 1024` patches at 448 input — a runtime crash on the first
forward pass. Set `image_size = 448` to match π0.6's stated
resolution (→ 32 patches/side, 256 mm tokens/image, which also
matches the model card).
2. RoPE θ asymmetry between backbone and expert for global layers.
Gemma 3 interleaves local layers (θ = 10 000) with global layers
(θ = 1 000 000). The expert's own `rope_theta = 10 000` was being
applied to its Q/K even on global layers, so the shared
cross-attention ran backbone-Q (rotated at 1M) against expert-K
(rotated at 10k) — the `q·R(Δpos)·k` invariant breaks when the two
rotations live in different RoPE bases. Fix: both streams now use
the backbone's per-layer θ; the expert's fallback θ is ignored at
runtime and documented as such.
3. Sliding-window mask used dense indices instead of absolute
positions. During expert cross-attention the Q tokens sit at
`prefix_offsets + chunk_idx`, but `_build_sliding_window_mask`
built its mask from `torch.arange(seq_len)` and the call site
sliced `[:, :T_suffix, :]`. Result: every prefix key farther than
`window` from the dense suffix row index was dropped — i.e. the
sliding layers saw essentially no prefix during cross-attention.
Fix: take `(query_positions, key_positions)` and compute
`|pos_q - pos_k| < window` over absolute positions; stash the
prefix positions in the KV cache so the expert can reconstruct
them on each denoising step.
Tests:
* `TestSlidingWindowMask.test_cross_attention_uses_absolute_positions`
— regression guard for #3.
* `TestGemma3WithExpertConfig.test_vision_image_size_matches_input_resolution`
and `.test_projector_accepts_448_inputs` — regression guards
for #1 (config invariant + end-to-end projector forward).
* `TestRopeThetaSymmetryDuringForward.test_expert_uses_backbone_per_layer_theta`
— regression guard for #2, uses monkeypatched `apply_rope` to
observe exactly which θ each layer asks for.
Local pytest: 29 passed / 1 deselected (up from 25). Full policies/
CPU suite: 75 passed / 2 skipped / 6 deselected. Pre-commit: all
hooks green (ruff, ruff-format, pyupgrade, typos, bandit, gitleaks).
https://claude.ai/code/session_01MibvjbcZo38nxrx6n9giLi
This was referenced Apr 29, 2026
3 tasks
3 tasks
This was referenced May 4, 2026
This was referenced May 4, 2026
3 tasks
This was referenced May 12, 2026
This was referenced May 20, 2026
This was referenced May 23, 2026
shuheng-liu
added a commit
that referenced
this pull request
May 27, 2026
…ring Addresses inline review on PR #337: - #1 (test tautology): extract the strip predicate to a module-level `_is_normalize_buffer_key` helper in `modeling_pi07_low_level.py`. The production filter now calls it; the two new tests (`test_predicate_matches_all_eight_saved_buffer_keys`, `test_predicate_anchored_at_key_start`) call the same helper directly, so any future edit to the predicate trips them instead of passing vacuously. - #4 (config persistence): after the strip fires, reset `model.config.skip_normalization_weights = False` so that subsequent `save_pretrained` / resume / inference loads do not re-strip the now-correct finetuned buffers. - #2 (inf-buffer trap) + #3 (other policies share the trap): expanded the field docstring to call out the `dataset_stats` precondition, the one-shot semantic, and the current pi07_paligemma_low_level scope.
This was referenced May 27, 2026
shuheng-liu
added a commit
that referenced
this pull request
May 27, 2026
Addresses the new findings on PR #340: - (#1, high priority) strict=True + skip_normalization_weights=True used to raise RuntimeError("Missing key(s) ...") in pi0/value because the strip removed Normalize buffer keys before load_state_dict, but the buffers were still registered on the model from __init__. The other eight policies hardcode strict=False on the load_state_dict call so they didn't hit this. Fix: pass effective_strict = strict and not stripped_keys so the strict semantics apply unchanged on the default-load path and only relax when the strip actually fired. - (#2) pi0/value's _load_as_safetensor now resolves is_main_process via get_proc_accelerator() and threads it to the strip helper, so the INFO/WARNING fires once per load (not once per rank under DDP/FSDP). - (#3, optional) Added a note in _assert_normalize_buffers_initialized explaining why "any normalize buffer is inf" is a safe proxy for "the stripped buffer is still inf" (create_stats_buffers rejects partial stats, so buffers are all-finite or all-inf — no reachable mixed state). Pre-commit clean on --all-files (CI parity); CPU suite 460 passed.
shuheng-liu
added a commit
that referenced
this pull request
May 27, 2026
Addresses the third-pass review on PR #340: - (#1) Defensive _materialize() before torch.isinf in the inline check inside _assert_normalize_buffers_initialized. Today from_pretrained runs before accelerator.prepare so params are plain Tensors, but exposing the helper without _materialize would break on a future caller invoking it on a fully_shard'd model. _materialize is a no-op on plain Tensors. Added a comment noting both the FSDP2 motivation and the named_parameters-vs-named_buffers convention (Normalize uses nn.Parameter(requires_grad=False), not register_buffer). - (#2) Clarified the one-shot reset docstring on both PreTrainedConfig.skip_normalization_weights and the helper itself: the reset is in-memory only. Persistence requires save_pretrained; re-running from_pretrained on the original source path reloads True and re-strips. The common in-process resume path (train.py → save_checkpoint → cfg.save_pretrained) is unaffected. - (#3) The named_parameters clarification is now part of the inline comment block in _assert_normalize_buffers_initialized. - (#4 logging vs print inconsistency) Pre-existing in pi05; not in scope for this PR. Pre-commit clean on --all-files; 30 targeted tests pass.
3 tasks
shuheng-liu
added a commit
that referenced
this pull request
May 28, 2026
…ck warn) Two low-priority items from the re-review of 85c535b: - Replace `[{}] * n` with `[{} for _ in range(n)]` for `per_ds_info`. The multiplied-literal form creates N references to the same dict, so a future contributor doing `per_ds_info[idx]["override_X"] = ...` (a legitimate- looking accumulator pattern) would silently fan out across every slot. The `[None] * n` and `["<no-repo-id>"] * n` callsites stay as-is because their element types are immutable; comment notes the distinction so the pattern doesn't regress. - Add a tighter fit-time warning for the specific divergence the deferred fallback-name finding (#3 in the original review) flagged. When a fallback-keyed `repo_id` appears more than once in the mixture, fit-time POOLS those rows under one shared key while training keeps them as separate singleton heads (via `_make_dataset_names`'s `#N` dedup). The new Counter-based warning surfaces the exact duplicate `repo_id -> count` map so an operator can fix it before the misalignment ships. Mechanical; no new infra. Closes the divergence-detection gap while the proper dedup follow-up (tracked separately) is pending.
This was referenced May 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this does
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples:
How it was tested
Explain/show how you tested your changes.
Examples:
test_somethingintests/test_stuff.py.new_featureand checked that training converges with policy X on dataset/environment Y.some_function, it now runs X times faster than previously.How to checkout & try? (for the reviewer)
Provide a simple way for the reviewer to try out your changes.
Examples:
SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
Note: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
Note: Before submitting this PR, please read the contributor guideline.