diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0e6d4caae0ee..8ed94c2a81c9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1071,13 +1071,17 @@ def maxpool_1d(self, inputs, input_types): def maxpool_3d(self, inputs, input_types): data = inputs[0] + need_squeeze = False + if len(self.get_dims(data)) == 4: + need_squeeze = True + data = _op.expand_dims(data, 0) pool_size = inputs[1] strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) - return _op.nn.max_pool3d( + res = _op.nn.max_pool3d( data, pool_size=pool_size, strides=strides, @@ -1085,6 +1089,7 @@ def maxpool_3d(self, inputs, input_types): padding=padding, ceil_mode=ceil_mode, ) + return res if not need_squeeze else _op.squeeze(res, [0]) def hardtanh(self, inputs, input_types): a = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3828412de054..a030c5141a31 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -906,13 +906,17 @@ def forward(self, *args): def test_forward_maxpool3d(): """test_forward_maxpool3d""" torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10, 10] - input_data = torch.rand(input_shape).float() + for input_shape in [(1, 3, 10, 10, 10), (3, 10, 10, 10)]: + input_data = torch.rand(input_shape).float() - verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), input_data) - verify_model(torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 2, 3]).eval(), input_data) - verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), input_data) - verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), input_data) + verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), input_data) + verify_model( + torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 2, 3]).eval(), input_data + ) + verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), input_data) + verify_model( + torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), input_data + ) # A functional variant (default strides = None case) class MaxPool3D(Module):