Skip to content
30 changes: 21 additions & 9 deletions monai/data/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import numpy as np
import torch

from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.transforms import concatenate
from monai.utils import convert_data_type


class DatasetSummary:
Expand All @@ -38,6 +41,7 @@ def __init__(
dataset: Dataset,
image_key: Optional[str] = "image",
label_key: Optional[str] = "label",
meta_key: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
num_workers: int = 0,
**kwargs,
Expand All @@ -47,39 +51,43 @@ def __init__(
dataset: dataset from which to load the data.
image_key: key name of images (default: ``image``).
label_key: key name of labels (default: ``label``).
meta_key: explicitly indicate the key of the corresponding meta data dictionary.
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the meta data is a dictionary object which contains: filename, affine, original_shape, etc.
if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`.
meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict,
the meta data is a dictionary object (default: ``meta_dict``).
num_workers: how many subprocesses to use for data loading.
``0`` means that the data will be loaded in the main process (default: ``0``).
kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``).
kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader,
this class forces to use ``batch_size=1``.

"""

self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs)

self.image_key = image_key
self.label_key = label_key
if image_key:
self.meta_key = f"{image_key}_{meta_key_postfix}"
self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}"
self.all_meta_data: List = []

def collect_meta_data(self):
"""
This function is used to collect the meta data for all images of the dataset.
"""
if not self.meta_key:
raise ValueError("To collect meta data for the dataset, `meta_key` should exist.")

for data in self.data_loader:
if self.meta_key not in data:
raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.")
self.all_meta_data.append(data[self.meta_key])

def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0):
"""
Calculate the target spacing according to all spacings.
If the target spacing is very anisotropic,
decrease the spacing value of the maximum axis according to percentile.
So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading
with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.
So far, this function only supports NIFTI images which store spacings in headers with key "pixdim".
After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.

Args:
spacing_key: key of spacing in meta data (default: ``pixdim``).
Expand All @@ -92,8 +100,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold:
self.collect_meta_data()
if spacing_key not in self.all_meta_data[0]:
raise ValueError("The provided spacing_key is not in self.all_meta_data.")

all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy()
all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0)
all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)

target_spacing = np.median(all_spacings, axis=0)
if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:
Expand Down Expand Up @@ -126,6 +134,8 @@ def calculate_statistics(self, foreground_threshold: int = 0):
image, label = data[self.image_key], data[self.label_key]
else:
image, label = data
image, *_ = convert_data_type(data=image, output_type=torch.Tensor)
label, *_ = convert_data_type(data=label, output_type=torch.Tensor)

voxel_max.append(image.max().item())
voxel_min.append(image.min().item())
Expand Down Expand Up @@ -169,6 +179,8 @@ def calculate_percentiles(
image, label = data[self.image_key], data[self.label_key]
else:
image, label = data
image, *_ = convert_data_type(data=image, output_type=torch.Tensor)
label, *_ = convert_data_type(data=label, output_type=torch.Tensor)

intensities = image[torch.where(label > foreground_threshold)].tolist()
if sampling_flag:
Expand Down
18 changes: 15 additions & 3 deletions tests/test_dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
from monai.utils import set_determinism


def test_collate(batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, np.ndarray):
return np.stack(batch, 0)
elif isinstance(elem, dict):
return elem_type({key: test_collate([d[key] for d in batch]) for key in elem})


class TestDatasetSummary(unittest.TestCase):
def test_spacing_intensity(self):
set_determinism(seed=0)
Expand All @@ -40,9 +49,12 @@ def test_spacing_intensity(self):
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
]

dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"]))
dataset = Dataset(
data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"])
)

calculator = DatasetSummary(dataset, num_workers=4)
# test **kwargs of `DatasetSummary` for `DataLoader`
calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate)

target_spacing = calculator.get_target_spacing()
self.assertEqual(target_spacing, (1.0, 1.0, 1.0))
Expand Down Expand Up @@ -74,7 +86,7 @@ def test_anisotropic_spacing(self):

dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"]))

calculator = DatasetSummary(dataset, num_workers=4)
calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix="meta_dict")

target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0)
np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))
Expand Down