[bugfix] thread contextual_seq_len from preprocessor to STULayer (proto sentinel + truncation total_uih_len + AOTI-friendly SLA builder)#501
Conversation
…[default=-1] sentinel Replaces the unreachable `if "contextual_seq_len" not in stu` guard at HSTUTransducer construction with the same `< 0` sentinel pattern that PR alibaba#500 introduced for `stu.scaling_seqlen`. Why the guard was unreachable: `tzrec/utils/config_util.py:66` builds the `stu` dict via `MessageToDict(..., including_default_value_fields=True, preserving_proto_field_name=True)`. With the previous proto declaration `optional uint32 contextual_seq_len = 13;` (no explicit default), proto2 fills the dict with `contextual_seq_len: 0` for any STU message that doesn't pin a value, so the membership test never matches and the `self._input_preprocessor.contextual_seq_len()` fallback never runs. Every `STULayer` ended up with `contextual_seq_len=0` regardless of the preprocessor's actual contextual-feature count. PR alibaba#450 introduced the guard intending exactly this fallback, but the dict-fill behavior of `including_default_value_fields=True` defeated it silently. Symptom: the failing CI run on the PR alibaba#500 branch (run 25497698553 / job 74821854417) crashed `test_rank_ultra_hstu_cutlass_train_eval_export` with a step-0 `l2_loss_watchtime: 1921847424.00` followed by a CUDA IMA on step 1. `compute_stu_truncation_plan` in `tzrec/ops/hstu_attention_utils.py` derives `uih_lengths = seq_lengths - contextual_seq_len - num_targets`; with `contextual_seq_len=0` but a 6-feature contextual prefix actually present, `uih_lengths` overshoots, `drop_count` and `new_lengths` desync from the apply-truncation kernel, and the truncated jagged buffer indexes outside the valid range. The dlrm_hstu_cutlass test does not exercise SLA + truncation + MoT so the bad value is numerically harmless there, which is why the regression slipped past existing CI. Fix: - `tzrec/protos/module.proto`: `optional uint32 contextual_seq_len = 13;` -> `optional int32 contextual_seq_len = 13 [default = -1];` (mirrors PR alibaba#500's `scaling_seqlen` field). Wire-compat: a uint32=0 from any legacy serialized config decodes as int32=0 and an explicit 0 still reads as "no contextual" (sentinel is `< 0`, not `<= 0`). - `tzrec/modules/gr/hstu_transducer.py`: replace `if "contextual_seq_len" not in stu:` with `if stu.get("contextual_seq_len", -1) < 0:`. Comment explains the dict-fill interaction so the next reader doesn't reintroduce the bug. No other call site reads `stu["contextual_seq_len"]`. `tzrec/protos/module_pb2.py` is gitignored and regenerated at install time via `scripts/gen_proto.sh`, so no .py change is committed; the regenerated descriptor was sanity-checked locally (`d['contextual_seq_len'] == -1` for an empty STU under `including_default_value_fields=True`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…en-proto-default-neg1
…0 truncation path The previous version of this commit threaded `self._input_preprocessor.contextual_seq_len()` into `STULayer` whenever the user hadn't pinned `stu.contextual_seq_len`. That flips `compute_stu_truncation_plan` / `apply_stu_truncation_plan` to the `contextual_seq_len > 0` branch (`split_2D_jagged` prefix-split + tail- split + `concat_2D_jagged`), which produces a forward-pass CUDA IMA on the very first training step in the `test_rank_ultra_hstu_cutlass_train_eval_export` integration test -- reproduced on both Hopper (PR alibaba#501 H20 CI run 25543727919, job 74975071509) and Ampere (run 25543727906, job 74975071514). The IMA surfaces inside `dlrm_hstu.loss -> _fx_avg_batch_size`, with the actual faulting kernel happening earlier in the forward pass and only being detected by the next CUDA dispatch. The legacy semantics that PR alibaba#492's integration test was implicitly relying on are `STULayer.contextual_seq_len = 0`: the dict-fill bug in `config_to_kwargs` (proto2 `optional uint32` always present in `MessageToDict(including_default_value_fields=True, ...)` as `0`) made `if "contextual_seq_len" not in stu:` unreachable, so the preprocessor fallback never fired. The CUTLASS-HSTU + SLA + attn_truncation combination has therefore never been exercised with `cs > 0` in production, and the path has a latent bug. Fix the runtime regression I introduced while keeping the proto change so explicit user overrides work and we stay forward-compatible: - `stu.get("contextual_seq_len", -1) < 0` now falls back to `0` instead of the preprocessor value, restoring the on-master behavior that PR alibaba#492 / PR alibaba#500 CI was passing on. - Users who genuinely need the contextual-aware truncation can pin `stu.contextual_seq_len: <N>` explicitly; that path triggers the buggy kernel branch and remains a TODO. - Comment block in the source documents both the kernel bug (so the next person doesn't re-thread the value blindly) and the recipe for users to opt in. The proto change (`optional int32 contextual_seq_len = 13 [default = -1]`) and the sentinel-style guard pattern from the previous commit stay -- they're still the right shape for a future fix that re-enables the auto-thread once the kernel bug is addressed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ide attn_func Root cause of the cs>0 + SLA + attn_truncation forward-pass IMA on PR alibaba#501: `cutlass_hstu_mha` was unconditionally building `num_contexts_tensor` from `contextual_seq_len > 0` and converting `num_targets` to int32, then passing both to `hstu_attn_varlen_func` together with the SLA NFUNC `func` tensor. But the func tensor already encodes contextual prefix + target isolation (see `build_sla_func_tensor` -- it computes `effective_k2 = max(sla_k2, contextual_seq_len)` and a separate target boundary so history vs. target rows get distinct intervals; and see the `pytorch_hstu_mha` reference, which deliberately decodes the mask only from `attn_func` when it is set, ignoring `contextual_seq_len` / `num_targets`). `hstu_attn_varlen_func` uses `num_contexts` / `num_targets` for layout-level decisions (separate context-region bookkeeping inside the kernel) that don't compose with the func-driven mask: the kernel ends up indexing a context buffer that doesn't match the func layout, reads OOB, and corrupts memory. The fault surfaces only on the very first forward step when the contextual prefix overlaps the global region (`max(sla_k2=32, cs=6) = 32`) and a real attn_truncation is active -- which is exactly the `test_rank_ultra_hstu_cutlass_train_eval_export` configuration (sla_k1=256, sla_k2=32, attn_truncation_split_layer=2, attn_truncation_tail_len=512, two MoT channels, 6-feature contextual group). Reproduced on both Hopper (PR alibaba#501 H20 CI run 25543727919 / job 74975071509, traceback inside `dlrm_hstu.loss._fx_avg_batch_size`) and Ampere (run 25543727906 / job 74975071514). dlrm_hstu_cutlass doesn't trip it because that config has no SLA -> no `attn_func` -> the standard fixed-mask path that legitimately needs `num_contexts` / `num_targets`. Fix: gate `num_contexts_tensor` and `num_targets_int32` on `attn_func is None`. Same defense-in-depth pattern as the existing `causal` / `max_attn_len` gates a few lines up. Documents the kernel-layout interaction in a comment so the next person doesn't reintroduce it. This unblocks the proto/sentinel work in the previous commit -- restore `stu["contextual_seq_len"] = self._input_preprocessor.contextual_seq_len()` in `HSTUTransducer.__init__` so the dict-fill bug fix from PR alibaba#501's first commit (which the previous commit had to revert as a band-aid) is back in effect. Verification: - `python -m unittest tzrec.modules.gr.hstu_transducer_test` 7/7 pass locally (PyTorch backend; the unit-test stub already exercises the contextual prefix + truncation matrix `[ctx in {0, 3}]` x `[(split, tail) in {(0,0), (1,4)}]` x interleave). - CUTLASS-path verification has to happen on H20 CI (local A10 lacks the `hstu_ops_gpu` wheel). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…truncation total_uih_len Real root cause of `test_rank_ultra_hstu_cutlass_train_eval_export`'s forward-pass IMA on PR alibaba#501. (My previous two commits guessed at the cutlass kernel and at falling back to `cs=0`; both have been reverted in favor of this targeted fix.) `HSTUTransducer._replay_truncation_state` returned `plan.total_kept - total_targets` for the post-truncation `total_uih_len` it hands to `_postprocess`. But `plan.total_kept` is the "rest" portion only -- the contextual prefix is split off separately into `plan.total_prefix = B * contextual_seq_len` and re-concatenated on top of the kept rest by `apply_stu_truncation_plan`'s `cs > 0` branch. So the real post-truncation UIH length is `plan.total_kept + plan.total_prefix - total_targets`. Downstream, `_postprocess` builds `uih_offsets = cumsum(seq_lengths - num_targets)` from `plan.new_lengths` (which still includes the contextual prefix per sample), so `uih_offsets[-1] = total_kept + total_prefix`. Passing the under-counted `total_uih_len` to `split_2D_jagged` makes the Triton/CUTLASS kernel allocate `OutA` of shape `(total_kept - total_targets, D)`, but the kernel writes each sample's slice at `OutA[uih_offsets[i] + j]` for `j in [0, ctx + new_uih_b)` -- positions up to `total_kept + total_prefix - 1`. The tail `B * contextual_seq_len` writes are out-of-bounds, producing the forward-pass CUDA IMA detected on the next step's `.item()` sync inside `ContextualInterleavePreprocessor.forward` (`max_uih_len = fx_int_item(uih_seq_lengths.max())`). Surfaces on both Hopper (PR alibaba#501 H20 CI run 25543727919 / job 74975071509) and Ampere (run 25543727906 / job 74975071514). Why on-master with `cs=0` was silently fine: the dict-fill bug in `config_to_kwargs` that PR alibaba#450's guard tried to address (now properly fixed via the `[default = -1]` sentinel earlier in this PR) made `STULayer.contextual_seq_len = 0` regardless of the preprocessor's real value, which forces `compute_stu_truncation_plan(cs=0)` and `apply_stu_truncation_plan`'s simple (non-contextual) branch. `plan.total_prefix` is `0` in that branch, so `total_kept + 0 - total_targets == total_kept - total_targets` was correct only by coincidence. Once the proto/sentinel fix wires `cs=6` through, the bug was load-bearing. Why the unit test for `_replay_truncation_state` was passing despite the bug: it asserted the buggy expected value (`total_kept - total_targets`), and the surrounding `test_forward_end_to_end` cases use the PyTorch backend, whose `pytorch_split_2D_jagged` ignores `total_len_left` and sizes outputs from offsets directly -- so the OOB write never happens on that path. The Triton/CUTLASS kernel honors `total_len_left`, which is why the integration test (CUTLASS kernel) is the one that crashes. Local repro (4× A10 + cu129 + fbgemm_gpu_hstu wheel installed): - Pre-fix: train_eval crashes on first forward step (l2_loss_watchtime=1.92e9 then IMA inside `ContextualInterleavePreprocessor.forward`'s `.item()`). - Post-fix: train_eval finishes cleanly (loss goes 1.92e9 -> 1.74e9 over 45 steps, eval AUCs sane); standalone eval finishes cleanly. Updates `test_replay_truncation_state_replays_plan_and_assigns_correct_fields` to assert the corrected value (`plan.total_kept + plan.total_prefix - total_targets`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
519d891 to
194c8e6
Compare
Without this, after the previous commit re-enables auto-threading ``cs > 0`` into ``STULayer``, the integration test ``test_rank_ultra_hstu_cutlass_train_eval_export`` passes train_eval but fails AOTI export with a ``cudaErrorAssert`` from a Triton kernel generated for ``build_sla_func_tensor``. Why the export breaks while train_eval is clean: ``build_sla_func_tensor`` does ``seq_offsets[batch_ids]`` / ``seq_lengths[batch_ids]`` / ``num_targets[batch_ids]`` where ``batch_ids = repeat_interleave(arange(B), seq_lengths)`` is provably in ``[0, B)`` by construction. At eager runtime that's fine. But under AOTI compile, Inductor fuses this body with surrounding shape arithmetic that includes ``B`` (it inlines the post-truncation ``seq_offsets = arange(B+1) * cs + offsets_tail`` chain when ``cs > 0``), surfaces ``B`` as a kernel parameter (``ks0``), and emits ``tl.device_assert(0 <= batch_ids < ks0)``. AOTI compile-time autotune benches the kernel with ``rand_strided`` int32 garbage rather than the real ``repeat_interleave`` output, so the assert fires and aborts export with a device-side assert. Master with the dict-fill bug never traced this branch (because ``STULayer.contextual_seq_len`` was forced to ``0`` and the post-truncation ``seq_offsets`` reduces to plain ``cumsum(new_lengths)`` with no ``arange*cs`` in the chain), so the kernel that surfaces ``B`` was never generated -- which is why the ``cs=0`` path slipped past export CI. Fix: wrap the entire SLA NFUNC builder body in a single ``torch.library.define`` custom op (``tzrec::_sla_build_func_2d``). Inductor sees one opaque op call instead of fusing the body with surrounding code; no ``assert_indirect_indexing`` codegen, no ``rand_strided`` autotune crash. The eager Python body runs unchanged at runtime where ``batch_ids`` is provably valid. Same low-level ``Library`` API as ``_sla_broadcast_func_to_heads`` (the existing black-box op for the head-broadcast) -- avoids the AOTI multi-thread deadlock that ``@torch.library.custom_op`` triggers. Verification (4x A10 + cu129 + fbgemm_gpu_hstu wheel installed): - ``python -m unittest tzrec.modules.gr.hstu_transducer_test`` 7/7 pass. - ``python -m unittest tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass`` -> Ran 60 hypothesis examples, OK (CUTLASS NFUNC path goes through the custom op now). - ``python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export`` -> ``Ran 1 test in 159.337s, OK`` (full train_eval + eval + export + predict pipeline). Pre-PR-501 fa39911 passed it via the cs=0 dict-fill bug; PR alibaba#501 prior commits failed it; this commit closes the loop. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…unblock AOTI export Without this, after the previous commit fixes ``_replay_truncation_state``'s post-truncation ``total_uih_len`` math and the proto/sentinel cleanup re-enables auto-threading ``contextual_seq_len > 0`` into ``STULayer``, AOTI export of ``test_rank_ultra_hstu_cutlass_train_eval_export`` aborts with ``cudaErrorAssert`` from a Triton kernel generated for ``build_sla_func_tensor`` (see PR alibaba#501 H20 CI run 25543727919 / job 74975071509 and the matching Ampere lane). Two distinct AOTI-compile failures were exposed: 1. ``unsqueeze(0).expand(nheads, 3, total_q).contiguous()`` trips Inductor's ``combine_contiguous_dims`` for a 3-D iteration whose last dim is the symbolic ``total_q``; the pass emits ``ModularIndexing(idx, total_q, 3)`` whose sympy simplifier raises ``ZeroDivisionError`` during AOT compile. 2. ``seq_offsets[batch_ids]`` / ``seq_lengths[batch_ids]`` / ``num_targets[batch_ids]`` where ``batch_ids = repeat_interleave( arange(B), seq_lengths)``. Inductor materializes ``batch_ids`` as a separate kernel input, fuses the post-truncation ``new_x_offsets = arange(B+1)*cs + offsets_tail`` chain in (which is what flips on when ``cs > 0``), surfaces ``B`` as a kernel parameter (``ks0``), and emits ``tl.device_assert(0 <= batch_ids < ks0)`` on the loaded raw value (the ``assert_indirect_indexing`` codegen). At AOTI compile-time autotune the bench fills the input with ``rand_strided`` int32 garbage that almost never satisfies the bound, so the assert fires and aborts export. ``master`` with the dict-fill bug never traced this branch (``STULayer.cs`` was forced to ``0`` and the post-truncation ``seq_offsets`` reduces to plain ``cumsum(new_lengths)`` with no ``arange*cs``), so the kernel that surfaces ``B`` was never generated -- which is why master CI slipped past export. Earlier rev of this commit wrapped both bodies in ``torch.library.Library("tzrec", "FRAGMENT")`` custom ops (``tzrec::_sla_broadcast_func_to_heads`` / ``tzrec::_sla_build_func_2d``). That worked but is incompatible with the cpp AOTI runtime requirement. This rev keeps everything inline: - **Builder body**: replace ``seq_offsets[batch_ids]`` / ``seq_lengths[batch_ids]`` / ``num_targets[batch_ids]`` with ``torch.repeat_interleave(values, lengths)`` directly per slot: ``` seq_offsets_per_pos = repeat_interleave(seq_offsets[:B].contiguous(), seq_lengths) L = repeat_interleave(seq_lengths, seq_lengths) T = repeat_interleave(num_targets.to(int32), seq_lengths) pos_local = pos_global - seq_offsets_per_pos ``` Mathematically equivalent: ``repeat_interleave(values, lengths)`` produces the same per-position output as ``values[batch_ids]``. Inductor's lowering for the ``(values, lengths)`` overload is a single cumsum-driven scatter/gather kernel that does not surface a user-visible ``batch_ids`` tensor with an ``assert_indirect_indexing`` codegen, so the autotune-bench bounds check that randomly fails is never generated. The ``.narrow(0, 0, B)`` + ``.contiguous()`` on ``seq_offsets[:B]`` materializes the slice into a dense buffer (the original SliceView form crashes Inductor's searchsorted lowering at AOT compile, per the comment near ``torch.diff(seq_offsets_i32)`` above). - **Broadcast tail**: drop the explicit ``.contiguous()`` and return ``func_2d.unsqueeze(0).expand(nheads, 3, total_q)``. The consumer ``cutlass_hstu_mha`` is ``@torch.fx.wrap``'d (``tzrec/ops/_cuda/cutlass_hstu_attention.py:31``, confirmed leaf semantic in ``tzrec/ops/hstu_attention.py:121``); its body runs outside Inductor's compile boundary and already calls ``.contiguous()`` on ``attn_func`` itself (``cutlass_hstu_attention.py:136``). With ``.contiguous()`` deferred to that runtime call, ``combine_contiguous_dims`` never runs on the symbolic-last-dim broadcast and the sympy ``ZeroDivisionError`` goes away. - **``compute_stu_truncation_plan``**: change ``new_x_offsets = offsets_prefix + offsets_tail`` to the mathematically equivalent ``new_x_offsets = x_offsets - offsets_head``. The ``arange(B+1) * contextual_seq_len`` chain is what got Inductor to inline ``new_x_offsets[batch_ids] = cs * batch_ids + offsets_tail[batch_ids]`` into the SLA builder kernel and surface ``B`` in the first place; the cumsum-based subtraction leaves ``cs`` outside the consumer's dependency chain. ``offsets_prefix`` (still required by ``apply_stu_truncation_plan``'s prefix-split branch) and ``total_prefix = B_static * contextual_seq_len`` stay. Verification (4x A10 + cu129 + ``fbgemm_gpu_hstu`` wheel installed): - ``python -m unittest tzrec.modules.gr.hstu_transducer_test`` -> 7/7 pass. - ``python -m unittest tzrec.ops.hstu_attention_utils_test`` -> 27/27 pass. - ``python -m unittest tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass`` -> 60 hypothesis examples, OK (CUTLASS NFUNC parity unchanged). - ``python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export`` -> ``Ran 1 test in 162.842s, OK`` (full train_eval + eval + export + predict, no ``tzrec::`` custom ops registered). No ``torch._inductor.config`` overrides; no ``assert_indirect_indexing = False``; no ``triton.autotune_pointwise = False`` (would regress perf). Just the math rewrite + the ``.contiguous()`` deferral. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
679d1dd to
462ed66
Compare
| seq_offsets_starts = seq_offsets_i32.narrow(0, 0, B).contiguous() # (B,) | ||
| seq_offsets_per_pos = torch.repeat_interleave(seq_offsets_starts, seq_lengths) | ||
| pos_local = pos_global - seq_offsets_per_pos | ||
| L = torch.repeat_interleave(seq_lengths, seq_lengths) # per-position seq length | ||
| if num_targets is not None: | ||
| T = num_targets.to(torch.int32)[batch_ids] # per-position target count | ||
| T = torch.repeat_interleave(num_targets.to(torch.int32), seq_lengths) |
There was a problem hiding this comment.
This AOTI-compile-time-autotune workaround also runs on the eager runtime forward path (STULayer.forward → build_sla_func_tensor), so every CUTLASS layer of every batch pays it.
Old form: 1 repeat_interleave(arange(B), seq_lengths) (≈1 cumsum + 1 scatter, length total_q) + 3 cheap B-element gathers.
New form: 3 repeat_interleave(values, seq_lengths) calls (each ≈1 cumsum + 1 scatter at length total_q) + the narrow(...).contiguous() copy.
Net: roughly +2 cumsum kernels and +2 full-length scatters per SLA layer per forward. With a stack of N CUTLASS layers, that's 4N extra kernel launches per forward.
Consider gating the new form behind an export-only path (e.g., torch.fx._symbolic_trace.is_fx_tracing() / a module-level _export_mode flag) and keeping the original batch_ids = repeat_interleave(arange(B), seq_lengths) + 3 gathers on eager — the autotune device-assert only fires under AOTI compile-time bench, not at eager runtime.
Side note on line 113: seq_offsets_i32.narrow(0, 0, B) on a 1-D contiguous source is already contiguous, so .contiguous() is a no-op. If it's there to defeat an Inductor SliceView issue analogous to the seq_offsets[1:] problem documented at line 92, please add a one-line comment so a future cleanup pass doesn't drop it.
Code review summaryReviewed via code-quality, performance, test-coverage, doc-accuracy, and security subagents. The fix is well-scoped and the inline comments justifying each AOTI workaround are excellent. Three inline comments left covering the main concerns. Headline findings
Smaller items (skipped from inline)
SecurityNo noteworthy security/safety findings. The proto-sentinel-then-validate layering correctly prevents any negative Positive
|
…t, simplify comments
Test pin: add a 4-row parametrized
``test_contextual_seq_len_resolution`` that exercises the four states
the ``stu.get("contextual_seq_len", -1) < 0`` guard supports:
| stu dict | preprocessor | resolved |
|-------------------|--------------|----------|
| key absent | 5 | 5 |
| key=-1 (sentinel) | 5 | 5 |
| key=0 (explicit) | 5 | 0 |
| key=3 (explicit) | 5 | 3 |
The existing end-to-end matrix never carries ``contextual_seq_len`` in
its ``stu`` dict, so a regression to ``not in stu`` would silently pass
every prior case. Verified the new test catches it: replacing the
guard with ``not in stu`` makes case 1 (sentinel-fill) fail with
``AssertionError: -1 != 5``; restoring the guard makes it pass.
Doc + comment cleanup:
- ``STUTruncationPlan.total_kept`` field doc now spells out that it
excludes the contextual prefix and that callers must add
``total_prefix`` back to get the full post-truncation total.
- Drop the verbose narrative comments in ``build_sla_func_tensor``,
``compute_stu_truncation_plan``, ``HSTUTransducer.__init__`` and
``_replay_truncation_state`` -- bug history belongs in commit
messages, not in code. Each kept comment is now 2-5 lines explaining
the WHY, not the full backstory.
- Fix the stale "default-zero fill" wording in
``HSTUTransducer.__init__`` -- the proto now fills the sentinel
``-1``, not ``0``.
- Standardize the ``module.proto`` sentinel comments for
``contextual_seq_len`` and ``scaling_seqlen`` to the same
``Sentinel: < 0 (default) = ...`` pattern (grep-friendly).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ward_end_to_end Drop the standalone ``test_contextual_seq_len_resolution`` and fold its contract-pin cases into the existing ``test_forward_end_to_end`` parametrization. Each row carries explicit ``stu_cs`` / ``expected_cs`` columns; ``_build_transducer`` grows a ``stu_cs_override`` parameter that pins ``stu["contextual_seq_len"]`` in the dict passed to ``HSTUTransducer.__init__``. Five legacy rows keep ``stu_cs=None`` (key absent, falls back to the preprocessor's value); three new rows pin the sentinel guard: | name | stu_cs | ppr_cs | expected | |-------------------------------|--------|--------|----------| | sentinel_fill_falls_back | -1 | 3 | 3 | | explicit_zero_overrides | 0 | 3 | 0 | | explicit_positive_overrides | 5 | 3 | 5 | Each row also asserts ``layer._contextual_seq_len == expected_cs`` on every layer in the stack -- so all 8 cases double as both the forward-end-to-end correctness check AND the guard contract. Verified the merged test still pins the bug class: replacing the guard with ``not in stu`` makes exactly case 5 (``sentinel_fill_falls_back``) fail with ``AssertionError: -1 != 3``; restoring the guard makes all 10 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ng test rows
Replace the positional-tuple form of ``test_forward_end_to_end``'s
parametrization with ``param("name", split=..., tail=..., ...)``.
Each row is now self-describing -- no inline ``# split`` /
``# tail`` etc. comments needed and no risk of miscounting positional
slots when adding rows.
The first positional ``"name"`` keeps the test-name suffix (e.g.
``test_forward_end_to_end_5_sentinel_fill_falls_back``); everything
else is named kwargs. ``parameterized`` doesn't accept keyword-only
function signatures, so the test method declares each kwarg as a
regular positional parameter (the kwargs from ``param`` bind by
name regardless).
No behavior change. Verified: reverting the guard to ``not in stu``
makes case 5 ``sentinel_fill_falls_back`` fail with
``AssertionError: -1 != 3``; restoring it makes all 10 tests pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Carries the PR alibaba#501 fixes: contextual_seq_len proto sentinel + auto-thread, post-truncation total_uih_len fix, and the AOTI-friendly SLA NFUNC builder rewrite (custom-op-free). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
PR #450 added
if "contextual_seq_len" not in stu:toHSTUTransducer.__init__intending to forward_input_preprocessor.contextual_seq_len()into everySTULayerwhen the user hadn't pinnedstu.contextual_seq_len. That guard has been silently unreachable since day 1 becausetzrec/utils/config_util.pyrunsMessageToDict(..., including_default_value_fields=True, ...)and the proto field wasoptional uint32with no explicit default — so the dict always containscontextual_seq_len: 0and the membership test never matches. EverySTULayerwas constructed withcontextual_seq_len = 0regardless of the preprocessor's value.For models without SLA + attn_truncation + multi-channel MoT (e.g.,
dlrm_hstu_cutlass) the wrong value was numerically harmless. Forultra_hstu_cutlass_kuairand_1k(sla_k1=256, sla_k2=32, attn_truncation_split_layer=2, attn_truncation_tail_len=512, attn_num_layers=4, 2 MoT channels, 6-feature contextual group), it manifests asl2_loss_watchtime: 1.92e9on step 0 followed by a forward-pass CUDA IMA on step 1 (see PR #500 CI run 25497698553 / job 74821854417, and the matching Hopper repro on PR #501 H20 CI run 25543727919 / job 74975071509).This PR closes the loop end-to-end:
Proto sentinel (
tzrec/protos/module.proto):optional uint32 contextual_seq_len = 13;→optional int32 contextual_seq_len = 13 [default = -1];. Mirrors thescaling_seqlenpattern from PR [feat] stu.scaling_seqlen + drop autotune assert strip #500.Sentinel-style guard (
tzrec/modules/gr/hstu_transducer.py): replace the unreachableif "contextual_seq_len" not in stu:withif stu.get("contextual_seq_len", -1) < 0: stu["contextual_seq_len"] = self._input_preprocessor.contextual_seq_len().Post-truncation
total_uih_lenfix (tzrec/modules/gr/hstu_transducer.py:_replay_truncation_state): changeplan.total_kept - total_targetstoplan.total_kept + plan.total_prefix - total_targets.total_keptis the post-truncation "rest" portion only (excluding contextual prefix);_postprocessbuildsuih_offsetsfromseq_lengths - num_targetswhich still includes the prefix per sample, so the under-count made the Triton/CUTLASSsplit_2D_jaggedallocateOutAsmaller thanoffsets_left[-1]and write OOB on the very first forward step. (The unit test was asserting the buggy expected value too — also corrected.)AOTI-friendly SLA NFUNC builder (
tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor): with cs > 0 actually flowing through, AOTI export hit two new Inductor-compile failures the mastercs = 0path never tripped:unsqueeze(0).expand(nheads, 3, total_q).contiguous()triggers Inductor'scombine_contiguous_dimsfor a 3-D iteration with symbolic last dim →ModularIndexing(idx, total_q, 3)→ sympyZeroDivisionError.seq_offsets[batch_ids]/seq_lengths[batch_ids]/num_targets[batch_ids](wherebatch_ids = repeat_interleave(arange(B), seq_lengths)) materializesbatch_idsas a kernel input and emitstl.device_assert(0 <= batch_ids < B). Under AOTI compile-time autotune the bench fills the input withrand_stridedint32 garbage that never satisfies the bound — assert fires, export aborts withcudaErrorAssert.Fixes (no Python custom ops, no
torch._inductor.configoverrides):.contiguous()to the consumer: returnfunc_2d.unsqueeze(0).expand(nheads, 3, total_q).cutlass_hstu_mhais@torch.fx.wrap'd (FX leaf) and already calls.contiguous()itself (tzrec/ops/_cuda/cutlass_hstu_attention.py:136); deferring keepscombine_contiguous_dimsoutside Inductor's compile boundary.torch.repeat_interleave(values, lengths)directly per slot (seq_offsets_per_pos = repeat_interleave(seq_offsets[:B].contiguous(), seq_lengths), same forL/T). Mathematically equivalent tovalues[batch_ids], but Inductor's lowering for the(values, lengths)overload is a single cumsum-driven scatter/gather that doesn't surface a separatebatch_idskernel input — noassert_indirect_indexingcodegen, no autotune-bench bounds-check failure.compute_stu_truncation_plan:new_x_offsets = x_offsets - offsets_head(mathematically equivalent tooffsets_prefix + offsets_tail, but noarange(B+1)*cschain feeding into the SLA builder's seq_offsets — keepscs(and via itB) out of the consumer's SymInt graph).Wire-compat
uint32=0from any old serialized config decodes asint32=0. The< 0sentinel (not<= 0) preserves the legacy semantic that an explicitly-set 0 means "no contextual" rather than "use preprocessor".tzrec/protos/module_pb2.pyis gitignored and regenerated byscripts/gen_proto.shat install time — no checked-in.pychange. Regenerated descriptor sanity-checked locally:d['contextual_seq_len'] == -1for an empty STU underincluding_default_value_fields=True..pt2references notzrec::Python custom ops — runs in pure C++ runtime.Test plan
python -m unittest tzrec.modules.gr.hstu_transducer_test— 7/7 pass (coversctx in {0, 3}×(split, tail) in {(0,0), (1,4)}× interleave matrix).python -m unittest tzrec.ops.hstu_attention_utils_test— 27/27 pass (truncation-plan parity).python -m unittest tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass— 60 hypothesis examples, OK (CUTLASS NFUNC parity vs PyTorch reference unchanged).python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export—Ran 1 test in 162.842s, OK(full train_eval + eval + export + predict on 4×A10 + cu129 +fbgemm_gpu_hstuwheel).STU→{'contextual_seq_len': -1, …}→ sentinel< 0triggers preprocessor fallback; explicit 0 → respected as 0; explicit positive → respected verbatim.test_rank_ultra_hstu_cutlass_train_eval_exportand it passes (was the originally failing test).test_rank_dlrm_hstu_cutlass_train_eval_exportstill green (different code path, regression spot-check).🤖 Generated with Claude Code