diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 6af46d11ea..785e8c7b88 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from itertools import chain from typing import List, Optional @@ -18,9 +19,10 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.data.meta_tensor import MetaTensor from monai.data.utils import affine_to_spacing from monai.transforms import concatenate -from monai.utils import PostFix, convert_data_type +from monai.utils import PostFix, convert_data_type, convert_to_tensor DEFAULT_POST_FIX = PostFix.meta() @@ -30,9 +32,9 @@ class DatasetSummary: This class provides a way to calculate a reasonable output voxel spacing according to the input dataset. The achieved values can used to resample the input in 3d segmentation tasks (like using as the `pixdim` parameter in `monai.transforms.Spacingd`). - In addition, it also supports to count the mean, std, min and max intensities of the input, + In addition, it also supports to compute the mean, std, min and max intensities of the input, and these statistics are helpful for image normalization - (like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). + (as parameters of `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`). The algorithm for calculation refers to: `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. @@ -58,6 +60,7 @@ def __init__( for example, for data with key `image`, the metadata by default is in `image_meta_dict`. the metadata is a dictionary object which contains: filename, affine, original_shape, etc. if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`. + This is not required if `data[image_key]` is a MetaTensor. meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the metadata from dict, the metadata is a dictionary object (default: ``meta_dict``). num_workers: how many subprocesses to use for data loading. @@ -80,17 +83,21 @@ def collect_meta_data(self): """ for data in self.data_loader: - if self.meta_key not in data: - raise ValueError(f"To collect metadata for the dataset, key `{self.meta_key}` must exist in `data`.") - self.all_meta_data.append(data[self.meta_key]) + if isinstance(data[self.image_key], MetaTensor): + meta_dict = data[self.image_key].meta + elif self.meta_key in data: + meta_dict = data[self.meta_key] + else: + warnings.warn(f"To collect metadata for the dataset, `{self.meta_key}` or `data.meta` must exist.") + self.all_meta_data.append(meta_dict) def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. - So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". - After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix, + otherwise, the `data[spacing_key]` must be a vector of pixdim values. Args: spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``). @@ -103,7 +110,15 @@ def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data] + spacings = [] + for data in self.all_meta_data: + spacing_vals = convert_to_tensor(data[spacing_key][0], track_meta=False, wrap_sequence=True) + if spacing_vals.ndim == 1: # vector + spacings.append(spacing_vals[:3][None]) + elif spacing_vals.ndim == 2: # matrix + spacings.append(affine_to_spacing(spacing_vals, 3)[None]) + else: + raise ValueError("data[spacing_key] must be a vector or a matrix.") all_spacings = concatenate(to_cat=spacings, axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index d0531b28a0..a5b5eee28f 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -20,10 +20,8 @@ from monai.data import Dataset, DatasetSummary, create_test_image_3d from monai.transforms import LoadImaged from monai.transforms.compose import Compose -from monai.transforms.meta_utility.dictionary import FromMetaTensord from monai.transforms.utility.dictionary import ToNumpyd from monai.utils import set_determinism -from monai.utils.enums import PostFix def test_collate(batch): @@ -56,7 +54,6 @@ def test_spacing_intensity(self): t = Compose( [ LoadImaged(keys=["image", "label"]), - FromMetaTensord(keys=["image", "label"]), ToNumpyd(keys=["image", "label", "image_meta_dict", "label_meta_dict"]), ] ) @@ -65,7 +62,7 @@ def test_spacing_intensity(self): # test **kwargs of `DatasetSummary` for `DataLoader` calculator = DatasetSummary(dataset, num_workers=4, meta_key="image_meta_dict", collate_fn=test_collate) - target_spacing = calculator.get_target_spacing() + target_spacing = calculator.get_target_spacing(spacing_key="pixdim") self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) calculator.calculate_statistics() np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5) @@ -93,10 +90,10 @@ def test_anisotropic_spacing(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - t = Compose([LoadImaged(keys=["image", "label"]), FromMetaTensord(keys=["image", "label"])]) + t = Compose([LoadImaged(keys=["image", "label"])]) dataset = Dataset(data=data_dicts, transform=t) - calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix=PostFix.meta()) + calculator = DatasetSummary(dataset, num_workers=4) target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))