From e72ef7db394f50de4ce76fa4f526f48e7977a1e5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Jun 2021 17:35:01 +0100 Subject: [PATCH] fixes jit concurrent build Signed-off-by: Wenqi Li --- monai/networks/layers/gmm.py | 13 ++++++++----- tests/test_gmm.py | 21 ++++++++++++++++++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/monai/networks/layers/gmm.py b/monai/networks/layers/gmm.py index 953702e1f8..3091f95458 100644 --- a/monai/networks/layers/gmm.py +++ b/monai/networks/layers/gmm.py @@ -26,12 +26,13 @@ class GaussianMixtureModel: https://en.wikipedia.org/wiki/Mixture_model """ - def __init__(self, channel_count, mixture_count, mixture_size): + def __init__(self, channel_count: int, mixture_count: int, mixture_size: int, verbose_build: bool = False): """ Args: - channel_count (int): The number of features per element. - mixture_count (int): The number of class distributions. - mixture_size (int): The number Gaussian components per class distribution. + channel_count: The number of features per element. + mixture_count: The number of class distributions. + mixture_size: The number Gaussian components per class distribution. + verbose_build: If ``True``, turns on verbose logging of load steps. """ if not torch.cuda.is_available(): raise NotImplementedError("GaussianMixtureModel is currently implemented for CUDA.") @@ -39,7 +40,9 @@ def __init__(self, channel_count, mixture_count, mixture_size): self.mixture_count = mixture_count self.mixture_size = mixture_size self.compiled_extension = load_module( - "gmm", {"CHANNEL_COUNT": channel_count, "MIXTURE_COUNT": mixture_count, "MIXTURE_SIZE": mixture_size} + "gmm", + {"CHANNEL_COUNT": channel_count, "MIXTURE_COUNT": mixture_count, "MIXTURE_SIZE": mixture_size}, + verbose_build=verbose_build, ) self.params, self.scratch = self.compiled_extension.init() diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 6cccdb7410..0e2401b452 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -9,6 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil +import tempfile import unittest import numpy as np @@ -309,6 +312,18 @@ @skip_if_no_cuda class GMMTestCase(unittest.TestCase): + def setUp(self): + self._var = os.environ.get("TORCH_EXTENSIONS_DIR", None) + self.tempdir = tempfile.mkdtemp() + os.environ["TORCH_EXTENSIONS_DIR"] = self.tempdir + + def tearDown(self) -> None: + if self._var is None: + os.environ.pop("TORCH_EXTENSIONS_DIR", None) + else: + os.environ["TORCH_EXTENSIONS_DIR"] = f"{self._var}" + shutil.rmtree(self.tempdir) + @parameterized.expand(TEST_CASES) def test_cuda(self, test_case_description, mixture_count, class_count, features, labels, expected): @@ -320,7 +335,11 @@ def test_cuda(self, test_case_description, mixture_count, class_count, features, labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device) # Create GMM - gmm = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count) + gmm = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=True) + # reload GMM to confirm the build + _ = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=False) + # reload quietly + _ = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=True) # Apply GMM gmm.learn(features_tensor, labels_tensor)