From cc889f2acea405e5b9dc55dc83728d477192973b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 31 Dec 2021 16:33:40 +0800 Subject: [PATCH 1/3] [DLMED] update dataset summary Signed-off-by: Nic Ma --- monai/data/dataset_summary.py | 20 +++++++++++++------- tests/test_dataset_summary.py | 8 +++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 43178d2536..610c14df7e 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -15,6 +15,7 @@ import numpy as np import torch +from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset @@ -38,6 +39,7 @@ def __init__( dataset: Dataset, image_key: Optional[str] = "image", label_key: Optional[str] = "label", + meta_key: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", num_workers: int = 0, **kwargs, @@ -47,11 +49,16 @@ def __init__( dataset: dataset from which to load the data. image_key: key name of images (default: ``image``). label_key: key name of labels (default: ``label``). + meta_key: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`. meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict, the meta data is a dictionary object (default: ``meta_dict``). num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process (default: ``0``). - kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``). + kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader, + this class forces to use ``batch_size=1``. """ @@ -59,18 +66,17 @@ def __init__( self.image_key = image_key self.label_key = label_key - if image_key: - self.meta_key = f"{image_key}_{meta_key_postfix}" + self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}" self.all_meta_data: List = [] def collect_meta_data(self): """ This function is used to collect the meta data for all images of the dataset. """ - if not self.meta_key: - raise ValueError("To collect meta data for the dataset, `meta_key` should exist.") for data in self.data_loader: + if self.meta_key not in data: + raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): @@ -78,8 +84,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, decrease the spacing value of the maximum axis according to percentile. - So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading - with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. + So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". + After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: spacing_key: key of spacing in meta data (default: ``pixdim``). diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 667dc3f190..5c89835b48 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -40,9 +40,11 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + dataset = Dataset( + data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + ) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1") target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -74,7 +76,7 @@ def test_anisotropic_spacing(self): dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) - calculator = DatasetSummary(dataset, num_workers=4) + calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix="meta_dict") target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0) np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8)) From 26e8909d6dfee32bce102f152cc3932f015bcd2b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 31 Dec 2021 17:09:51 +0800 Subject: [PATCH 2/3] [DLMED] enhance data type Signed-off-by: Nic Ma --- monai/data/dataset_summary.py | 10 ++++++++-- tests/test_dataset_summary.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 610c14df7e..dd8a94143b 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -18,6 +18,8 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.transforms import concatenate +from monai.utils import convert_data_type class DatasetSummary: @@ -98,8 +100,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - - all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy() + all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) if max(target_spacing) / min(target_spacing) >= anisotropic_threshold: @@ -132,6 +134,8 @@ def calculate_statistics(self, foreground_threshold: int = 0): image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) voxel_max.append(image.max().item()) voxel_min.append(image.min().item()) @@ -175,6 +179,8 @@ def calculate_percentiles( image, label = data[self.image_key], data[self.label_key] else: image, label = data + image, *_ = convert_data_type(data=image, output_type=torch.Tensor) + label, *_ = convert_data_type(data=label, output_type=torch.Tensor) intensities = image[torch.where(label > foreground_threshold)].tolist() if sampling_flag: diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 5c89835b48..7724a150d1 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -44,7 +44,16 @@ def test_spacing_intensity(self): data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) ) - calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1") + # test **kwargs of `DatasetSummary` for `DataLoader` + def _test_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, dict): + return elem_type({key: _test_collate([d[key] for d in batch]) for key in elem}) + + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=_test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) From f1b25afb57e6aa7aae5d06817b06768125d4fa16 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 1 Jan 2022 06:32:33 +0800 Subject: [PATCH 3/3] [DLMED] fix pickle issue Signed-off-by: Nic Ma --- tests/test_dataset_summary.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 7724a150d1..5569c51a0c 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -22,6 +22,15 @@ from monai.utils import set_determinism +def test_collate(batch): + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, np.ndarray): + return np.stack(batch, 0) + elif isinstance(elem, dict): + return elem_type({key: test_collate([d[key] for d in batch]) for key in elem}) + + class TestDatasetSummary(unittest.TestCase): def test_spacing_intensity(self): set_determinism(seed=0) @@ -45,15 +54,7 @@ def test_spacing_intensity(self): ) # test **kwargs of `DatasetSummary` for `DataLoader` - def _test_collate(batch): - elem = batch[0] - elem_type = type(elem) - if isinstance(elem, np.ndarray): - return np.stack(batch, 0) - elif isinstance(elem, dict): - return elem_type({key: _test_collate([d[key] for d in batch]) for key in elem}) - - calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=_test_collate) + calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0))