diff --git a/docs/source/data.rst b/docs/source/data.rst index 0ab64edb7b..e8c68de853 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,6 +21,12 @@ Generic Interfaces :members: :special-members: __next__ +`DatasetFunc` +~~~~~~~~~~~~~ +.. autoclass:: DatasetFunc + :members: + :special-members: __next__ + `ShuffleBuffer` ~~~~~~~~~~~~~~~ .. autoclass:: ShuffleBuffer diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e7fa2b3107..b12a307663 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -17,6 +17,7 @@ CacheNTransDataset, CSVDataset, Dataset, + DatasetFunc, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 38371f384b..ccd831ee0f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -97,6 +97,56 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): return self._transform(index) +class DatasetFunc(Dataset): + """ + Execute function on the input dataset and leverage the output to act as a new Dataset. + It can be used to load / fetch the basic dataset items, like the list of `image, label` paths. + Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc. + The `data` arg of `Dataset` will be applied to the first arg of callable `func`. + Usage example:: + + data_list = DatasetFunc( + data="path to file", + func=monai.data.load_decathlon_datalist, + data_list_key="validation", + base_dir="path to base dir", + ) + # partition dataset for every rank + data_partition = DatasetFunc( + data=data_list, + func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], + num_partitions=torch.distributed.get_world_size(), + ) + dataset = Dataset(data=data_partition, transform=transforms) + + Args: + data: input data for the func to process, will apply to `func` as the first arg. + func: callable function to generate dataset items. + kwargs: other arguments for the `func` except for the first arg. + + """ + + def __init__(self, data: Any, func: Callable, **kwargs) -> None: + super().__init__(data=None, transform=None) # type:ignore + self.src = data + self.func = func + self.kwargs = kwargs + self.reset() + + def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs): + """ + Reset the dataset items with specified `func`. + + Args: + data: if not None, execute `func` on it, default to `self.src`. + func: if not None, execute the `func` with specified `kwargs`, default to `self.func`. + kwargs: other arguments for the `func` except for the first arg. + + """ + src = self.src if data is None else data + self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs) + + class PersistentDataset(Dataset): """ Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data, diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py new file mode 100644 index 0000000000..b3f6b95403 --- /dev/null +++ b/tests/test_dataset_func.py @@ -0,0 +1,52 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset + + +class TestDatasetFunc(unittest.TestCase): + def test_seg_values(self): + with tempfile.TemporaryDirectory() as tempdir: + # prepare test datalist file + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + + data_list = DatasetFunc( + data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir + ) + # partition dataset for train / validation + data_partition = DatasetFunc( + data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2 + ) + dataset = Dataset(data=data_partition, transform=None) + self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(dataset[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index ac2118d99f..0fcda21feb 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -35,7 +35,7 @@ def test_scaling(self): scaler = ScaleIntensityRangePercentilesd( keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max ) - assert_allclose(p(expected), scaler(data)["img"]) + assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) def test_relative_scaling(self): img = self.imt