diff --git a/.github/workflows/claude-implement-fixes.yml b/.github/workflows/claude-implement-fixes.yml index f88d9231..e3e97bc9 100644 --- a/.github/workflows/claude-implement-fixes.yml +++ b/.github/workflows/claude-implement-fixes.yml @@ -100,6 +100,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | A reviewer asked you to address review feedback on this PR. @@ -111,16 +112,33 @@ jobs: 2. For each actionable comment containing `@claude fix`, implement the fix on the PR's branch. 3. Skip comments that are questions, taste preferences, or already addressed. - 4. Run the test command from CLAUDE.md before pushing. + 4. Decide whether to run tests: + - If the diff is purely documentation, comments, formatting, + string-literal text, or otherwise CANNOT change runtime behavior, + you MAY skip tests. Be honest about confidence — when in doubt, run. + If you skip, the commit body MUST contain a line of the form: + tests: skipped — + - Otherwise (changes touching imports, function bodies, control + flow, types, configs read at runtime, dependencies, or build + manifests): run `pytest -m "not gpu" -n auto`. Scope to the + changed subtree where possible (e.g. `pytest tests/policies/test_pi05.py` + for a pi05 change) to keep the run fast. The commit body MUST contain: + tests: passed — + If tests fail: do NOT push; reply on the relevant PR comment + explaining the failure and stop. 5. Make ONE commit at the end of the session that addresses every comment you decided to act on — do NOT push one commit per comment. Subject line (must be < 80 chars per CLAUDE.md): [claude-fix] address review feedback on #${{ github.event.issue.number || github.event.pull_request.number }} - Commit body: a bulleted list, one bullet per addressed comment: + Commit body: a bulleted list, one bullet per addressed comment, + followed by the `tests:` line from step 4: - addresses @ (): + ... + tests: passed — pytest -m "not gpu" tests/policies/test_pi05.py + (or: tests: skipped — comment-only change, no runtime impact) Then push the single commit to the PR branch. 6. Reply individually to each addressed comment on the PR with diff --git a/.github/workflows/claude-pr-review.yml b/.github/workflows/claude-pr-review.yml index def19ebe..68ef60c8 100644 --- a/.github/workflows/claude-pr-review.yml +++ b/.github/workflows/claude-pr-review.yml @@ -50,7 +50,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | - --model claude-opus-4-7 + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | You are reviewing PR #${{ github.event.pull_request.number }} in diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml index 2ed9b5f5..bc6f8be9 100644 --- a/.github/workflows/cpu_test.yml +++ b/.github/workflows/cpu_test.yml @@ -102,7 +102,8 @@ jobs: python3 -c "import sys; print(sys.path)" python3 -c "import libero.libero" && echo "LIBERO config set successfully." echo "Running cpu based pytest and generating coverage report..." - pytest -m "not gpu" -n auto -v --cov=lerobot/ --cov-report=xml:cpu_test/cpu_test.xml --ignore=tests/planner/test_planner.py --ignore tests/utils/test_libero_utils.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/ + # TODO(#210): drop --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py once pi07 migrates to SpaceTimeSiglipVideoEncoder (#192). + pytest -m "not gpu" -n auto -v --cov=lerobot/ --cov-report=xml:cpu_test/cpu_test.xml --ignore=tests/planner/test_planner.py --ignore tests/utils/test_libero_utils.py --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_async_vector_env --deselect=tests/envs/test_factory.py::TestMakeEnv::test_make_env_sync_vector_env tests/ echo "Pytest execution and coverage report generation completed." - name: Upload coverage reports diff --git a/.github/workflows/extract-claude-lessons.yml b/.github/workflows/extract-claude-lessons.yml index 5826321e..e0a7cbf0 100644 --- a/.github/workflows/extract-claude-lessons.yml +++ b/.github/workflows/extract-claude-lessons.yml @@ -25,9 +25,13 @@ permissions: jobs: extract-lessons: + # Gate on the head branch name, not user.login: in this repo Claude Code + # pushes to a `claude/*` branch and a human opens the PR, so the PR's + # `user.login` never contains 'claude'. The branch prefix is the reliable + # signal that Claude touched the PR. if: >- github.event.pull_request.merged == true - && contains(github.event.pull_request.user.login, 'claude') + && startsWith(github.event.pull_request.head.ref, 'claude/') && !startsWith(github.event.pull_request.title, 'chore(claude): learn from') runs-on: ubuntu-latest timeout-minutes: 20 @@ -39,6 +43,7 @@ jobs: with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} claude_args: | + --model ${{ vars.CLAUDE_MODEL || 'claude-opus-4-7[1m]' }} --permission-mode bypassPermissions prompt: | Review the comments on this merged PR. If any reviewer feedback diff --git a/.github/workflows/gpu_test.yml b/.github/workflows/gpu_test.yml index 453db4b5..9c52a9f5 100644 --- a/.github/workflows/gpu_test.yml +++ b/.github/workflows/gpu_test.yml @@ -91,7 +91,8 @@ jobs: source .venv/bin/activate mkdir -p /tmp/libero-assets/libero/libero export LIBERO_CONFIG_PATH="$(pwd)/.github/assets/libero" - pytest -m "gpu" -n 0 -v tests/ + # TODO(#210): drop --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py once pi07 migrates to SpaceTimeSiglipVideoEncoder (#192). + pytest -m "gpu" -n 0 -v --ignore=tests/policies/test_pi07_paligemma_low_level_planner.py tests/ stop-runner: name: Stop GPU Runner diff --git a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py index 2b8f372f..2af13448 100644 --- a/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py +++ b/src/opentau/policies/pi07/high_level_planner/modeling_pi07_high_level.py @@ -888,9 +888,16 @@ def embed_prefix( Gemma 3 embedding layer. **Concatenation order** (training when memory and response are provided): - ``[images | language | metadata | ";\\n " | "Updated Memory: " | memory_tokens | + ``[images | language | metadata? | ";\\n "? | "Updated Memory: " | memory_tokens | "Subtask: " | response_tokens]`` + ``";\\n "`` is gated on ``metadata_tokens`` — it serves as the metadata → + ``"Updated Memory:"`` separator, so when no metadata is provided there is + nothing to terminate and emitting it would dangle spurious tokens. The + ``"Updated Memory: "`` anchor itself is unconditional because inference + relies on it as the autoregressive starting point for memory decoding + (memory_tokens is None at inference by design). + When ``memory_tokens`` / ``response_tokens`` are omitted (inference), only the fixed spans before those segments are present; memory and subtask text are filled in via KV-cache decoding plus an explicit ``"Subtask: "`` injection before response AR. @@ -898,7 +905,7 @@ def embed_prefix( Attention pattern (via ``att_masks`` cumsums): - Image + language tokens: bidirectional (``0``). - Metadata (if present): new bidirectional block (``[1, 0, …, 0]``). - - ``";\\n "`` (same string as ``encode(";\n ", add_special_tokens=False)``): continues previous block (``0``). + - ``";\\n "`` (only when metadata present): continues previous block (``0``). - ``"Updated Memory: "``: new bidirectional block (``[1, 0, …, 0]``). - Memory token slots: causal segment (``1`` per slot). - ``"Subtask: "`` (training): new block then causal continuation within span. @@ -973,20 +980,25 @@ def embed_prefix( pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) - prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) - prefix_end_tokens = torch.tensor( - [prefix_end_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + # ";\n " is the metadata -> "Updated Memory:" separator. With no metadata, + # there is nothing to terminate, so omit it; "Updated Memory:" still anchors + # AR memory decoding either way. + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - num_prefix_end_embs = prefix_end_emb.shape[1] - prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones( + bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device + ) - embs.append(prefix_end_emb) - pad_masks.append(prefix_end_mask) - att_masks += [0] * num_prefix_end_embs + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs memory_start_indicator_ids = self.language_tokenizer.encode( "Updated Memory: ", add_special_tokens=False diff --git a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py index 01dfeb41..ae18a54e 100644 --- a/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level_planner/modeling_pi07_low_level.py @@ -1152,17 +1152,24 @@ def embed_prefix( Concatenation order: - ``[videos | language | State: | state(T) | ", " | response | - Subgoal: | subgoal_images… | ", " | metadata | ";\\n " | + ``[videos | language | State: | state(T) | | response? | + Subgoal: | subgoal_images… | ", " | metadata? | ";\\n "? | ("Action:" + discrete_actions only when training)]`` + ```` is ``", "`` when at least one optional middle block + (response / subgoal / metadata) contributes real tokens, else ``":\\n"``. + The trailing ``";\\n "`` prefix-end is only emitted in the former case; + without optional content the state-end already serves as the separator + before ``"Action: "``, so appending another would dangle spurious tokens. + Attention pattern (via ``att_masks`` cumsums): - Video + language: bidirectional (``0``). - - ``State:``, projected state timestep tokens, comma after state: bidirectional (``0``). + - ``State:``, projected state timestep tokens, state-end separator: bidirectional (``0``). - Response spans: prefix-LM style block opening (``[1, 0, …]`` inside the segment). - ``Subgoal:``: new bidirectional block (``[1, 0, …]``). - Subgoal image patches per camera: bidirectional blocks (``[1, 0, …]``). - Commas/metadata / ``";\\n "``: mostly continued prefix blocks (see code). + - ``Action:`` indicator: each token is its own causal block (``[1, 1, …]``). - Discrete actions (training): causal ``1`` per timestep after ``Action:``. Args: @@ -1196,6 +1203,22 @@ def embed_prefix( att_masks = [] bsize = lang_tokens.shape[0] + # Whether any optional middle block (response / subgoal / metadata) will + # actually contribute real tokens to the prefix. When all are dropped the + # state-end separator collapses to ":\n" and the trailing prefix-end is + # omitted, eliminating spurious dangling tokens that would otherwise break + # the cumsum at the indicator -> first-discrete boundary. + has_response = ( + response_tokens is not None and response_masks is not None and bool(response_masks.any()) + ) + has_subgoal = ( + bool(subgoal_images) and bool(subgoal_img_masks) and any(bool(m.any()) for m in subgoal_img_masks) + ) + has_metadata = ( + metadata_tokens is not None and metadata_masks is not None and bool(metadata_masks.any()) + ) + has_any_optional = bool(has_response or has_subgoal or has_metadata) + for vid, vid_mask in zip(videos, vid_masks, strict=True): vid_emb = self.embed_video(vid) # (B, num_video_tokens, vlm_hidden) vid_emb = vid_emb.to(dtype=_preferred_dtype()) @@ -1252,7 +1275,10 @@ def embed_prefix( pad_masks.append(state_mask) att_masks += [0] * num_state_tokens # full attention with video and language - state_end_indicator_ids = self.language_tokenizer.encode(", ", add_special_tokens=False) + # When optional middle blocks follow, use ", " as a state -> optional separator; + # otherwise collapse to ":\n" so the trailing prefix-end can be omitted. + state_end_str = ", " if has_any_optional else ":\n" + state_end_indicator_ids = self.language_tokenizer.encode(state_end_str, add_special_tokens=False) state_end_tokens = torch.tensor( [state_end_indicator_ids] * bsize, device=lang_tokens.device, @@ -1352,20 +1378,26 @@ def embed_prefix( pad_masks.append(metadata_masks) att_masks += [1] + [0] * (metadata_emb.shape[1] - 1) - prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) - prefix_end_tokens = torch.tensor( - [prefix_end_indicator_ids] * bsize, - device=lang_tokens.device, - dtype=torch.long, - ) - prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) + # Only emit the ";\n " prefix-end when at least one optional middle block was added + # above. With no optional content, the state-end already collapsed to ":\n" and acts + # as the separator before "Action: " — appending another would dangle spurious tokens. + if has_any_optional: + prefix_end_indicator_ids = self.language_tokenizer.encode(";\n ", add_special_tokens=False) + prefix_end_tokens = torch.tensor( + [prefix_end_indicator_ids] * bsize, + device=lang_tokens.device, + dtype=torch.long, + ) + prefix_end_emb = self.gemma3_with_expert.embed_language_tokens(prefix_end_tokens) - num_prefix_end_embs = prefix_end_emb.shape[1] - prefix_end_mask = torch.ones(bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device) + num_prefix_end_embs = prefix_end_emb.shape[1] + prefix_end_mask = torch.ones( + bsize, num_prefix_end_embs, dtype=torch.bool, device=lang_tokens.device + ) - embs.append(prefix_end_emb) - pad_masks.append(prefix_end_mask) - att_masks += [0] * num_prefix_end_embs + embs.append(prefix_end_emb) + pad_masks.append(prefix_end_mask) + att_masks += [0] * num_prefix_end_embs if discrete_actions is not None: discrete_action_start_indicator_ids = self.language_tokenizer.encode( @@ -1387,7 +1419,11 @@ def embed_prefix( embs.append(discrete_action_start_emb) pad_masks.append(discrete_action_start_mask) - att_masks += [1] + [0] * (num_discrete_action_start_embs - 1) + # Each "Action: " indicator token is its own causal block. Using + # [1] + [0]*(N-1) collapses them into a single bidirectional block, which + # shifts the cumsum at the indicator -> first-discrete boundary by N-1 + # and breaks the discrete-action CE loss. + att_masks += [1] * num_discrete_action_start_embs discrete_action_emb = self.gemma3_with_expert.embed_discrete_actions(discrete_actions) embs.append(discrete_action_emb.to(dtype=_preferred_dtype())) diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index 5bf32a57..f6b74538 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -513,7 +513,25 @@ def test_complete_pi06_pipeline_integration_smoke(lerobot_dataset_metadata): config.output_features = {k: ft for k, ft in features.items() if ft.type is FeatureType.ACTION} config.input_features = {k: ft for k, ft in features.items() if k not in config.output_features} - policy = PI06Policy(config, dataset_stats=lerobot_dataset_metadata.stats) + # The shared lerobot_dataset_metadata fixture carries actions stats shaped + # (50, 32) — matching the default PI06Config(chunk_size=50). This test + # uses chunk_size=10 to keep the model small, so override the actions + # stats to (10, 32) before constructing Normalize buffers; otherwise + # `(actions - min) / (max - min + EPS)` mismatches at dim=1 (actions is + # (B, 10, 32) but the buffer is (50, 32)). + import copy + + import numpy as np + + dataset_stats = copy.deepcopy(lerobot_dataset_metadata.stats) + for k in ("max", "mean", "min", "std"): + dataset_stats["actions"][k] = np.full( + (config.chunk_size, 32), + float(dataset_stats["actions"][k].flatten()[0]), + dtype=np.float32, + ) + + policy = PI06Policy(config, dataset_stats=dataset_stats) policy.to(dtype=torch.bfloat16, device="cuda") batch = { diff --git a/tests/policies/test_pi07_cpu.py b/tests/policies/test_pi07_cpu.py index 85dea0eb..35b49df9 100644 --- a/tests/policies/test_pi07_cpu.py +++ b/tests/policies/test_pi07_cpu.py @@ -634,7 +634,12 @@ class TestEmbedPrefixConditionalGuards: def test_all_optional_blocks_absent_skips_emission(self): """All-False response_masks + no subgoal_images + all-False metadata_masks → - the prefix collapses to ``videos + lang + State: + state + ", " + ";\\n "``. + the prefix collapses to ``videos + lang + State: + state + ":\\n"``. + + With ``has_any_optional == False`` the state-end separator collapses + from ``", "`` to ``":\\n"`` (same fake-tokenizer length: 2 tokens) and + the trailing ``";\\n "`` prefix-end is omitted entirely so it cannot + dangle as a spurious separator before ``"Action: "``. """ method = _embed_prefix_method() fake = _make_fake_flow_matching() @@ -658,12 +663,11 @@ def test_all_optional_blocks_absent_skips_emission(self): # lang: 3 (prompt_len) # "State: ": 2 # state(T=1): 1 - # ", ": 2 - # ";\n ": 2 - # Total = 13 - assert embs.shape == (bsize, 13, 4) - assert pad_masks.shape == (bsize, 13) - assert att_masks.shape == (bsize, 13) + # ":\n": 2 (state-end; collapsed from ", " because no optionals) + # Total = 11 (";\n " prefix-end is omitted with no optional content) + assert embs.shape == (bsize, 11, 4) + assert pad_masks.shape == (bsize, 11) + assert att_masks.shape == (bsize, 11) # No causal boundaries (1's) anywhere in the att_masks — every block # should remain bidirectional with all-skip optional blocks. assert int(att_masks[0].sum().item()) == 0, ( @@ -744,3 +748,57 @@ def test_response_mask_any_true_emits_block(self): assert per_sample_sum == 1, ( f"expected exactly one causal boundary from the response block opening, got {per_sample_sum}" ) + + def test_discrete_actions_indicator_uses_per_token_causal_blocks(self): + """The ``"Action: "`` indicator must use ``[1]*N`` (one causal block per + token), not ``[1] + [0]*(N-1)`` (single bidirectional block). + + The buggy ``[1] + [0]*(N-1)`` pattern collapses the indicator into one + bidirectional block, shifting the cumsum at the indicator -> first-discrete + boundary by N-1 and breaking the discrete-action CE loss. This test pins + the tail of ``att_masks`` to all 1's after the indicator, so a regression + to the old pattern fails immediately. + """ + method = _embed_prefix_method() + fake = _make_fake_flow_matching() + bsize = 2 + kwargs = _build_default_inputs(batch_size=bsize) + # No optional middle blocks — keeps the prefix layout deterministic so + # the indicator + discrete-action span sits exactly at the tail. + kwargs["subgoal_images"] = [] + kwargs["subgoal_img_masks"] = [] + + num_action_tokens = 3 + kwargs["discrete_actions"] = torch.zeros(bsize, num_action_tokens, dtype=torch.long) + kwargs["discrete_action_masks"] = torch.ones(bsize, num_action_tokens, dtype=torch.bool) + + _, pad_masks, att_masks = method(fake, **kwargs) + + # Expected layout (no optional blocks; fake tokenizer encodes every + # indicator phrase to 2 tokens; ``discrete_action_emb`` matches its + # input shape): + # videos(3) + lang(3) + "State: "(2) + state(1) + ":\n"(2) + # + "Action: "(2) + discrete_actions(3) = 16 + num_indicator_tokens = 2 + base_prefix_len = 11 + total_len = base_prefix_len + num_indicator_tokens + num_action_tokens + assert att_masks.shape == (bsize, total_len) + assert pad_masks.shape == (bsize, total_len) + + # The first ``base_prefix_len`` positions are bidirectional (no causal + # boundaries) — same invariant as test_all_optional_blocks_absent. + assert int(att_masks[0, :base_prefix_len].sum().item()) == 0 + + # Tail invariant: every position from the indicator onward must be 1 + # (per-token causal blocks for the indicator + one causal step per + # discrete action). A regression to ``[1] + [0]*(N-1)`` on the + # indicator would put zeros at indices ``base_prefix_len + 1 .. + # base_prefix_len + num_indicator_tokens - 1`` and the assertion below + # would fail. + tail = att_masks[0, base_prefix_len:] + assert int(tail.sum().item()) == num_indicator_tokens + num_action_tokens, ( + f"expected all-ones tail of length {num_indicator_tokens + num_action_tokens} " + f"(indicator + discrete actions), got {tail.tolist()} — a zero in the indicator " + "span signals a regression to the old [1]+[0]*(N-1) pattern, which shifts the " + "cumsum at the indicator -> first-discrete boundary and breaks the CE loss." + ) diff --git a/tests/policies/test_pi07_low_level_planner.py b/tests/policies/test_pi07_low_level_planner.py index 1b233c13..52ef1dc3 100644 --- a/tests/policies/test_pi07_low_level_planner.py +++ b/tests/policies/test_pi07_low_level_planner.py @@ -291,7 +291,14 @@ def _verify_position_ids( if inference_mode: prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] else: - prefix_offsets = torch.sum(prefix_pad_masks[:, :-DISCRETE_ACTION_MAX_LENGTH], dim=-1)[:, None] + # Training: model's prefix_offsets exclude both the "Action: " indicator + # and the discrete-action span from cross-attention (matches pi05's + # discrete_action_indicator_max_length logic) so the action expert sees + # the same prefix length at train and inference. + action_lead_len = self._indicator_lens(tokenizer)["action_lead"] + prefix_offsets = torch.sum( + prefix_pad_masks[:, : -(action_lead_len + DISCRETE_ACTION_MAX_LENGTH)], dim=-1 + )[:, None] expected_suffix = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 assert torch.equal(suffix_position_ids, expected_suffix) @@ -312,12 +319,16 @@ def _verify_action_expert_attention_mask( prefix_pad_masks, suffix_pad_masks, suffix_att_masks, + tokenizer, inference_mode=False, ): if inference_mode: num_cross = prefix_pad_masks.shape[1] else: - num_cross = prefix_pad_masks.shape[1] - DISCRETE_ACTION_MAX_LENGTH + # Training: cross-attention excludes both the "Action: " indicator and + # the discrete-action span (mirrors the prefix_offsets logic above). + action_lead_len = self._indicator_lens(tokenizer)["action_lead"] + num_cross = prefix_pad_masks.shape[1] - action_lead_len - DISCRETE_ACTION_MAX_LENGTH expected = make_att_2d_masks( suffix_pad_masks, @@ -478,6 +489,7 @@ def capture_embed_suffix(*args, **kwargs): captured["prefix_pad_masks"], captured["suffix_pad_masks"], captured["suffix_att_masks"], + tokenizer, ) assert isinstance(loss, dict) @@ -586,6 +598,7 @@ def capture_embed_suffix_infer(*args, **kwargs): captured_infer["prefix_pad_masks"], captured_infer["suffix_pad_masks"], captured_infer["suffix_att_masks"], + tokenizer, inference_mode=True, )