Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add aten::pixel_shuffle implementation (#6328) #6468

Merged
merged 1 commit into from Sep 14, 2020

Conversation

wjliu1998
Copy link
Contributor

Following the implementation of aten::pixel_shuffle in pytorch. The implementation of pixel_shuffle has been added in pytorch frontend

@masahi
Copy link
Member

masahi commented Sep 14, 2020

thanks, please add a test.

@wjliu1998
Copy link
Contributor Author

thanks, please add a test.

Thanks! The test has been added!

out_shape = [b, oc, oh, ow]

data = _op.transform.reshape(data, new_shape)
axes = [0, 1, 4, 2, 5, 3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me what are these hardcoded axes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original shape order is [b, oc, upscale_factor, upscale_factor, h, w], the hardcoded axes transpose the order to [b, oc, h, upscale_factor, w, upscale_factor] for further reshape

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, please add your explanation above as a comment (so that people won't get confused).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment has been added.

@masahi
Copy link
Member

masahi commented Sep 14, 2020

@wjliu1998 lint is complaining

raise AssertionError(msg)

if isinstance(data, tvm.runtime.NDArray):
ndims = len(_infer_shape(data))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you copy pasted the above code from _transpose(). That one is a legacy code we don't want to repeat. Please leave only what is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The code has been deleted

Copy link
Member

@masahi masahi Sep 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, you have three if isinstance(...) right? You shouldn't need isintance, just get ndims directly (figure out which code path actually hit in your test). I guess all you need is just ndims = data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out that. It seems ndims = len(_infer_shape(data, prelude.mod)) is enough

@masahi masahi merged commit e02dc69 into apache:master Sep 14, 2020
@masahi
Copy link
Member

masahi commented Sep 14, 2020

Thanks @wjliu1998

kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Sep 17, 2020
kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Sep 18, 2020
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 18, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants