From 912eaffbe9efd1c968a0bc865541d1db92b2c151 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 16 Dec 2021 20:32:00 +0800 Subject: [PATCH 1/2] [DLMED] correct kwargs Signed-off-by: Nic Ma --- monai/networks/utils.py | 4 +++- tests/test_convert_to_torchscript.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ef0cff0eed..531e25397e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -458,11 +458,13 @@ def convert_to_torchscript( device: target device to verify the model, if None, use CUDA if available. rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + kwargs: other arguments for `torch.jit.script()` to convert model except `obj`, for more details: + https://pytorch.org/docs/master/generated/torch.jit.script.html. """ model.eval() with torch.no_grad(): - script_module = torch.jit.script(model) + script_module = torch.jit.script(model, **kwargs) if filename_or_obj is not None: if not pytorch_after(1, 7): torch.jit.save(m=script_module, f=filename_or_obj) diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index a772610a04..60e78cc0d4 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -34,6 +34,7 @@ def test_value(self): device="cuda" if torch.cuda.is_available() else "cpu", rtol=1e-3, atol=1e-4, + optimize=None, ) self.assertTrue(isinstance(torchscript_model, torch.nn.Module)) From c4c33fb108820d36eb597412af911e5ca50f94bf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 16 Dec 2021 20:45:41 +0800 Subject: [PATCH 2/2] [DLMED] fix grammar Signed-off-by: Nic Ma --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 531e25397e..0cff97cf27 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -458,7 +458,7 @@ def convert_to_torchscript( device: target device to verify the model, if None, use CUDA if available. rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. - kwargs: other arguments for `torch.jit.script()` to convert model except `obj`, for more details: + kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. """