diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ef0cff0eed..0cff97cf27 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -458,11 +458,13 @@ def convert_to_torchscript( device: target device to verify the model, if None, use CUDA if available. rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. + kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + https://pytorch.org/docs/master/generated/torch.jit.script.html. """ model.eval() with torch.no_grad(): - script_module = torch.jit.script(model) + script_module = torch.jit.script(model, **kwargs) if filename_or_obj is not None: if not pytorch_after(1, 7): torch.jit.save(m=script_module, f=filename_or_obj) diff --git a/tests/test_convert_to_torchscript.py b/tests/test_convert_to_torchscript.py index a772610a04..60e78cc0d4 100644 --- a/tests/test_convert_to_torchscript.py +++ b/tests/test_convert_to_torchscript.py @@ -34,6 +34,7 @@ def test_value(self): device="cuda" if torch.cuda.is_available() else "cpu", rtol=1e-3, atol=1e-4, + optimize=None, ) self.assertTrue(isinstance(torchscript_model, torch.nn.Module))