Skip to content

[bugfix] share Dim across grouped-sequence tensors in legacy AOT export#485

Merged
tiankongdeguiji merged 3 commits into
alibaba:masterfrom
tiankongdeguiji:bugfix/aot-share-seq-dims
Apr 26, 2026
Merged

[bugfix] share Dim across grouped-sequence tensors in legacy AOT export#485
tiankongdeguiji merged 3 commits into
alibaba:masterfrom
tiankongdeguiji:bugfix/aot-share-seq-dims

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

@tiankongdeguiji tiankongdeguiji commented Apr 25, 2026

Summary

  • Legacy export_model_aot (the two-stage ENABLE_AOT=1 path) gave each *__sequence tensor its own torch.export.Dim. When several JAGGED_SEQUENCE feature_groups draw their first feature from the same parent SequenceFeature (HSTU's uih, uih_action, uih_timestamp, uih_watchtime all from uih_seq), they share per-sample lengths in the data parser, so cross-sequence ops in the dense graph (e.g. bitwise_or between seq_actions and seq_watchtimes in SimpleActionEncoder) trigger ConstraintViolationError during torch.export.export.
  • split_model now derives a per-{group_name}__sequence share_key from the first feature's _is_grouped_seq / sequence_name, and export_model_aot reuses one Dim per share_key (separate caches for SEQUENCE's axis-1 max-seq-len Dim and JAGGED_SEQUENCE's axis-0 nnz Dim). Standalone (non-grouped) sequence features keep a per-group share_key, preserving the previous behavior.

Repro of the original failure

Running the legacy AOT export against an HSTU pipeline with multiple JAGGED_SEQUENCE feature_groups derived from the same parent SequenceFeature (e.g. uih_seq -> uih, uih_action, uih_timestamp, uih_watchtime) and a SimpleActionEncoder configured with watchtime_to_action_thresholds / watchtime_to_action_weights previously aborted with:

torch._dynamo.exc.UserError: Constraints violated (uih_watchtime__sequence__batch)!
- The values of uih_watchtime__sequence__batch and uih_action__sequence__batch must always be equal.

After this PR, the [INFO] dynamic shapes=... line shows the four UIH-derived sequence values referencing one shared Dim and the two CAND-derived ones sharing another:

uih__sequence:           {0: Dim('seq_uih_seq__batch', ...)}
uih_action__sequence:    {0: Dim('seq_uih_seq__batch', ...)}
uih_timestamp__sequence: {0: Dim('seq_uih_seq__batch', ...)}
uih_watchtime__sequence: {0: Dim('seq_uih_seq__batch', ...)}
candidate__sequence:           {0: Dim('seq_cand_seq__batch', ...)}
candidate_timestamp__sequence: {0: Dim('seq_cand_seq__batch', ...)}

and the export runs through to a complete AOTI archive.

Why the existing DlrmHSTUTest didn't catch this

DlrmHSTUTest::test_dlrm_hstu already parameterizes (graph_type=AOT_INDUCTOR, has_watchtime=True) and exercises split_model + export_model_aot. It was silent on this bug for two reasons that this PR also fixes:

  1. The encoder was built without watchtime_to_action_thresholds / watchtime_to_action_weights, so SimpleActionEncoder.need_watchtime returned False and the bitwise_or(seq_actions, seq_watchtimes >= ...) branch was dead at trace time. Without that op, no constraint between uih_action__sequence and uih_watchtime__sequence was ever recorded.
  2. All sequence features were declared standalone (top-level sequence_id_feature / sequence_raw_feature), so _is_grouped_seq was always False and there was no sequence_name signal to merge separate JAGGED_SEQUENCE feature_groups under one share_key.

Both gaps are closed: features are wrapped inside SequenceFeature(sequence_name="uih_seq" / "cand_seq", features=[...]) parents and the encoder now sets watchtime_to_action_thresholds/watchtime_to_action_weights when has_watchtime=True. The hypothesis sample (AOT_INDUCTOR, has_watchtime=True) now reproduces the failure on master and validates the fix here.

Test plan

  • pytest -x tzrec/models/dlrm_hstu_test.py::DlrmHSTUTest::test_dlrm_hstu — passes after fix (1 passed, 141s).
  • Re-run the failing AOT export end-to-end — [INFO] dynamic shapes now logs seq_uih_seq__batch and seq_cand_seq__batch shared across the matching keys; export proceeds past the previous failure point and produces a complete aoti_model.pt2.
  • Smoke-test the produced AOTI archive end-to-end (predict path).

🤖 Generated with Claude Code

@tiankongdeguiji tiankongdeguiji force-pushed the bugfix/aot-share-seq-dims branch from 3f058ed to 076e529 Compare April 25, 2026 07:34
Legacy `export_model_aot` previously assigned an independent
`torch.export.Dim` to every entry in `meta_info["jagged_seq_tensor_names"]`
(and likewise to each padded-`SEQUENCE` group's per-axis Dim). When
multiple JAGGED_SEQUENCE feature_groups draw their first feature from
the same parent `SequenceFeature` (e.g. HSTU's `uih`, `uih_action`,
`uih_timestamp`, `uih_watchtime` all derived from `uih_seq`), their
per-sample lengths come from one shared data-parser source — so
runtime nnz across those tensors is necessarily equal. With independent
Dims, downstream ops (e.g. `bitwise_or` between `seq_actions` and
`seq_watchtimes` in `SimpleActionEncoder.forward`) trigger a
`ConstraintViolationError` during `torch.export`.

`split_model` now computes a per-`{group_name}__sequence` `share_key`
based on the first feature's `_is_grouped_seq` / `sequence_name`, and
`export_model_aot` uses one Dim per share_key (separately per axis for
SEQUENCE and JAGGED_SEQUENCE). Standalone (non-grouped) sequence
features keep their own share_key and retain the previous
independent-Dim behavior.

Regression coverage: `DlrmHSTUTest::test_dlrm_hstu` now wraps `uih_*`
and `cand_*` sub-features inside `SequenceFeature` parents and threads
`watchtime_to_action_thresholds` / `watchtime_to_action_weights` into
`SimpleActionEncoder` when `has_watchtime=True`, so the cross-sequence
bitwise-OR branch actually runs during tracing — without these two
changes neither the constraint nor the share-Dim signal would be
exercised in tests.

Verified end-to-end against an HSTU pipeline that exhibits this
failure: the `dynamic shapes` log now reports shared dims
`seq_uih_seq__batch` and `seq_cand_seq__batch` for the respective
groups, and `torch.export.export` proceeds past the previous failure
point through to a complete AOTI archive.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the bugfix/aot-share-seq-dims branch from 076e529 to e838368 Compare April 25, 2026 07:36
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 25, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Apr 25, 2026
tiankongdeguiji and others added 2 commits April 25, 2026 15:41
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment on lines +1015 to +1020
first = feat_by_name.get(fg.feature_names[0])
if first is not None and getattr(first, "_is_grouped_seq", False):
share_key = f"seq_{first.sequence_name}"
else:
share_key = f"fg_{fg.group_name}"
share[f"{fg.group_name}__sequence"] = share_key
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "first feature decides the share_key" heuristic has a few silent correctness holes worth tightening:

  1. Mixed-parent groups silently coalesce. A JAGGED_SEQUENCE group whose feature_names[0] is from SequenceFeature(uih_seq) and whose second is from SequenceFeature(cand_seq) will map all its tensors to seq_uih_seq and forcibly share a Dim with uih's tensor — wrong.
  2. value_dim != 1 is not filtered. The unified path's _build_dynamic_shapes (export_util.py:268-274) explicitly excludes multi-valued grouped features from share groups because their nnz is independent. The legacy helper here doesn't, so a multi-valued grouped feature would be over-shared. Same model could pass unified AOT and fail legacy AOT.
  3. first.sequence_name is a bare attribute access. _is_grouped_seq is defensively read with getattr(..., False), but sequence_name would raise AttributeError if a BaseFeature subclass ever lacks it. Mirror the unified path: getattr(first, "sequence_name", None).

At minimum, iterate all features in the group and assert they share the same sequence_name/value_dim==1, warn (or fall back) on mixed groups, and guard sequence_name with getattr. The docstring should also call out that only the first feature is inspected today.

Comment on lines 254 to 263
simple_action_encoder=module_pb2.GRSimpleActionEncoder(
action_embedding_dim=8, action_weights=[1, 2, 4]
action_embedding_dim=8,
action_weights=[1, 2, 4],
watchtime_to_action_thresholds=(
[60, 300] if has_watchtime else []
),
watchtime_to_action_weights=(
[256, 512] if has_watchtime else []
),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bugfix is gated on (graph_type=AOT_INDUCTOR, has_watchtime=True), but with max_examples=20 over 8 sampled_from axes (~1024 combos) hypothesis only hits this exact pair with ~93% probability per CI run — i.e. ~7% of runs won't exercise the regression at all, and seeds aren't fixed via derandomize=True.

Recommend pinning this combo deterministically with an @example(graph_type=TestGraphType.AOT_INDUCTOR, has_watchtime=True, kernel=Kernel.PYTORCH, ...) decorator on test_dlrm_hstu, or adding a small dedicated non-hypothesis test, so the regression can never be silently unsampled.

@github-actions
Copy link
Copy Markdown

Review summary

The fix itself is correct and the PR description does an excellent job explaining both the failure mode and why the existing test was silent — that level of root-cause detail is great. Performance impact is mildly positive: sharing one symbolic Dim across grouped-sequence tensors lets ShapeEnv skip equality discovery via runtime asserts, reduces emitted assert_scalar ops, and unblocks cross-sequence fusion in inductor. No security or multi-process safety concerns (_compute_seq_share_groups is a pure function over replicated config; Dim names come from operator-controlled protobufs).

Noteworthy items posted inline:

  1. _compute_seq_share_groups first-feature heuristic — silently mis-shares Dims for mixed-parent groups, lacks the value_dim==1 filter the unified _build_dynamic_shapes uses, and first.sequence_name should be getattr-guarded.
  2. Hypothesis sampling reliability — (AOT_INDUCTOR, has_watchtime=True) is only sampled ~93% of the time per CI run; pin with @example(...) so the regression cannot be silently unsampled.

Two smaller follow-ups (not posted inline):

  • After this PR moves all sequence features under SequenceFeature(...) parents, the standalone share_key = f"fg_{fg.group_name}" fallback branch (export_util.py:1019) is no longer covered by dlrm_hstu_test.py. Consider one variant or a direct unit test in a new tzrec/utils/export_util_test.py that also covers _compute_seq_share_groups on CPU.
  • Two paths now answer the same conceptual question ("which sequence tensors share an nnz Dim?") with diverging algorithms (_compute_seq_share_groups vs _build_dynamic_shapes). Worth either extracting a shared helper or leaving a TODO so they don't drift further.

🤖 Generated with Claude Code

@tiankongdeguiji tiankongdeguiji merged commit ce1ee1a into alibaba:master Apr 26, 2026
6 of 7 checks passed
@tiankongdeguiji tiankongdeguiji deleted the bugfix/aot-share-seq-dims branch April 27, 2026 01:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants