[BugFix][Relax][Torch] Honor multi-axis dims in torch.flip converter#19511
Merged
tlopex merged 1 commit intoapache:mainfrom May 6, 2026
Merged
[BugFix][Relax][Torch] Honor multi-axis dims in torch.flip converter#19511tlopex merged 1 commit intoapache:mainfrom
tlopex merged 1 commit intoapache:mainfrom
Conversation
The PyTorch frontend's _flip coerces a list of dims to a single int (`dims = dims[0]`) and forwards only one axis to relax.op.flip, which is itself single-axis. As a result torch.flip(x, dims=[-1, -2]) silently flips just the last axis (max_abs_diff=8.0 vs PyTorch eager on a (3, 4) input). Iterate over dims instead, emitting one relax.op.flip per axis. Flips along distinct axes commute, so order is irrelevant. relax.op.flip itself is unchanged — it is used elsewhere as a single-axis op, and broadening its signature would expand scope beyond this bug.
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the _flip translator in the Relax Torch frontend to support multi-axis flipping, aligning with torch.flip semantics. The implementation now iterates through the provided dimensions and emits a relax.op.flip operation for each axis, whereas it previously only supported a single axis. Additionally, new test cases have been added to the FX and exported program frontend test suites to verify multi-axis and negative axis support. I have no feedback to provide as there are no review comments.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
PyTorch's
torch.flip(x, dims=[...])reverses every listed axis. TheRelax converter
_flip(base_fx_graph_translator.py) instead coercesthe list to a single integer:
Only the first axis is forwarded to
relax.op.flip, which is itselfsingle-axis. The remaining axes are silently dropped.
Minimal repro (vs PyTorch eager) on a
(3, 4)input withdims=[-1, -2]:max_abs_diff = 8.0. Both the
torch.exportand legacy fx paths sharethis converter, so both are affected.
Fix
Iterate over
dimsin the converter and emit onerelax.op.flipperaxis (flips along distinct axes commute, so the order is irrelevant).
A scalar
dimsis wrapped to a single-element list; non-int /non-sequence arguments still raise
TypeError.relax.op.flipitself is unchanged: it is used elsewhere as asingle-axis op, and widening its signature would expand the scope of
this fix beyond the PyTorch frontend.