diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index f3a3fba46e..fb68e04b44 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -100,6 +100,10 @@ jobs: python -m pip install -r requirements-dev.txt python -m pip list python setup.py develop # test no compile installation + shell: bash + - if: runner.os != 'windows' + name: Run compiled (${{ runner.os }}) + run: | python setup.py develop --uninstall BUILD_MONAI=1 python setup.py develop # compile the cpp extensions shell: bash diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 7688bdbcd5..4b9e6c4f6a 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -16,6 +16,7 @@ import numpy as np import torch +from monai.apps.auto3dseg.transforms import EnsureSameShaped from monai.apps.utils import get_logger from monai.auto3dseg import SegSummarizer from monai.auto3dseg.utils import datafold_read @@ -69,9 +70,11 @@ class DataAnalyzer: hist_range: ranges to compute histogram for each image channel. fmt: format used to save the analysis results. Defaults to "yaml". histogram_only: whether to only compute histograms. Defaults to False. + extra_params: other optional arguments. Currently supported arguments are : + 'allowed_shape_difference' (default 5) can be used to change the default tolerance of + the allowed shape differences between the image and label items. In case of shape mismatch below + the tolerance, the label image will be resized to match the image using nearest interpolation. - Raises: - ValueError if device is GPU and worker > 0. Examples: .. code-block:: python @@ -121,6 +124,7 @@ def __init__( hist_range: Optional[list] = None, fmt: Optional[str] = "yaml", histogram_only: bool = False, + **extra_params, ): if path.isfile(output_path): warnings.warn(f"File {output_path} already exists and will be overwritten.") @@ -139,6 +143,7 @@ def __init__( self.hist_range: list = [-500, 500] if hist_range is None else hist_range self.fmt = fmt self.histogram_only = histogram_only + self.extra_params = extra_params @staticmethod def _check_data_uniformity(keys: List[str], result: Dict): @@ -206,6 +211,17 @@ def get_all_case_stats(self, key="training", transform_list=None): EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float), Orientationd(keys=keys, axcodes="RAS"), ] + if self.label_key is not None: + + allowed_shape_difference = self.extra_params.pop("allowed_shape_difference", 5) + transform_list.append( + EnsureSameShaped( + keys=self.label_key, + source_key=self.image_key, + allowed_shape_difference=allowed_shape_difference, + ) + ) + transform = Compose(transform_list) files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key) diff --git a/monai/apps/auto3dseg/transforms.py b/monai/apps/auto3dseg/transforms.py new file mode 100644 index 0000000000..2793eb9202 --- /dev/null +++ b/monai/apps/auto3dseg/transforms.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Dict, Hashable, Mapping + +import numpy as np +import torch + +from monai.config import KeysCollection +from monai.networks.utils import pytorch_after +from monai.transforms import MapTransform + + +class EnsureSameShaped(MapTransform): + """ + Checks if segmentation label images (in keys) have the same spatial shape as the main image (in source_key), + and raise an error if the shapes are significantly different. + If the shapes are only slightly different (within an allowed_shape_difference in each dim), then resize the label using + nearest interpolation. This transform is designed to correct datasets with slight label shape mismatches. + Generally image and segmentation label must have the same spatial shape, however some public datasets are having slight + shape mismatches, which will cause potential crashes when calculating loss or metric functions. + """ + + def __init__( + self, + keys: KeysCollection = "label", + allow_missing_keys: bool = False, + source_key: str = "image", + allowed_shape_difference: int = 5, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be compared to the source_key item shape. + allow_missing_keys: do not raise exception if key is missing. + source_key: key of the item with the reference shape. + allowed_shape_difference: raises error if shapes are different more than this value in any dimension, + otherwise corrects for the shape mismatch using nearest interpolation. + + """ + super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) + self.source_key = source_key + self.allowed_shape_difference = allowed_shape_difference + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + image_shape = d[self.source_key].shape[1:] + for key in self.key_iterator(d): + label_shape = d[key].shape[1:] + if label_shape != image_shape: + if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference): + warnings.warn( + f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}," + f"the meta-data was not updated." + ) + d[key] = torch.nn.functional.interpolate( + input=d[key].unsqueeze(0), + size=image_shape, + mode="nearest-exact" if pytorch_after(1, 11) else "nearest", + ).squeeze(0) + else: + raise ValueError(f"The {key} shape {label_shape} is different from the source shape {image_shape}.") + return d diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index d6612de02c..f1f5fc41c0 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -16,7 +16,6 @@ import numpy as np import torch -import torch.nn.functional as F from monai.apps.utils import get_logger from monai.auto3dseg.operations import Operations, SampleOperations, SummaryOperations @@ -33,7 +32,7 @@ from monai.data import MetaTensor, affine_to_spacing from monai.transforms.transform import MapTransform from monai.transforms.utils_pytorch_numpy_unification import sum, unique -from monai.utils import convert_to_numpy, pytorch_after +from monai.utils import convert_to_numpy from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys from monai.utils.misc import ImageMetaKey, label_union @@ -326,16 +325,7 @@ def __call__(self, data) -> dict: ndas_label = d[self.label_key] # (H,W,D) if ndas_label.shape != ndas[0].shape: - # if image and label shapes are different, check if they are close - if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10): - logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}") - ndas_label = F.interpolate( - input=ndas_label.unsqueeze(0).unsqueeze(0), - size=list(ndas[0].shape), - mode="nearest-exact" if pytorch_after(1, 11) else "nearest", - )[0, 0] - else: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] @@ -465,16 +455,7 @@ def __call__(self, data): ndas_label = d[self.label_key] # (H,W,D) if ndas_label.shape != ndas[0].shape: - # if image and label shapes are different, check if they are close - if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10): - logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}") - ndas_label = F.interpolate( - input=ndas_label.unsqueeze(0).unsqueeze(0), - size=list(ndas[0].shape), - mode="nearest-exact" if pytorch_after(1, 11) else "nearest", - )[0, 0] - else: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds] diff --git a/tests/test_spacing.py b/tests/test_spacing.py index ba44bf76f2..90fc05b40f 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -20,7 +20,7 @@ from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import fall_back_tuple -from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose +from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick TESTS = [] for device in TEST_DEVICES: @@ -261,6 +261,7 @@ TEST_INVERSE.append([*d, recompute, align, scale_extent]) +@skip_if_quick class TestSpacingCase(unittest.TestCase): @parameterized.expand(TESTS) def test_spacing(self, init_param, img, affine, data_param, expected_output, device):