Skip to content

[bugfix] fix contextual_seq_len not passed from preprocessor to STULayer#450

Merged
tiankongdeguiji merged 2 commits into
alibaba:masterfrom
tiankongdeguiji:fix/contextual_seq_len_bug
Mar 30, 2026
Merged

[bugfix] fix contextual_seq_len not passed from preprocessor to STULayer#450
tiankongdeguiji merged 2 commits into
alibaba:masterfrom
tiankongdeguiji:fix/contextual_seq_len_bug

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

Summary

  • contextual_seq_len in STULayer was configured from the STU proto field (defaulting to 0), instead of being derived from InputPreprocessor.contextual_seq_len(). This caused incorrect attention masks in both Triton and PyTorch kernels when contextual features were used but contextual_seq_len was not explicitly set in the proto config.
  • Fixed by reordering HSTUTransducer.__init__ to create InputPreprocessor first, then override stu["contextual_seq_len"] with the computed value before constructing STULayer instances.
  • Marked the contextual_seq_len field in STU proto as deprecated (kept for backward compatibility).
  • Bumped version to 1.1.3.

Test plan

  • python -m tzrec.modules.gr.stu_test passes
  • pre-commit run -a passes
  • Verify attention mask correctness with contextual features enabled

🤖 Generated with Claude Code

@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Mar 29, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Mar 29, 2026
Comment thread tzrec/modules/gr/hstu_transducer.py Outdated
Comment thread tzrec/protos/module.proto Outdated
@github-actions
Copy link
Copy Markdown

Code Review Summary

Clean, well-targeted bugfix. The root cause is clear and the fix is minimal — reordering __init__ so InputPreprocessor is created before STULayer instances, allowing contextual_seq_len to be derived correctly. The proto deprecation comment is accurate and the version bump is appropriate.

Inline Comments

  1. Defensive dict copy (hstu_transducer.py:82) — The stu dict is mutated in-place, which works today but is fragile. Suggest stu = dict(stu) before the assignment.
  2. Runtime deprecation warning (module.proto:258) — Consider logging a warning when a user has explicitly set the now-ignored contextual_seq_len proto field.

Test Coverage Gap

There are no tests for HSTUTransducer anywhere in the test suite. The existing stu_test.py tests STULayer directly by passing contextual_seq_len as a constructor argument — it never exercises the config-to-module wiring where this bug lived. A minimal integration test that constructs HSTUTransducer with contextual features and asserts _stu_module layers have the correct _contextual_seq_len would guard against regression.

Pre-existing Doc Issue (not from this PR)

In tzrec/modules/gr/stu.py:190, the contextual_seq_len docstring says (bool) but the actual type is int. Worth fixing while in the neighborhood.


Overall: Approve with minor suggestions. The fix is correct and important — the Triton kernel was being compiled with HAS_CONTEXTUAL_SEQ_LEN=False when contextual features were present, producing wrong attention masks.

🤖 Generated with Claude Code

@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Mar 29, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Mar 29, 2026
Comment thread tzrec/protos/module.proto
Comment thread tzrec/modules/gr/hstu_transducer.py Outdated
@github-actions
Copy link
Copy Markdown

Code Review Summary

Clean, well-scoped bugfix. The root cause (proto default silently overriding the preprocessor-derived value) is correctly addressed, and the defensive shallow copy + deprecation warning are good additions. The docstring type fix (boolint) and proto comment update are both accurate.

Noteworthy Feedback

1. Consider a regression test for contextual_seq_len propagation (test coverage)

No existing test verifies that HSTUTransducer correctly passes contextual_seq_len from the preprocessor to the STULayer instances. The existing stu_test.py constructs STULayer directly (bypassing the transducer), and dlrm_hstu_test.py only checks output shapes. A simple CPU-only unit test would guard this fix:

# Construct HSTUTransducer with contextual features, then assert:
for layer in transducer._stu_module._stu_layers:
    assert layer._contextual_seq_len == transducer._input_preprocessor.contextual_seq_len()

2. Use proto [deprecated = true] annotation (see inline comment on module.proto)

3. Minor: warnings.warn vs logger.warning (see inline comment on hstu_transducer.py)

No Issues Found

  • Performance: no concerns — changes are init-only, no hot-path impact.
  • Security: no concerns — internal config plumbing, no external input.

LGTM with the test suggestion above as a nice-to-have.

🤖 Generated with Claude Code

contextual_seq_len in STULayer was configured from the STU proto field
(defaulting to 0), instead of being derived from the InputPreprocessor.
This caused incorrect attention masks when contextual features were used
but contextual_seq_len was not explicitly set in the proto config.

The fix reorders HSTUTransducer.__init__ to create the InputPreprocessor
first, then uses the preprocessor's computed value as fallback when
contextual_seq_len is not explicitly set in proto. If the user explicitly
sets contextual_seq_len in proto config, that value takes priority.

Also fixes docstring type (bool -> int) and bumps version to 1.1.3.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the fix/contextual_seq_len_bug branch from c0b7bf7 to 1c5b1e2 Compare March 29, 2026 13:55
contextual_seq_len has no default value in proto, so check key presence
instead of checking for zero value.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji merged commit eed874e into alibaba:master Mar 30, 2026
6 of 7 checks passed
tiankongdeguiji added a commit to tiankongdeguiji/TorchEasyRec that referenced this pull request May 8, 2026
…truncation total_uih_len

Real root cause of `test_rank_ultra_hstu_cutlass_train_eval_export`'s
forward-pass IMA on PR alibaba#501. (My previous two commits guessed at the
cutlass kernel and at falling back to `cs=0`; both have been reverted
in favor of this targeted fix.)

`HSTUTransducer._replay_truncation_state` returned
`plan.total_kept - total_targets` for the post-truncation
`total_uih_len` it hands to `_postprocess`. But `plan.total_kept` is the
"rest" portion only -- the contextual prefix is split off separately
into `plan.total_prefix = B * contextual_seq_len` and re-concatenated
on top of the kept rest by `apply_stu_truncation_plan`'s `cs > 0`
branch. So the real post-truncation UIH length is `plan.total_kept +
plan.total_prefix - total_targets`.

Downstream, `_postprocess` builds `uih_offsets = cumsum(seq_lengths -
num_targets)` from `plan.new_lengths` (which still includes the
contextual prefix per sample), so `uih_offsets[-1] = total_kept +
total_prefix`. Passing the under-counted `total_uih_len` to
`split_2D_jagged` makes the Triton/CUTLASS kernel allocate `OutA` of
shape `(total_kept - total_targets, D)`, but the kernel writes each
sample's slice at `OutA[uih_offsets[i] + j]` for `j in [0, ctx + new_uih_b)`
-- positions up to `total_kept + total_prefix - 1`. The tail `B *
contextual_seq_len` writes are out-of-bounds, producing the forward-pass
CUDA IMA detected on the next step's `.item()` sync inside
`ContextualInterleavePreprocessor.forward` (`max_uih_len =
fx_int_item(uih_seq_lengths.max())`). Surfaces on both Hopper (PR alibaba#501
H20 CI run 25543727919 / job 74975071509) and Ampere (run 25543727906 /
job 74975071514).

Why on-master with `cs=0` was silently fine: the dict-fill bug in
`config_to_kwargs` that PR alibaba#450's guard tried to address (now properly
fixed via the `[default = -1]` sentinel earlier in this PR) made
`STULayer.contextual_seq_len = 0` regardless of the preprocessor's
real value, which forces `compute_stu_truncation_plan(cs=0)` and
`apply_stu_truncation_plan`'s simple (non-contextual) branch.
`plan.total_prefix` is `0` in that branch, so
`total_kept + 0 - total_targets == total_kept - total_targets` was
correct only by coincidence. Once the proto/sentinel fix wires
`cs=6` through, the bug was load-bearing.

Why the unit test for `_replay_truncation_state` was passing despite the
bug: it asserted the buggy expected value (`total_kept - total_targets`),
and the surrounding `test_forward_end_to_end` cases use the PyTorch
backend, whose `pytorch_split_2D_jagged` ignores `total_len_left` and
sizes outputs from offsets directly -- so the OOB write never happens
on that path. The Triton/CUTLASS kernel honors `total_len_left`, which
is why the integration test (CUTLASS kernel) is the one that crashes.

Local repro (4× A10 + cu129 + fbgemm_gpu_hstu wheel installed):
- Pre-fix: train_eval crashes on first forward step
  (l2_loss_watchtime=1.92e9 then IMA inside
  `ContextualInterleavePreprocessor.forward`'s `.item()`).
- Post-fix: train_eval finishes cleanly (loss goes 1.92e9 -> 1.74e9
  over 45 steps, eval AUCs sane); standalone eval finishes cleanly.

Updates `test_replay_truncation_state_replays_plan_and_assigns_correct_fields`
to assert the corrected value
(`plan.total_kept + plan.total_prefix - total_targets`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants