Skip to content
Merged
4 changes: 4 additions & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ jobs:
python -m pip install -r requirements-dev.txt
python -m pip list
python setup.py develop # test no compile installation
shell: bash
- if: runner.os != 'windows'
name: Run compiled (${{ runner.os }})
run: |
python setup.py develop --uninstall
BUILD_MONAI=1 python setup.py develop # compile the cpp extensions
shell: bash
Expand Down
20 changes: 18 additions & 2 deletions monai/apps/auto3dseg/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import torch

from monai.apps.auto3dseg.transforms import EnsureSameShaped
from monai.apps.utils import get_logger
from monai.auto3dseg import SegSummarizer
from monai.auto3dseg.utils import datafold_read
Expand Down Expand Up @@ -69,9 +70,11 @@ class DataAnalyzer:
hist_range: ranges to compute histogram for each image channel.
fmt: format used to save the analysis results. Defaults to "yaml".
histogram_only: whether to only compute histograms. Defaults to False.
extra_params: other optional arguments. Currently supported arguments are :
'allowed_shape_difference' (default 5) can be used to change the default tolerance of
the allowed shape differences between the image and label items. In case of shape mismatch below
the tolerance, the label image will be resized to match the image using nearest interpolation.

Raises:
ValueError if device is GPU and worker > 0.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -121,6 +124,7 @@ def __init__(
hist_range: Optional[list] = None,
fmt: Optional[str] = "yaml",
histogram_only: bool = False,
**extra_params,
):
if path.isfile(output_path):
warnings.warn(f"File {output_path} already exists and will be overwritten.")
Expand All @@ -139,6 +143,7 @@ def __init__(
self.hist_range: list = [-500, 500] if hist_range is None else hist_range
self.fmt = fmt
self.histogram_only = histogram_only
self.extra_params = extra_params

@staticmethod
def _check_data_uniformity(keys: List[str], result: Dict):
Expand Down Expand Up @@ -206,6 +211,17 @@ def get_all_case_stats(self, key="training", transform_list=None):
EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float),
Orientationd(keys=keys, axcodes="RAS"),
]
if self.label_key is not None:

allowed_shape_difference = self.extra_params.pop("allowed_shape_difference", 5)
transform_list.append(
EnsureSameShaped(
keys=self.label_key,
source_key=self.image_key,
allowed_shape_difference=allowed_shape_difference,
)
)

transform = Compose(transform_list)

files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)
Expand Down
71 changes: 71 additions & 0 deletions monai/apps/auto3dseg/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Dict, Hashable, Mapping

import numpy as np
import torch

from monai.config import KeysCollection
from monai.networks.utils import pytorch_after
from monai.transforms import MapTransform


class EnsureSameShaped(MapTransform):
"""
Checks if segmentation label images (in keys) have the same spatial shape as the main image (in source_key),
and raise an error if the shapes are significantly different.
If the shapes are only slightly different (within an allowed_shape_difference in each dim), then resize the label using
nearest interpolation. This transform is designed to correct datasets with slight label shape mismatches.
Generally image and segmentation label must have the same spatial shape, however some public datasets are having slight
shape mismatches, which will cause potential crashes when calculating loss or metric functions.
"""

def __init__(
self,
keys: KeysCollection = "label",
allow_missing_keys: bool = False,
source_key: str = "image",
allowed_shape_difference: int = 5,
Comment thread
myron marked this conversation as resolved.
) -> None:
"""
Args:
keys: keys of the corresponding items to be compared to the source_key item shape.
allow_missing_keys: do not raise exception if key is missing.
source_key: key of the item with the reference shape.
allowed_shape_difference: raises error if shapes are different more than this value in any dimension,
otherwise corrects for the shape mismatch using nearest interpolation.

"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.source_key = source_key
self.allowed_shape_difference = allowed_shape_difference

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
image_shape = d[self.source_key].shape[1:]
for key in self.key_iterator(d):
label_shape = d[key].shape[1:]
if label_shape != image_shape:
if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference):
warnings.warn(
f"The {key} with shape {label_shape} was resized to match the source shape {image_shape},"
f"the meta-data was not updated."
)
d[key] = torch.nn.functional.interpolate(
input=d[key].unsqueeze(0),
size=image_shape,
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
).squeeze(0)
else:
raise ValueError(f"The {key} shape {label_shape} is different from the source shape {image_shape}.")
return d
25 changes: 3 additions & 22 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy as np
import torch
import torch.nn.functional as F

from monai.apps.utils import get_logger
from monai.auto3dseg.operations import Operations, SampleOperations, SummaryOperations
Expand All @@ -33,7 +32,7 @@
from monai.data import MetaTensor, affine_to_spacing
from monai.transforms.transform import MapTransform
from monai.transforms.utils_pytorch_numpy_unification import sum, unique
from monai.utils import convert_to_numpy, pytorch_after
from monai.utils import convert_to_numpy
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
from monai.utils.misc import ImageMetaKey, label_union

Expand Down Expand Up @@ -326,16 +325,7 @@ def __call__(self, data) -> dict:
ndas_label = d[self.label_key] # (H,W,D)

if ndas_label.shape != ndas[0].shape:
# if image and label shapes are different, check if they are close
if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10):
logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}")
ndas_label = F.interpolate(
input=ndas_label.unsqueeze(0).unsqueeze(0),
size=list(ndas[0].shape),
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
)[0, 0]
else:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
Expand Down Expand Up @@ -465,16 +455,7 @@ def __call__(self, data):
ndas_label = d[self.label_key] # (H,W,D)

if ndas_label.shape != ndas[0].shape:
# if image and label shapes are different, check if they are close
if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10):
logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}")
ndas_label = F.interpolate(
input=ndas_label.unsqueeze(0).unsqueeze(0),
size=list(ndas[0].shape),
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
)[0, 0]
else:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_spacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.data.utils import affine_to_spacing
from monai.transforms import Spacing
from monai.utils import fall_back_tuple
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick

TESTS = []
for device in TEST_DEVICES:
Expand Down Expand Up @@ -261,6 +261,7 @@
TEST_INVERSE.append([*d, recompute, align, scale_extent])


@skip_if_quick
class TestSpacingCase(unittest.TestCase):
@parameterized.expand(TESTS)
def test_spacing(self, init_param, img, affine, data_param, expected_output, device):
Expand Down