Skip to content

[bugfix] thread contextual_seq_len from preprocessor to STULayer (proto sentinel + truncation total_uih_len + AOTI-friendly SLA builder)#501

Merged
tiankongdeguiji merged 8 commits into
alibaba:masterfrom
tiankongdeguiji:fix/contextual-seqlen-proto-default-neg1
May 9, 2026
Merged

[bugfix] thread contextual_seq_len from preprocessor to STULayer (proto sentinel + truncation total_uih_len + AOTI-friendly SLA builder)#501
tiankongdeguiji merged 8 commits into
alibaba:masterfrom
tiankongdeguiji:fix/contextual-seqlen-proto-default-neg1

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

@tiankongdeguiji tiankongdeguiji commented May 8, 2026

Summary

PR #450 added if "contextual_seq_len" not in stu: to HSTUTransducer.__init__ intending to forward _input_preprocessor.contextual_seq_len() into every STULayer when the user hadn't pinned stu.contextual_seq_len. That guard has been silently unreachable since day 1 because tzrec/utils/config_util.py runs MessageToDict(..., including_default_value_fields=True, ...) and the proto field was optional uint32 with no explicit default — so the dict always contains contextual_seq_len: 0 and the membership test never matches. Every STULayer was constructed with contextual_seq_len = 0 regardless of the preprocessor's value.

For models without SLA + attn_truncation + multi-channel MoT (e.g., dlrm_hstu_cutlass) the wrong value was numerically harmless. For ultra_hstu_cutlass_kuairand_1k (sla_k1=256, sla_k2=32, attn_truncation_split_layer=2, attn_truncation_tail_len=512, attn_num_layers=4, 2 MoT channels, 6-feature contextual group), it manifests as l2_loss_watchtime: 1.92e9 on step 0 followed by a forward-pass CUDA IMA on step 1 (see PR #500 CI run 25497698553 / job 74821854417, and the matching Hopper repro on PR #501 H20 CI run 25543727919 / job 74975071509).

This PR closes the loop end-to-end:

  1. Proto sentinel (tzrec/protos/module.proto): optional uint32 contextual_seq_len = 13;optional int32 contextual_seq_len = 13 [default = -1];. Mirrors the scaling_seqlen pattern from PR [feat] stu.scaling_seqlen + drop autotune assert strip #500.

  2. Sentinel-style guard (tzrec/modules/gr/hstu_transducer.py): replace the unreachable if "contextual_seq_len" not in stu: with if stu.get("contextual_seq_len", -1) < 0: stu["contextual_seq_len"] = self._input_preprocessor.contextual_seq_len().

  3. Post-truncation total_uih_len fix (tzrec/modules/gr/hstu_transducer.py:_replay_truncation_state): change plan.total_kept - total_targets to plan.total_kept + plan.total_prefix - total_targets. total_kept is the post-truncation "rest" portion only (excluding contextual prefix); _postprocess builds uih_offsets from seq_lengths - num_targets which still includes the prefix per sample, so the under-count made the Triton/CUTLASS split_2D_jagged allocate OutA smaller than offsets_left[-1] and write OOB on the very first forward step. (The unit test was asserting the buggy expected value too — also corrected.)

  4. AOTI-friendly SLA NFUNC builder (tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor): with cs > 0 actually flowing through, AOTI export hit two new Inductor-compile failures the master cs = 0 path never tripped:

    • unsqueeze(0).expand(nheads, 3, total_q).contiguous() triggers Inductor's combine_contiguous_dims for a 3-D iteration with symbolic last dim → ModularIndexing(idx, total_q, 3) → sympy ZeroDivisionError.
    • seq_offsets[batch_ids] / seq_lengths[batch_ids] / num_targets[batch_ids] (where batch_ids = repeat_interleave(arange(B), seq_lengths)) materializes batch_ids as a kernel input and emits tl.device_assert(0 <= batch_ids < B). Under AOTI compile-time autotune the bench fills the input with rand_strided int32 garbage that never satisfies the bound — assert fires, export aborts with cudaErrorAssert.

    Fixes (no Python custom ops, no torch._inductor.config overrides):

    • Defer .contiguous() to the consumer: return func_2d.unsqueeze(0).expand(nheads, 3, total_q). cutlass_hstu_mha is @torch.fx.wrap'd (FX leaf) and already calls .contiguous() itself (tzrec/ops/_cuda/cutlass_hstu_attention.py:136); deferring keeps combine_contiguous_dims outside Inductor's compile boundary.
    • Replace indirect indexing with torch.repeat_interleave(values, lengths) directly per slot (seq_offsets_per_pos = repeat_interleave(seq_offsets[:B].contiguous(), seq_lengths), same for L / T). Mathematically equivalent to values[batch_ids], but Inductor's lowering for the (values, lengths) overload is a single cumsum-driven scatter/gather that doesn't surface a separate batch_ids kernel input — no assert_indirect_indexing codegen, no autotune-bench bounds-check failure.
    • compute_stu_truncation_plan: new_x_offsets = x_offsets - offsets_head (mathematically equivalent to offsets_prefix + offsets_tail, but no arange(B+1)*cs chain feeding into the SLA builder's seq_offsets — keeps cs (and via it B) out of the consumer's SymInt graph).

Wire-compat

  • A uint32=0 from any old serialized config decodes as int32=0. The < 0 sentinel (not <= 0) preserves the legacy semantic that an explicitly-set 0 means "no contextual" rather than "use preprocessor".
  • tzrec/protos/module_pb2.py is gitignored and regenerated by scripts/gen_proto.sh at install time — no checked-in .py change. Regenerated descriptor sanity-checked locally: d['contextual_seq_len'] == -1 for an empty STU under including_default_value_fields=True.
  • AOTI exported .pt2 references no tzrec:: Python custom ops — runs in pure C++ runtime.

Test plan

  • Local python -m unittest tzrec.modules.gr.hstu_transducer_test — 7/7 pass (covers ctx in {0, 3} × (split, tail) in {(0,0), (1,4)} × interleave matrix).
  • Local python -m unittest tzrec.ops.hstu_attention_utils_test — 27/27 pass (truncation-plan parity).
  • Local python -m unittest tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass — 60 hypothesis examples, OK (CUTLASS NFUNC parity vs PyTorch reference unchanged).
  • Local python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_exportRan 1 test in 162.842s, OK (full train_eval + eval + export + predict on 4×A10 + cu129 + fbgemm_gpu_hstu wheel).
  • Local proto round-trip: empty STU{'contextual_seq_len': -1, …} → sentinel < 0 triggers preprocessor fallback; explicit 0 → respected as 0; explicit positive → respected verbatim.
  • CI Hopper lane runs test_rank_ultra_hstu_cutlass_train_eval_export and it passes (was the originally failing test).
  • CI confirms test_rank_dlrm_hstu_cutlass_train_eval_export still green (different code path, regression spot-check).

🤖 Generated with Claude Code

tiankongdeguiji and others added 2 commits May 8, 2026 11:56
…[default=-1] sentinel

Replaces the unreachable `if "contextual_seq_len" not in stu` guard at
HSTUTransducer construction with the same `< 0` sentinel pattern that
PR alibaba#500 introduced for `stu.scaling_seqlen`.

Why the guard was unreachable: `tzrec/utils/config_util.py:66` builds the
`stu` dict via `MessageToDict(..., including_default_value_fields=True,
preserving_proto_field_name=True)`. With the previous proto declaration
`optional uint32 contextual_seq_len = 13;` (no explicit default), proto2
fills the dict with `contextual_seq_len: 0` for any STU message that
doesn't pin a value, so the membership test never matches and the
`self._input_preprocessor.contextual_seq_len()` fallback never runs.
Every `STULayer` ended up with `contextual_seq_len=0` regardless of the
preprocessor's actual contextual-feature count. PR alibaba#450 introduced the
guard intending exactly this fallback, but the dict-fill behavior of
`including_default_value_fields=True` defeated it silently.

Symptom: the failing CI run on the PR alibaba#500 branch (run 25497698553 /
job 74821854417) crashed `test_rank_ultra_hstu_cutlass_train_eval_export`
with a step-0 `l2_loss_watchtime: 1921847424.00` followed by a CUDA IMA
on step 1. `compute_stu_truncation_plan` in
`tzrec/ops/hstu_attention_utils.py` derives
`uih_lengths = seq_lengths - contextual_seq_len - num_targets`; with
`contextual_seq_len=0` but a 6-feature contextual prefix actually
present, `uih_lengths` overshoots, `drop_count` and `new_lengths` desync
from the apply-truncation kernel, and the truncated jagged buffer
indexes outside the valid range. The dlrm_hstu_cutlass test does not
exercise SLA + truncation + MoT so the bad value is numerically
harmless there, which is why the regression slipped past existing CI.

