New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add aten::pixel_shuffle implementation (#6328) #6468
Conversation
thanks, please add a test. |
Thanks! The test has been added! |
out_shape = [b, oc, oh, ow] | ||
|
||
data = _op.transform.reshape(data, new_shape) | ||
axes = [0, 1, 4, 2, 5, 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me what are these hardcoded axes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original shape order is [b, oc, upscale_factor, upscale_factor, h, w], the hardcoded axes transpose the order to [b, oc, h, upscale_factor, w, upscale_factor] for further reshape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, please add your explanation above as a comment (so that people won't get confused).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment has been added.
@wjliu1998 lint is complaining |
python/tvm/relay/frontend/pytorch.py
Outdated
raise AssertionError(msg) | ||
|
||
if isinstance(data, tvm.runtime.NDArray): | ||
ndims = len(_infer_shape(data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you copy pasted the above code from _transpose()
. That one is a legacy code we don't want to repeat. Please leave only what is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The code has been deleted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, you have three if isinstance(...)
right? You shouldn't need isintance
, just get ndims directly (figure out which code path actually hit in your test). I guess all you need is just ndims = data
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out that. It seems ndims = len(_infer_shape(data, prelude.mod))
is enough
Thanks @wjliu1998 |
Following the implementation of aten::pixel_shuffle in pytorch. The implementation of pixel_shuffle has been added in pytorch frontend