[bugfix] share Dim across grouped-sequence tensors in legacy AOT export#485
Conversation
3f058ed to
076e529
Compare
Legacy `export_model_aot` previously assigned an independent
`torch.export.Dim` to every entry in `meta_info["jagged_seq_tensor_names"]`
(and likewise to each padded-`SEQUENCE` group's per-axis Dim). When
multiple JAGGED_SEQUENCE feature_groups draw their first feature from
the same parent `SequenceFeature` (e.g. HSTU's `uih`, `uih_action`,
`uih_timestamp`, `uih_watchtime` all derived from `uih_seq`), their
per-sample lengths come from one shared data-parser source — so
runtime nnz across those tensors is necessarily equal. With independent
Dims, downstream ops (e.g. `bitwise_or` between `seq_actions` and
`seq_watchtimes` in `SimpleActionEncoder.forward`) trigger a
`ConstraintViolationError` during `torch.export`.
`split_model` now computes a per-`{group_name}__sequence` `share_key`
based on the first feature's `_is_grouped_seq` / `sequence_name`, and
`export_model_aot` uses one Dim per share_key (separately per axis for
SEQUENCE and JAGGED_SEQUENCE). Standalone (non-grouped) sequence
features keep their own share_key and retain the previous
independent-Dim behavior.
Regression coverage: `DlrmHSTUTest::test_dlrm_hstu` now wraps `uih_*`
and `cand_*` sub-features inside `SequenceFeature` parents and threads
`watchtime_to_action_thresholds` / `watchtime_to_action_weights` into
`SimpleActionEncoder` when `has_watchtime=True`, so the cross-sequence
bitwise-OR branch actually runs during tracing — without these two
changes neither the constraint nor the share-Dim signal would be
exercised in tests.
Verified end-to-end against an HSTU pipeline that exhibits this
failure: the `dynamic shapes` log now reports shared dims
`seq_uih_seq__batch` and `seq_cand_seq__batch` for the respective
groups, and `torch.export.export` proceeds past the previous failure
point through to a complete AOTI archive.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
076e529 to
e838368
Compare
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| first = feat_by_name.get(fg.feature_names[0]) | ||
| if first is not None and getattr(first, "_is_grouped_seq", False): | ||
| share_key = f"seq_{first.sequence_name}" | ||
| else: | ||
| share_key = f"fg_{fg.group_name}" | ||
| share[f"{fg.group_name}__sequence"] = share_key |
There was a problem hiding this comment.
The "first feature decides the share_key" heuristic has a few silent correctness holes worth tightening:
- Mixed-parent groups silently coalesce. A
JAGGED_SEQUENCEgroup whosefeature_names[0]is fromSequenceFeature(uih_seq)and whose second is fromSequenceFeature(cand_seq)will map all its tensors toseq_uih_seqand forcibly share a Dim withuih's tensor — wrong. value_dim != 1is not filtered. The unified path's_build_dynamic_shapes(export_util.py:268-274) explicitly excludes multi-valued grouped features from share groups because their nnz is independent. The legacy helper here doesn't, so a multi-valued grouped feature would be over-shared. Same model could pass unified AOT and fail legacy AOT.first.sequence_nameis a bare attribute access._is_grouped_seqis defensively read withgetattr(..., False), butsequence_namewould raiseAttributeErrorif aBaseFeaturesubclass ever lacks it. Mirror the unified path:getattr(first, "sequence_name", None).
At minimum, iterate all features in the group and assert they share the same sequence_name/value_dim==1, warn (or fall back) on mixed groups, and guard sequence_name with getattr. The docstring should also call out that only the first feature is inspected today.
| simple_action_encoder=module_pb2.GRSimpleActionEncoder( | ||
| action_embedding_dim=8, action_weights=[1, 2, 4] | ||
| action_embedding_dim=8, | ||
| action_weights=[1, 2, 4], | ||
| watchtime_to_action_thresholds=( | ||
| [60, 300] if has_watchtime else [] | ||
| ), | ||
| watchtime_to_action_weights=( | ||
| [256, 512] if has_watchtime else [] | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The bugfix is gated on (graph_type=AOT_INDUCTOR, has_watchtime=True), but with max_examples=20 over 8 sampled_from axes (~1024 combos) hypothesis only hits this exact pair with ~93% probability per CI run — i.e. ~7% of runs won't exercise the regression at all, and seeds aren't fixed via derandomize=True.
Recommend pinning this combo deterministically with an @example(graph_type=TestGraphType.AOT_INDUCTOR, has_watchtime=True, kernel=Kernel.PYTORCH, ...) decorator on test_dlrm_hstu, or adding a small dedicated non-hypothesis test, so the regression can never be silently unsampled.
Review summaryThe fix itself is correct and the PR description does an excellent job explaining both the failure mode and why the existing test was silent — that level of root-cause detail is great. Performance impact is mildly positive: sharing one symbolic Dim across grouped-sequence tensors lets Noteworthy items posted inline:
Two smaller follow-ups (not posted inline):
🤖 Generated with Claude Code |
Summary
export_model_aot(the two-stageENABLE_AOT=1path) gave each*__sequencetensor its owntorch.export.Dim. When several JAGGED_SEQUENCE feature_groups draw their first feature from the same parentSequenceFeature(HSTU'suih,uih_action,uih_timestamp,uih_watchtimeall fromuih_seq), they share per-sample lengths in the data parser, so cross-sequence ops in the dense graph (e.g.bitwise_orbetweenseq_actionsandseq_watchtimesinSimpleActionEncoder) triggerConstraintViolationErrorduringtorch.export.export.split_modelnow derives a per-{group_name}__sequenceshare_keyfrom the first feature's_is_grouped_seq/sequence_name, andexport_model_aotreuses one Dim per share_key (separate caches for SEQUENCE's axis-1 max-seq-len Dim and JAGGED_SEQUENCE's axis-0 nnz Dim). Standalone (non-grouped) sequence features keep a per-group share_key, preserving the previous behavior.Repro of the original failure
Running the legacy AOT export against an HSTU pipeline with multiple JAGGED_SEQUENCE feature_groups derived from the same parent
SequenceFeature(e.g.uih_seq->uih,uih_action,uih_timestamp,uih_watchtime) and aSimpleActionEncoderconfigured withwatchtime_to_action_thresholds/watchtime_to_action_weightspreviously aborted with:After this PR, the
[INFO] dynamic shapes=...line shows the four UIH-derived sequence values referencing one shared Dim and the two CAND-derived ones sharing another:and the export runs through to a complete AOTI archive.
Why the existing DlrmHSTUTest didn't catch this
DlrmHSTUTest::test_dlrm_hstualready parameterizes(graph_type=AOT_INDUCTOR, has_watchtime=True)and exercisessplit_model+export_model_aot. It was silent on this bug for two reasons that this PR also fixes:watchtime_to_action_thresholds/watchtime_to_action_weights, soSimpleActionEncoder.need_watchtimereturnedFalseand thebitwise_or(seq_actions, seq_watchtimes >= ...)branch was dead at trace time. Without that op, no constraint betweenuih_action__sequenceanduih_watchtime__sequencewas ever recorded.sequence_id_feature/sequence_raw_feature), so_is_grouped_seqwas alwaysFalseand there was nosequence_namesignal to merge separate JAGGED_SEQUENCE feature_groups under one share_key.Both gaps are closed: features are wrapped inside
SequenceFeature(sequence_name="uih_seq" / "cand_seq", features=[...])parents and the encoder now setswatchtime_to_action_thresholds/watchtime_to_action_weightswhenhas_watchtime=True. The hypothesis sample(AOT_INDUCTOR, has_watchtime=True)now reproduces the failure on master and validates the fix here.Test plan
pytest -x tzrec/models/dlrm_hstu_test.py::DlrmHSTUTest::test_dlrm_hstu— passes after fix (1 passed, 141s).[INFO] dynamic shapesnow logsseq_uih_seq__batchandseq_cand_seq__batchshared across the matching keys; export proceeds past the previous failure point and produces a completeaoti_model.pt2.🤖 Generated with Claude Code