Fix:
- `tzrec/protos/module.proto`: `optional uint32 contextual_seq_len = 13;`
  -> `optional int32 contextual_seq_len = 13 [default = -1];` (mirrors
  PR alibaba#500's `scaling_seqlen` field). Wire-compat: a uint32=0 from any
  legacy serialized config decodes as int32=0 and an explicit 0 still
  reads as "no contextual" (sentinel is `< 0`, not `<= 0`).
- `tzrec/modules/gr/hstu_transducer.py`: replace `if "contextual_seq_len"
  not in stu:` with `if stu.get("contextual_seq_len", -1) < 0:`. Comment
  explains the dict-fill interaction so the next reader doesn't
  reintroduce the bug. No other call site reads `stu["contextual_seq_len"]`.

`tzrec/protos/module_pb2.py` is gitignored and regenerated at install
time via `scripts/gen_proto.sh`, so no .py change is committed; the
regenerated descriptor was sanity-checked locally
(`d['contextual_seq_len'] == -1` for an empty STU under
`including_default_value_fields=True`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tiankongdeguiji added a commit to tiankongdeguiji/TorchEasyRec that referenced this pull request May 8, 2026
…0 truncation path

The previous version of this commit threaded
`self._input_preprocessor.contextual_seq_len()` into `STULayer` whenever
the user hadn't pinned `stu.contextual_seq_len`. That flips
`compute_stu_truncation_plan` / `apply_stu_truncation_plan` to the
`contextual_seq_len > 0` branch (`split_2D_jagged` prefix-split + tail-
split + `concat_2D_jagged`), which produces a forward-pass CUDA IMA on
the very first training step in the
`test_rank_ultra_hstu_cutlass_train_eval_export` integration test --
reproduced on both Hopper (PR alibaba#501 H20 CI run 25543727919, job
74975071509) and Ampere (run 25543727906, job 74975071514). The IMA
surfaces inside `dlrm_hstu.loss -> _fx_avg_batch_size`, with the actual
faulting kernel happening earlier in the forward pass and only being
detected by the next CUDA dispatch.

The legacy semantics that PR alibaba#492's integration test was implicitly
relying on are `STULayer.contextual_seq_len = 0`: the dict-fill bug in
`config_to_kwargs` (proto2 `optional uint32` always present in
`MessageToDict(including_default_value_fields=True, ...)` as `0`) made
`if "contextual_seq_len" not in stu:` unreachable, so the preprocessor
fallback never fired. The CUTLASS-HSTU + SLA + attn_truncation
combination has therefore never been exercised with `cs > 0` in
production, and the path has a latent bug.

Fix the runtime regression I introduced while keeping the proto change
so explicit user overrides work and we stay forward-compatible:
- `stu.get("contextual_seq_len", -1) < 0` now falls back to `0` instead
  of the preprocessor value, restoring the on-master behavior that
  PR alibaba#492 / PR alibaba#500 CI was passing on.
- Users who genuinely need the contextual-aware truncation can pin
  `stu.contextual_seq_len: <N>` explicitly; that path triggers the
  buggy kernel branch and remains a TODO.
- Comment block in the source documents both the kernel bug (so the
  next person doesn't re-thread the value blindly) and the recipe for
  users to opt in.

The proto change (`optional int32 contextual_seq_len = 13 [default =
-1]`) and the sentinel-style guard pattern from the previous commit
stay -- they're still the right shape for a future fix that re-enables
the auto-thread once the kernel bug is addressed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tiankongdeguiji added a commit to tiankongdeguiji/TorchEasyRec that referenced this pull request May 8, 2026
…ide attn_func

Root cause of the cs>0 + SLA + attn_truncation forward-pass IMA on PR alibaba#501:

`cutlass_hstu_mha` was unconditionally building `num_contexts_tensor` from
`contextual_seq_len > 0` and converting `num_targets` to int32, then
passing both to `hstu_attn_varlen_func` together with the SLA NFUNC
`func` tensor. But the func tensor already encodes contextual prefix +
target isolation (see `build_sla_func_tensor` -- it computes
`effective_k2 = max(sla_k2, contextual_seq_len)` and a separate target
boundary so history vs. target rows get distinct intervals; and see the
`pytorch_hstu_mha` reference, which deliberately decodes the mask only
from `attn_func` when it is set, ignoring `contextual_seq_len` /
`num_targets`).

`hstu_attn_varlen_func` uses `num_contexts` / `num_targets` for
layout-level decisions (separate context-region bookkeeping inside the
kernel) that don't compose with the func-driven mask: the kernel ends
up indexing a context buffer that doesn't match the func layout, reads
OOB, and corrupts memory. The fault surfaces only on the very first
forward step when the contextual prefix overlaps the global region
(`max(sla_k2=32, cs=6) = 32`) and a real attn_truncation is active --
which is exactly the `test_rank_ultra_hstu_cutlass_train_eval_export`
configuration (sla_k1=256, sla_k2=32,
attn_truncation_split_layer=2, attn_truncation_tail_len=512, two MoT
channels, 6-feature contextual group). Reproduced on both Hopper (PR
alibaba#501 H20 CI run 25543727919 / job 74975071509, traceback inside
`dlrm_hstu.loss._fx_avg_batch_size`) and Ampere (run 25543727906 / job
74975071514). dlrm_hstu_cutlass doesn't trip it because that config has
no SLA -> no `attn_func` -> the standard fixed-mask path that
legitimately needs `num_contexts` / `num_targets`.

Fix: gate `num_contexts_tensor` and `num_targets_int32` on `attn_func is
None`. Same defense-in-depth pattern as the existing `causal` /
`max_attn_len` gates a few lines up. Documents the kernel-layout
interaction in a comment so the next person doesn't reintroduce it.

This unblocks the proto/sentinel work in the previous commit -- restore
`stu["contextual_seq_len"] = self._input_preprocessor.contextual_seq_len()`
in `HSTUTransducer.__init__` so the dict-fill bug fix from PR alibaba#501's
first commit (which the previous commit had to revert as a
band-aid) is back in effect.

Verification:
- `python -m unittest tzrec.modules.gr.hstu_transducer_test` 7/7 pass
  locally (PyTorch backend; the unit-test stub already exercises the
  contextual prefix + truncation matrix `[ctx in {0, 3}]` x `[(split,
  tail) in {(0,0), (1,4)}]` x interleave).
- CUTLASS-path verification has to happen on H20 CI (local A10 lacks
  the `hstu_ops_gpu` wheel).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…truncation total_uih_len

Real root cause of `test_rank_ultra_hstu_cutlass_train_eval_export`'s
forward-pass IMA on PR alibaba#501. (My previous two commits guessed at the
cutlass kernel and at falling back to `cs=0`; both have been reverted
in favor of this targeted fix.)

`HSTUTransducer._replay_truncation_state` returned
`plan.total_kept - total_targets` for the post-truncation
`total_uih_len` it hands to `_postprocess`. But `plan.total_kept` is the
"rest" portion only -- the contextual prefix is split off separately
into `plan.total_prefix = B * contextual_seq_len` and re-concatenated
on top of the kept rest by `apply_stu_truncation_plan`'s `cs > 0`
branch. So the real post-truncation UIH length is `plan.total_kept +
plan.total_prefix - total_targets`.

Downstream, `_postprocess` builds `uih_offsets = cumsum(seq_lengths -
num_targets)` from `plan.new_lengths` (which still includes the
contextual prefix per sample), so `uih_offsets[-1] = total_kept +
total_prefix`. Passing the under-counted `total_uih_len` to
`split_2D_jagged` makes the Triton/CUTLASS kernel allocate `OutA` of
shape `(total_kept - total_targets, D)`, but the kernel writes each
sample's slice at `OutA[uih_offsets[i] + j]` for `j in [0, ctx + new_uih_b)`
-- positions up to `total_kept + total_prefix - 1`. The tail `B *
contextual_seq_len` writes are out-of-bounds, producing the forward-pass
CUDA IMA detected on the next step's `.item()` sync inside
`ContextualInterleavePreprocessor.forward` (`max_uih_len =
fx_int_item(uih_seq_lengths.max())`). Surfaces on both Hopper (PR alibaba#501
H20 CI run 25543727919 / job 74975071509) and Ampere (run 25543727906 /
job 74975071514).

Why on-master with `cs=0` was silently fine: the dict-fill bug in
`config_to_kwargs` that PR alibaba#450's guard tried to address (now properly
fixed via the `[default = -1]` sentinel earlier in this PR) made
`STULayer.contextual_seq_len = 0` regardless of the preprocessor's
real value, which forces `compute_stu_truncation_plan(cs=0)` and
`apply_stu_truncation_plan`'s simple (non-contextual) branch.
`plan.total_prefix` is `0` in that branch, so
`total_kept + 0 - total_targets == total_kept - total_targets` was
correct only by coincidence. Once the proto/sentinel fix wires
`cs=6` through, the bug was load-bearing.

Why the unit test for `_replay_truncation_state` was passing despite the
bug: it asserted the buggy expected value (`total_kept - total_targets`),
and the surrounding `test_forward_end_to_end` cases use the PyTorch
backend, whose `pytorch_split_2D_jagged` ignores `total_len_left` and
sizes outputs from offsets directly -- so the OOB write never happens
on that path. The Triton/CUTLASS kernel honors `total_len_left`, which
is why the integration test (CUTLASS kernel) is the one that crashes.

Local repro (4× A10 + cu129 + fbgemm_gpu_hstu wheel installed):
- Pre-fix: train_eval crashes on first forward step
  (l2_loss_watchtime=1.92e9 then IMA inside
  `ContextualInterleavePreprocessor.forward`'s `.item()`).
- Post-fix: train_eval finishes cleanly (loss goes 1.92e9 -> 1.74e9
  over 45 steps, eval AUCs sane); standalone eval finishes cleanly.

Updates `test_replay_truncation_state_replays_plan_and_assigns_correct_fields`
to assert the corrected value
(`plan.total_kept + plan.total_prefix - total_targets`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the fix/contextual-seqlen-proto-default-neg1 branch from 519d891 to 194c8e6 Compare May 8, 2026 10:31
tiankongdeguiji added a commit to tiankongdeguiji/TorchEasyRec that referenced this pull request May 8, 2026
Without this, after the previous commit re-enables auto-threading
``cs > 0`` into ``STULayer``, the integration test
``test_rank_ultra_hstu_cutlass_train_eval_export`` passes train_eval
but fails AOTI export with a ``cudaErrorAssert`` from a Triton kernel
generated for ``build_sla_func_tensor``.

Why the export breaks while train_eval is clean:
``build_sla_func_tensor`` does ``seq_offsets[batch_ids]`` /
``seq_lengths[batch_ids]`` / ``num_targets[batch_ids]`` where
``batch_ids = repeat_interleave(arange(B), seq_lengths)`` is provably in
``[0, B)`` by construction. At eager runtime that's fine. But under
AOTI compile, Inductor fuses this body with surrounding shape
arithmetic that includes ``B`` (it inlines the post-truncation
``seq_offsets = arange(B+1) * cs + offsets_tail`` chain when
``cs > 0``), surfaces ``B`` as a kernel parameter (``ks0``), and
emits ``tl.device_assert(0 <= batch_ids < ks0)``. AOTI compile-time
autotune benches the kernel with ``rand_strided`` int32 garbage
rather than the real ``repeat_interleave`` output, so the assert
fires and aborts export with a device-side assert.

Master with the dict-fill bug never traced this branch (because
``STULayer.contextual_seq_len`` was forced to ``0`` and the
post-truncation ``seq_offsets`` reduces to plain
``cumsum(new_lengths)`` with no ``arange*cs`` in the chain), so the
kernel that surfaces ``B`` was never generated -- which is why the
``cs=0`` path slipped past export CI.

Fix: wrap the entire SLA NFUNC builder body in a single
``torch.library.define`` custom op (``tzrec::_sla_build_func_2d``).
Inductor sees one opaque op call instead of fusing the body with
surrounding code; no ``assert_indirect_indexing`` codegen, no
``rand_strided`` autotune crash. The eager Python body runs unchanged
at runtime where ``batch_ids`` is provably valid. Same low-level
``Library`` API as ``_sla_broadcast_func_to_heads`` (the existing
black-box op for the head-broadcast) -- avoids the AOTI multi-thread
deadlock that ``@torch.library.custom_op`` triggers.

Verification (4x A10 + cu129 + fbgemm_gpu_hstu wheel installed):
- ``python -m unittest tzrec.modules.gr.hstu_transducer_test`` 7/7 pass.
- ``python -m unittest tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass``
  -> Ran 60 hypothesis examples, OK (CUTLASS NFUNC path goes through
  the custom op now).
- ``python -m unittest
  tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export``
  -> ``Ran 1 test in 159.337s, OK`` (full train_eval + eval + export +
  predict pipeline). Pre-PR-501 fa39911 passed it via the cs=0 dict-fill
  bug; PR alibaba#501 prior commits failed it; this commit closes the loop.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…unblock AOTI export

Without this, after the previous commit fixes
``_replay_truncation_state``'s post-truncation ``total_uih_len`` math
and the proto/sentinel cleanup re-enables auto-threading
``contextual_seq_len > 0`` into ``STULayer``, AOTI export of
``test_rank_ultra_hstu_cutlass_train_eval_export`` aborts with
``cudaErrorAssert`` from a Triton kernel generated for
``build_sla_func_tensor`` (see PR alibaba#501 H20 CI run 25543727919 / job
74975071509 and the matching Ampere lane).

Two distinct AOTI-compile failures were exposed:

1. ``unsqueeze(0).expand(nheads, 3, total_q).contiguous()`` trips
   Inductor's ``combine_contiguous_dims`` for a 3-D iteration whose
   last dim is the symbolic ``total_q``; the pass emits
   ``ModularIndexing(idx, total_q, 3)`` whose sympy simplifier raises
   ``ZeroDivisionError`` during AOT compile.

2. ``seq_offsets[batch_ids]`` / ``seq_lengths[batch_ids]`` /
   ``num_targets[batch_ids]`` where ``batch_ids = repeat_interleave(
   arange(B), seq_lengths)``. Inductor materializes ``batch_ids`` as a
   separate kernel input, fuses the post-truncation ``new_x_offsets =
   arange(B+1)*cs + offsets_tail`` chain in (which is what flips on
   when ``cs > 0``), surfaces ``B`` as a kernel parameter (``ks0``),
   and emits ``tl.device_assert(0 <= batch_ids < ks0)`` on the loaded
   raw value (the ``assert_indirect_indexing`` codegen). At
   AOTI compile-time autotune the bench fills the input with
   ``rand_strided`` int32 garbage that almost never satisfies the
   bound, so the assert fires and aborts export. ``master`` with the
   dict-fill bug never traced this branch (``STULayer.cs`` was forced
   to ``0`` and the post-truncation ``seq_offsets`` reduces to plain
   ``cumsum(new_lengths)`` with no ``arange*cs``), so the kernel that
   surfaces ``B`` was never generated -- which is why master CI
   slipped past export.

Earlier rev of this commit wrapped both bodies in
``torch.library.Library("tzrec", "FRAGMENT")`` custom ops
(``tzrec::_sla_broadcast_func_to_heads`` /
``tzrec::_sla_build_func_2d``). That worked but is incompatible with
the cpp AOTI runtime requirement. This rev keeps everything inline:

- **Builder body**: replace ``seq_offsets[batch_ids]`` /
  ``seq_lengths[batch_ids]`` / ``num_targets[batch_ids]`` with
  ``torch.repeat_interleave(values, lengths)`` directly per slot:
  ```
  seq_offsets_per_pos = repeat_interleave(seq_offsets[:B].contiguous(), seq_lengths)
  L = repeat_interleave(seq_lengths, seq_lengths)
  T = repeat_interleave(num_targets.to(int32), seq_lengths)
  pos_local = pos_global - seq_offsets_per_pos
  ```
  Mathematically equivalent: ``repeat_interleave(values, lengths)``
  produces the same per-position output as ``values[batch_ids]``.
  Inductor's lowering for the ``(values, lengths)`` overload is a
  single cumsum-driven scatter/gather kernel that does not surface a
  user-visible ``batch_ids`` tensor with an
  ``assert_indirect_indexing`` codegen, so the autotune-bench bounds
  check that randomly fails is never generated. The ``.narrow(0, 0, B)``
  + ``.contiguous()`` on ``seq_offsets[:B]`` materializes the slice
  into a dense buffer (the original SliceView form crashes Inductor's
  searchsorted lowering at AOT compile, per the comment near
  ``torch.diff(seq_offsets_i32)`` above).

- **Broadcast tail**: drop the explicit ``.contiguous()`` and return
  ``func_2d.unsqueeze(0).expand(nheads, 3, total_q)``. The consumer
  ``cutlass_hstu_mha`` is ``@torch.fx.wrap``'d
  (``tzrec/ops/_cuda/cutlass_hstu_attention.py:31``, confirmed leaf
  semantic in ``tzrec/ops/hstu_attention.py:121``); its body runs
  outside Inductor's compile boundary and already calls
  ``.contiguous()`` on ``attn_func`` itself
  (``cutlass_hstu_attention.py:136``). With ``.contiguous()`` deferred
  to that runtime call, ``combine_contiguous_dims`` never runs on the
  symbolic-last-dim broadcast and the sympy ``ZeroDivisionError``
  goes away.

- **``compute_stu_truncation_plan``**: change ``new_x_offsets =
  offsets_prefix + offsets_tail`` to the mathematically equivalent
  ``new_x_offsets = x_offsets - offsets_head``. The ``arange(B+1) *
  contextual_seq_len`` chain is what got Inductor to inline
  ``new_x_offsets[batch_ids] = cs * batch_ids + offsets_tail[batch_ids]``
  into the SLA builder kernel and surface ``B`` in the first place;
  the cumsum-based subtraction leaves ``cs`` outside the consumer's
  dependency chain. ``offsets_prefix`` (still required by
  ``apply_stu_truncation_plan``'s prefix-split branch) and
  ``total_prefix = B_static * contextual_seq_len`` stay.

Verification (4x A10 + cu129 + ``fbgemm_gpu_hstu`` wheel installed):
- ``python -m unittest tzrec.modules.gr.hstu_transducer_test`` -> 7/7 pass.
- ``python -m unittest tzrec.ops.hstu_attention_utils_test`` -> 27/27 pass.
- ``python -m unittest
  tzrec.ops.hstu_attention_test.HSTUAttentionTest.test_sla_attn_cutlass``
  -> 60 hypothesis examples, OK (CUTLASS NFUNC parity unchanged).
- ``python -m unittest
  tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export``
  -> ``Ran 1 test in 162.842s, OK`` (full train_eval + eval + export
  + predict, no ``tzrec::`` custom ops registered).

No ``torch._inductor.config`` overrides; no ``assert_indirect_indexing
= False``; no ``triton.autotune_pointwise = False`` (would regress
perf). Just the math rewrite + the ``.contiguous()`` deferral.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the fix/contextual-seqlen-proto-default-neg1 branch from 679d1dd to 462ed66 Compare May 9, 2026 03:31
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label May 9, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 9, 2026
@tiankongdeguiji tiankongdeguiji changed the title [bugfix] thread contextual_seq_len from preprocessor to STULayer via [default=-1] sentinel [bugfix] thread contextual_seq_len from preprocessor to STULayer (proto sentinel + truncation total_uih_len + AOTI-friendly SLA builder) May 9, 2026
Comment thread tzrec/ops/hstu_attention_utils.py Outdated
Comment on lines +113 to +118
seq_offsets_starts = seq_offsets_i32.narrow(0, 0, B).contiguous() # (B,)
seq_offsets_per_pos = torch.repeat_interleave(seq_offsets_starts, seq_lengths)
pos_local = pos_global - seq_offsets_per_pos
L = torch.repeat_interleave(seq_lengths, seq_lengths) # per-position seq length
if num_targets is not None:
T = num_targets.to(torch.int32)[batch_ids] # per-position target count
T = torch.repeat_interleave(num_targets.to(torch.int32), seq_lengths)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This AOTI-compile-time-autotune workaround also runs on the eager runtime forward path (STULayer.forwardbuild_sla_func_tensor), so every CUTLASS layer of every batch pays it.

Old form: 1 repeat_interleave(arange(B), seq_lengths) (≈1 cumsum + 1 scatter, length total_q) + 3 cheap B-element gathers.
New form: 3 repeat_interleave(values, seq_lengths) calls (each ≈1 cumsum + 1 scatter at length total_q) + the narrow(...).contiguous() copy.

Net: roughly +2 cumsum kernels and +2 full-length scatters per SLA layer per forward. With a stack of N CUTLASS layers, that's 4N extra kernel launches per forward.

Consider gating the new form behind an export-only path (e.g., torch.fx._symbolic_trace.is_fx_tracing() / a module-level _export_mode flag) and keeping the original batch_ids = repeat_interleave(arange(B), seq_lengths) + 3 gathers on eager — the autotune device-assert only fires under AOTI compile-time bench, not at eager runtime.

Side note on line 113: seq_offsets_i32.narrow(0, 0, B) on a 1-D contiguous source is already contiguous, so .contiguous() is a no-op. If it's there to defeat an Inductor SliceView issue analogous to the seq_offsets[1:] problem documented at line 92, please add a one-line comment so a future cleanup pass doesn't drop it.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 9, 2026

Code review summary

Reviewed via code-quality, performance, test-coverage, doc-accuracy, and security subagents. The fix is well-scoped and the inline comments justifying each AOTI workaround are excellent. Three inline comments left covering the main concerns.

Headline findings

  • Perf: the new 3× repeat_interleave(values, lengths) form in build_sla_func_tensor is an AOTI compile-time-autotune workaround but also runs on every layer of every eager forward (~4 extra kernels × N CUTLASS layers / forward). Consider gating behind an export-only path. (inline)
  • Test coverage: the headline bug fix (stu.get("contextual_seq_len", -1) < 0) isn't actually pinned by any test — every existing case has the key absent from the stu dict, so the old not in form would pass identically. A 3-row parameterization on the stu-dict key would pin the contract. (inline)
  • Doc: STUTruncationPlan.total_kept's field doc doesn't mention it excludes the contextual prefix, even though the new _replay_truncation_state + plan.total_prefix add-back depends on that. (inline)

Smaller items (skipped from inline)

  • hstu_transducer.py:108 comment says "default-zero fill" but the proto now fills -1, not 0 — minor wording cleanup.
  • The proto comments for contextual_seq_len (line 256-257) and scaling_seqlen (line 263) both use the < 0 (default) = sentinel pattern but with slightly different phrasing; standardizing helps grep-ability.
  • Mathematical equivalence of new_x_offsets = x_offsets - offsets_head vs. the old offsets_prefix + offsets_tail confirmed; the rerouting rationale (keeping cs out of the seq_offsets dependency chain for AOTI) is correctly captured in the inline comment at lines 258-270.

Security

No noteworthy security/safety findings. The proto-sentinel-then-validate layering correctly prevents any negative contextual_seq_len from reaching the kernels via the config path; the in-kernel < 0 validators in build_sla_func_tensor and compute_stu_truncation_plan retain defense-in-depth. Wire-compat claim (uint32→int32) holds for any value an old client could realistically encode.

Positive

  • Root cause analysis in the PR body is precise (the including_default_value_fields=True interaction is exactly what made the previous not in guard unreachable).
  • The + plan.total_prefix add-back fix in _replay_truncation_state is correctly justified by the layout mismatch between total_kept (rest-only) and _postprocess's seq_lengths - num_targets (includes prefix).
  • Removing the _sla_broadcast_func_to_heads custom op in favor of expand + caller-side .contiguous() is a wash perf-wise and reduces export-side complexity.

Comment thread tzrec/modules/gr/hstu_transducer.py Outdated
Comment thread tzrec/modules/gr/hstu_transducer.py
tiankongdeguiji and others added 4 commits May 9, 2026 12:18
…t, simplify comments

Test pin: add a 4-row parametrized
``test_contextual_seq_len_resolution`` that exercises the four states
the ``stu.get("contextual_seq_len", -1) < 0`` guard supports:

| stu dict          | preprocessor | resolved |
|-------------------|--------------|----------|
| key absent        | 5            | 5        |
| key=-1 (sentinel) | 5            | 5        |
| key=0  (explicit) | 5            | 0        |
| key=3  (explicit) | 5            | 3        |

The existing end-to-end matrix never carries ``contextual_seq_len`` in
its ``stu`` dict, so a regression to ``not in stu`` would silently pass
every prior case. Verified the new test catches it: replacing the
guard with ``not in stu`` makes case 1 (sentinel-fill) fail with
``AssertionError: -1 != 5``; restoring the guard makes it pass.

Doc + comment cleanup:
- ``STUTruncationPlan.total_kept`` field doc now spells out that it
  excludes the contextual prefix and that callers must add
  ``total_prefix`` back to get the full post-truncation total.
- Drop the verbose narrative comments in ``build_sla_func_tensor``,
  ``compute_stu_truncation_plan``, ``HSTUTransducer.__init__`` and
  ``_replay_truncation_state`` -- bug history belongs in commit
  messages, not in code. Each kept comment is now 2-5 lines explaining
  the WHY, not the full backstory.
- Fix the stale "default-zero fill" wording in
  ``HSTUTransducer.__init__`` -- the proto now fills the sentinel
  ``-1``, not ``0``.
- Standardize the ``module.proto`` sentinel comments for
  ``contextual_seq_len`` and ``scaling_seqlen`` to the same
  ``Sentinel: < 0 (default) = ...`` pattern (grep-friendly).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ward_end_to_end

Drop the standalone ``test_contextual_seq_len_resolution`` and fold
its contract-pin cases into the existing ``test_forward_end_to_end``
parametrization.  Each row carries explicit ``stu_cs`` /
``expected_cs`` columns; ``_build_transducer`` grows a
``stu_cs_override`` parameter that pins
``stu["contextual_seq_len"]`` in the dict passed to
``HSTUTransducer.__init__``.

Five legacy rows keep ``stu_cs=None`` (key absent, falls back to the
preprocessor's value); three new rows pin the sentinel guard:

| name                          | stu_cs | ppr_cs | expected |
|-------------------------------|--------|--------|----------|
| sentinel_fill_falls_back      |     -1 |      3 |        3 |
| explicit_zero_overrides       |      0 |      3 |        0 |
| explicit_positive_overrides   |      5 |      3 |        5 |

Each row also asserts ``layer._contextual_seq_len == expected_cs`` on
every layer in the stack -- so all 8 cases double as both the
forward-end-to-end correctness check AND the guard contract.

Verified the merged test still pins the bug class: replacing the
guard with ``not in stu`` makes exactly case 5
(``sentinel_fill_falls_back``) fail with ``AssertionError: -1 != 3``;
restoring the guard makes all 10 tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ng test rows

Replace the positional-tuple form of ``test_forward_end_to_end``'s
parametrization with ``param("name", split=..., tail=..., ...)``.
Each row is now self-describing -- no inline ``# split`` /
``# tail`` etc. comments needed and no risk of miscounting positional
slots when adding rows.

The first positional ``"name"`` keeps the test-name suffix (e.g.
``test_forward_end_to_end_5_sentinel_fill_falls_back``); everything
else is named kwargs.  ``parameterized`` doesn't accept keyword-only
function signatures, so the test method declares each kwarg as a
regular positional parameter (the kwargs from ``param`` bind by
name regardless).

No behavior change.  Verified: reverting the guard to ``not in stu``
makes case 5 ``sentinel_fill_falls_back`` fail with
``AssertionError: -1 != 3``; restoring it makes all 10 tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Carries the PR alibaba#501 fixes: contextual_seq_len proto sentinel +
auto-thread, post-truncation total_uih_len fix, and the AOTI-friendly
SLA NFUNC builder rewrite (custom-op-free).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji merged commit a446e12 into alibaba:master May 9, 2026
7 checks passed
@tiankongdeguiji tiankongdeguiji deleted the fix/contextual-seqlen-proto-default-neg1 branch May 11, 2026 02:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants