Skip to content

Commit

Permalink
add aten::pixel_shuffle implementation (#6328) (#6468)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjliu1998 committed Sep 14, 2020
1 parent cdd3206 commit e02dc69
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Expand Up @@ -1230,6 +1230,33 @@ def _impl(inputs, input_types):

return _impl

def _pixel_shuffle(prelude):
def _impl(inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
upscale_squared = upscale_factor * upscale_factor
b, c, h, w = _infer_shape(data)
assert c % upscale_squared == 0, \
"input channel should be divisible by square of upscale_factor"

ndims = len(_infer_shape(data, prelude.mod))
axes = list(range(ndims))
num_inputs = len(inputs)
oc = c // upscale_squared
oh = h * upscale_factor
ow = w * upscale_factor

new_shape = [b, oc, upscale_factor, upscale_factor, h, w]
out_shape = [b, oc, oh, ow]

data = _op.transform.reshape(data, new_shape)
# The data will be transposed to
# [b, oc, h, upscale_factor, w, upscale_factor]
# for further reshape
axes = [0, 1, 4, 2, 5, 3]
data = _op.transform.transpose(data, axes)
return _op.transform.reshape(data, out_shape)
return _impl

def _clone():
def _impl(inputs, input_types):
Expand Down Expand Up @@ -2162,6 +2189,7 @@ def _wrap_const(c):
# Operator mappings
def _get_convert_map(prelude, default_dtype):
convert_map = {
"aten::pixel_shuffle": _pixel_shuffle(prelude),
"aten::device": _none(),
"prim::device": _none(),
"aten::sub": _elemwise("subtract"),
Expand Down
14 changes: 13 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Expand Up @@ -182,7 +182,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

with torch.no_grad():
baseline_outputs = baseline_model(*baseline_input)

if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
Expand Down Expand Up @@ -223,6 +223,17 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at


# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
torch.set_grad_enabled(False)
input_shape = [1, 144, 16, 16]

input_data = torch.rand(input_shape).float()
verify_model(torch.nn.PixelShuffle(2).float().eval(), input_data=input_data)
verify_model(torch.nn.PixelShuffle(3).float().eval(), input_data=input_data)
verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_add():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3163,6 +3174,7 @@ def test_forward_pretrained_bert_base_uncased():
test_duplicate_weight_use()

# Single operator tests
test_forward_pixel_shuffle()
test_forward_add()
test_forward_subtract()
test_forward_multiply()
Expand Down

0 comments on commit e02dc69

Please sign in to comment.