diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4fe63744fd..eae0580696 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -251,6 +251,10 @@ def set_determinism( for func in additional_settings: func(seed) + if torch.backends.flags_frozen(): + warnings.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.") + torch.backends.__allow_nonbracketed_mutation_flag = True + if seed is not None: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 31b3254f5b..7d6c54909d 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -40,6 +40,8 @@ def test_values(self): self.assertEqual(seed, get_seed()) a = np.random.randint(seed) b = torch.randint(seed, (1,)) + # tset when global flag support is disabled + torch.backends.disable_global_flags() set_determinism(seed=seed) c = np.random.randint(seed) d = torch.randint(seed, (1,))