Skip to content

Commit

Permalink
Fix Python IndexError of case19: paddle.nn.functional.conv2d_transpose (
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 committed Feb 2, 2023
1 parent f76a7c5 commit a34ecad
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,17 @@ def test_case(self):
paddle.enable_static()


class TestConv2dTranspose(unittest.TestCase):
def error_weight_input(self):
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1, 1, 1, 1]), dtype='float32')
weight = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
paddle.nn.functional.conv2d_transpose(x, weight, bias=0)

def test_type_error(self):
self.assertRaises(ValueError, self.error_weight_input)


class TestTensorOutputSize1(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 8, 8]]
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,12 @@ def conv2d_transpose(
x.shape
)
)
if len(weight.shape) != 4:
raise ValueError(
"Input weight should be 4D tensor, but received weight with the shape of {}".format(
weight.shape
)
)
num_channels = x.shape[channel_dim]
if num_channels < 0:
raise ValueError(
Expand Down

0 comments on commit a34ecad

Please sign in to comment.