From e02dc69fef294eb73dd65d18949ed9e108f60cda Mon Sep 17 00:00:00 2001 From: wjliu Date: Mon, 14 Sep 2020 22:51:06 +0800 Subject: [PATCH] add aten::pixel_shuffle implementation (#6328) (#6468) --- python/tvm/relay/frontend/pytorch.py | 28 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 14 +++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 33ce58fd1d7d..2eff4153592d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1230,6 +1230,33 @@ def _impl(inputs, input_types): return _impl +def _pixel_shuffle(prelude): + def _impl(inputs, input_types): + data = inputs[0] + upscale_factor = inputs[1] + upscale_squared = upscale_factor * upscale_factor + b, c, h, w = _infer_shape(data) + assert c % upscale_squared == 0, \ + "input channel should be divisible by square of upscale_factor" + + ndims = len(_infer_shape(data, prelude.mod)) + axes = list(range(ndims)) + num_inputs = len(inputs) + oc = c // upscale_squared + oh = h * upscale_factor + ow = w * upscale_factor + + new_shape = [b, oc, upscale_factor, upscale_factor, h, w] + out_shape = [b, oc, oh, ow] + + data = _op.transform.reshape(data, new_shape) + # The data will be transposed to + # [b, oc, h, upscale_factor, w, upscale_factor] + # for further reshape + axes = [0, 1, 4, 2, 5, 3] + data = _op.transform.transpose(data, axes) + return _op.transform.reshape(data, out_shape) + return _impl def _clone(): def _impl(inputs, input_types): @@ -2162,6 +2189,7 @@ def _wrap_const(c): # Operator mappings def _get_convert_map(prelude, default_dtype): convert_map = { + "aten::pixel_shuffle": _pixel_shuffle(prelude), "aten::device": _none(), "prim::device": _none(), "aten::sub": _elemwise("subtract"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6661aad70c09..4192cf45737d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -182,7 +182,7 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at with torch.no_grad(): baseline_outputs = baseline_model(*baseline_input) - + if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: @@ -223,6 +223,17 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at # Single operator tests +@tvm.testing.uses_gpu +def test_forward_pixel_shuffle(): + torch.set_grad_enabled(False) + input_shape = [1, 144, 16, 16] + + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.PixelShuffle(2).float().eval(), input_data=input_data) + verify_model(torch.nn.PixelShuffle(3).float().eval(), input_data=input_data) + verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_add(): torch.set_grad_enabled(False) @@ -3163,6 +3174,7 @@ def test_forward_pretrained_bert_base_uncased(): test_duplicate_weight_use() # Single operator tests + test_forward_pixel_shuffle() test_forward_add() test_forward_subtract() test_forward_multiply()