From 09f8559688d27c64569618befab80b94c5efd916 Mon Sep 17 00:00:00 2001 From: Ramon Emiliani Date: Mon, 11 Apr 2022 23:05:08 -0500 Subject: [PATCH 1/2] Update the existing functionality to comply with the `torchscript.jit.script` function. Signed-off-by: Ramon Emiliani --- monai/networks/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a6b0699107..197cefd97b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -281,14 +281,16 @@ def pixelshuffle( f"divisible by scale_factor ** dimensions ({factor}**{dim}={scale_divisor})." ) - org_channels = channels // scale_divisor + org_channels = int(channels // scale_divisor) output_size = [batch_size, org_channels] + [d * factor for d in input_size[2:]] - indices = tuple(range(2, 2 + 2 * dim)) - indices_factor, indices_dim = indices[:dim], indices[dim:] - permute_indices = (0, 1) + sum(zip(indices_dim, indices_factor), ()) + indices = list(range(2, 2 + 2 * dim)) + indices = indices[dim:] + indices[:dim] + permute_indices = [0, 1] + for idx in range(dim): + permute_indices.extend(indices[idx::dim]) - x = x.reshape(batch_size, org_channels, *([factor] * dim + input_size[2:])) + x = x.reshape([batch_size, org_channels] + [factor] * dim + input_size[2:]) x = x.permute(permute_indices).reshape(output_size) return x From 4ab097cfd73a3b71b17fe0869b121df62102c930 Mon Sep 17 00:00:00 2001 From: Ramon Emiliani Date: Tue, 12 Apr 2022 08:33:25 -0500 Subject: [PATCH 2/2] Add test to verify `pixelshuffle` is scriptable. Signed-off-by: Ramon Emiliani --- tests/test_subpixel_upsample.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py index 0216f164c3..fc302a16d3 100644 --- a/tests/test_subpixel_upsample.py +++ b/tests/test_subpixel_upsample.py @@ -18,6 +18,7 @@ from monai.networks import eval_mode from monai.networks.blocks import SubpixelUpsample from monai.networks.layers.factories import Conv +from tests.utils import test_script_save TEST_CASE_SUBPIXEL = [] for inch in range(1, 5): @@ -73,6 +74,12 @@ def test_subpixel_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, _ = TEST_CASE_SUBPIXEL[0] + net = SubpixelUpsample(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main()