Skip to content

Commit

Permalink
[Relay][Bugfix] fix the wrong calculate logic of operator flip in PyT…
Browse files Browse the repository at this point in the history
…orch frontend (#15752)

The original implementation of Flip in PyTorch converter mistaken the type of attribute `axis` in the Flip operator as an integer. Thus, It only parses the first element of the `axis` and will give a wrong calculation result when the length of `axis` is more than one.  According to the PyTorch documentation [here](https://pytorch.org/docs/stable/generated/torch.flip.html), the type of `axis` is a list or tuple.

This PR corrected the incorrect implementation of the algorithm of `torch.flip` converter and added a regression test.
  • Loading branch information
jikechao committed Sep 16, 2023
1 parent 08a6ee5 commit 61d5be0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,10 @@ def nll_loss(self, inputs, input_types):
def flip(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])
out = data
for ax in axis:
out = _op.reverse(out, ax)
return out

def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
"""
Expand Down
11 changes: 6 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,13 +4899,14 @@ def __init__(self, axis=0):
self.axis = axis

def forward(self, x):
return x.flip([self.axis])
return x.flip(self.axis)

input_t = torch.randn(2, 3, 4)
verify_model(Flip(axis=0), input_data=input_t)
verify_model(Flip(axis=1), input_data=input_t)
verify_model(Flip(axis=2), input_data=input_t)
verify_model(Flip(axis=-1), input_data=input_t)
verify_model(Flip(axis=[0]), input_data=input_t)
verify_model(Flip(axis=[1]), input_data=input_t)
verify_model(Flip(axis=[2]), input_data=input_t)
verify_model(Flip(axis=[-1]), input_data=input_t)
verify_model(Flip(axis=[0, 1]), input_data=input_t)


def test_annotate_span():
Expand Down

0 comments on commit 61d5be0

Please sign in to comment.