[bugfix] fix contextual_seq_len not passed from preprocessor to STULayer#450
Conversation
Code Review SummaryClean, well-targeted bugfix. The root cause is clear and the fix is minimal — reordering Inline Comments
Test Coverage GapThere are no tests for Pre-existing Doc Issue (not from this PR)In Overall: Approve with minor suggestions. The fix is correct and important — the Triton kernel was being compiled with 🤖 Generated with Claude Code |
Code Review SummaryClean, well-scoped bugfix. The root cause (proto default silently overriding the preprocessor-derived value) is correctly addressed, and the defensive shallow copy + deprecation warning are good additions. The docstring type fix ( Noteworthy Feedback1. Consider a regression test for No existing test verifies that # Construct HSTUTransducer with contextual features, then assert:
for layer in transducer._stu_module._stu_layers:
assert layer._contextual_seq_len == transducer._input_preprocessor.contextual_seq_len()2. Use proto 3. Minor: No Issues Found
LGTM with the test suggestion above as a nice-to-have. 🤖 Generated with Claude Code |
contextual_seq_len in STULayer was configured from the STU proto field (defaulting to 0), instead of being derived from the InputPreprocessor. This caused incorrect attention masks when contextual features were used but contextual_seq_len was not explicitly set in the proto config. The fix reorders HSTUTransducer.__init__ to create the InputPreprocessor first, then uses the preprocessor's computed value as fallback when contextual_seq_len is not explicitly set in proto. If the user explicitly sets contextual_seq_len in proto config, that value takes priority. Also fixes docstring type (bool -> int) and bumps version to 1.1.3. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c0b7bf7 to
1c5b1e2
Compare
contextual_seq_len has no default value in proto, so check key presence instead of checking for zero value. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…truncation total_uih_len Real root cause of `test_rank_ultra_hstu_cutlass_train_eval_export`'s forward-pass IMA on PR alibaba#501. (My previous two commits guessed at the cutlass kernel and at falling back to `cs=0`; both have been reverted in favor of this targeted fix.) `HSTUTransducer._replay_truncation_state` returned `plan.total_kept - total_targets` for the post-truncation `total_uih_len` it hands to `_postprocess`. But `plan.total_kept` is the "rest" portion only -- the contextual prefix is split off separately into `plan.total_prefix = B * contextual_seq_len` and re-concatenated on top of the kept rest by `apply_stu_truncation_plan`'s `cs > 0` branch. So the real post-truncation UIH length is `plan.total_kept + plan.total_prefix - total_targets`. Downstream, `_postprocess` builds `uih_offsets = cumsum(seq_lengths - num_targets)` from `plan.new_lengths` (which still includes the contextual prefix per sample), so `uih_offsets[-1] = total_kept + total_prefix`. Passing the under-counted `total_uih_len` to `split_2D_jagged` makes the Triton/CUTLASS kernel allocate `OutA` of shape `(total_kept - total_targets, D)`, but the kernel writes each sample's slice at `OutA[uih_offsets[i] + j]` for `j in [0, ctx + new_uih_b)` -- positions up to `total_kept + total_prefix - 1`. The tail `B * contextual_seq_len` writes are out-of-bounds, producing the forward-pass CUDA IMA detected on the next step's `.item()` sync inside `ContextualInterleavePreprocessor.forward` (`max_uih_len = fx_int_item(uih_seq_lengths.max())`). Surfaces on both Hopper (PR alibaba#501 H20 CI run 25543727919 / job 74975071509) and Ampere (run 25543727906 / job 74975071514). Why on-master with `cs=0` was silently fine: the dict-fill bug in `config_to_kwargs` that PR alibaba#450's guard tried to address (now properly fixed via the `[default = -1]` sentinel earlier in this PR) made `STULayer.contextual_seq_len = 0` regardless of the preprocessor's real value, which forces `compute_stu_truncation_plan(cs=0)` and `apply_stu_truncation_plan`'s simple (non-contextual) branch. `plan.total_prefix` is `0` in that branch, so `total_kept + 0 - total_targets == total_kept - total_targets` was correct only by coincidence. Once the proto/sentinel fix wires `cs=6` through, the bug was load-bearing. Why the unit test for `_replay_truncation_state` was passing despite the bug: it asserted the buggy expected value (`total_kept - total_targets`), and the surrounding `test_forward_end_to_end` cases use the PyTorch backend, whose `pytorch_split_2D_jagged` ignores `total_len_left` and sizes outputs from offsets directly -- so the OOB write never happens on that path. The Triton/CUTLASS kernel honors `total_len_left`, which is why the integration test (CUTLASS kernel) is the one that crashes. Local repro (4× A10 + cu129 + fbgemm_gpu_hstu wheel installed): - Pre-fix: train_eval crashes on first forward step (l2_loss_watchtime=1.92e9 then IMA inside `ContextualInterleavePreprocessor.forward`'s `.item()`). - Post-fix: train_eval finishes cleanly (loss goes 1.92e9 -> 1.74e9 over 45 steps, eval AUCs sane); standalone eval finishes cleanly. Updates `test_replay_truncation_state_replays_plan_and_assigns_correct_fields` to assert the corrected value (`plan.total_kept + plan.total_prefix - total_targets`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
contextual_seq_leninSTULayerwas configured from the STU proto field (defaulting to 0), instead of being derived fromInputPreprocessor.contextual_seq_len(). This caused incorrect attention masks in both Triton and PyTorch kernels when contextual features were used butcontextual_seq_lenwas not explicitly set in the proto config.HSTUTransducer.__init__to createInputPreprocessorfirst, then overridestu["contextual_seq_len"]with the computed value before constructingSTULayerinstances.contextual_seq_lenfield in STU proto as deprecated (kept for backward compatibility).Test plan
python -m tzrec.modules.gr.stu_testpassespre-commit run -apasses🤖 Generated with Claude Code