Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Add support for aten::concat #16199

Merged
merged 5 commits into from Dec 9, 2023

Conversation

sweetcocoa
Copy link
Contributor

I think it is a quite simple problem,

aten::concat is just an alias of aten::cat, but it is not supported.

https://github.com/pytorch/pytorch/blob/3fbfa8cd0a5cefadb3f116c5cd0d60e96ab8c99e/aten/src/ATen/native/TensorShape.cpp#L667

If needed, I will add a minimal example to reproduce.

torch==2.1.1+cpu
numpy==1.23.5

@mshr-h
Copy link
Contributor

mshr-h commented Dec 4, 2023

@sweetcocoa Thank you for your PR!
It would be nice to include a test to https://github.com/apache/tvm/blob/main/tests/python/frontend/pytorch/test_forward.py
Maybe just adding a test case to test_forward_concatenate(): is fine?
What do you think? @jikechao @vvchernov

@jikechao
Copy link
Contributor

jikechao commented Dec 4, 2023

@mshr-h, I agree with your comments.
@sweetcocoa, could you add a test case about torch.concat inner test_forward_concatenate():

@vvchernov
Copy link
Contributor

Thank you @sweetcocoa for your PR! I agree the test is needed

Copy link
Contributor

@jikechao jikechao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks!

@@ -719,10 +719,27 @@ def forward(self, *args):
b = (args[0][:, :, 1] + 3) * 11
c = (args[0][:, :, 2] + 5) * 13
return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove trailing whitespaces.

@@ -2893,6 +2911,7 @@ def forward(self, inp):
@tvm.testing.uses_gpu
def test_simple_rnn():
"""test_simple_rnn"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I suppose that all such white spaces are redundant. Could you please remove them in this test and in the tests below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't see it, I've reverted it now.

@sweetcocoa
Copy link
Contributor Author

I don't think this CI error stems from this PR, can I try restarting it?

@jikechao
Copy link
Contributor

jikechao commented Dec 7, 2023

I don't think this CI error stems from this PR, can I try restarting it?

Yes, you can comment with @tvm-bot rerun in the chatbox to rerun the CI test.

@sweetcocoa
Copy link
Contributor Author

@tvm-bot rerun

Comment on lines 723 to 739
class Concatenate3(Module):
# pylint: disable=missing-class-docstring
def __init__(self):
super().__init__()

class _Concatenate(Module):
def forward(self, *args):
a = (args[0][:, :, 0] + 2) * 7
b = (args[0][:, :, 1] + 3) * 11
c = (args[0][:, :, 2] + 5) * 13
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2)

self.mod = _Concatenate()

def forward(self, *args):
return self.mod(*args)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you create a class in this way?
Will the following code work in the same way?

Suggested change
class Concatenate3(Module):
# pylint: disable=missing-class-docstring
def __init__(self):
super().__init__()
class _Concatenate(Module):
def forward(self, *args):
a = (args[0][:, :, 0] + 2) * 7
b = (args[0][:, :, 1] + 3) * 11
c = (args[0][:, :, 2] + 5) * 13
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2)
self.mod = _Concatenate()
def forward(self, *args):
return self.mod(*args)
class Concatenate3(Module):
def forward(self, *args):
a = (args[0][:, :, 0] + 2) * 7
b = (args[0][:, :, 1] + 3) * 11
c = (args[0][:, :, 2] + 5) * 13
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echuraev
The torch.concat is preserved as aten::concat only when it is in a nested module like this code. (In the most cases, It is converted to aten::cat instead of aten::concat.) I've tried to find a reason for this, but haven't found one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reply! Could you please in this case specify it in the class description, instead of using # pylint: disable=missing-class-docstring?

Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your PR!

@masahi masahi merged commit 65121c8 into apache:main Dec 9, 2023
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants