diff --git a/docs/source/data.rst b/docs/source/data.rst index c528deeacc..0ab64edb7b 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -200,6 +200,8 @@ DatasetSummary Decathlon Datalist ~~~~~~~~~~~~~~~~~~ .. autofunction:: monai.data.load_decathlon_datalist +.. autofunction:: monai.data.load_decathlon_properties +.. autofunction:: monai.data.check_missing_files DataLoader diff --git a/monai/data/__init__.py b/monai/data/__init__.py index df436ba667..fafe282358 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -24,7 +24,7 @@ ZipDataset, ) from .dataset_summary import DatasetSummary -from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties +from .decathlon_datalist import check_missing_files, load_decathlon_datalist, load_decathlon_properties from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 78440fe11c..40f51f0e75 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -13,6 +13,7 @@ import os from typing import Dict, List, Optional, Sequence, Union, overload +from monai.config import KeysCollection from monai.utils import ensure_tuple @@ -148,3 +149,40 @@ def load_decathlon_properties(data_property_file_path: str, property_keys: Union raise KeyError(f"key {key} is not in the data property file.") properties[key] = json_data[key] return properties + + +def check_missing_files( + datalist: List[Dict], keys: KeysCollection, root_dir: Optional[str] = None, allow_missing_keys: bool = False +): + """Checks whether some files in the Decathlon datalist are missing. + It would be helpful to check missing files before a heavy training run. + + Args: + datalist: a list of data items, every item is a dictionary. + ususally generated by `load_decathlon_datalist` API. + keys: expected keys to check in the datalist. + root_dir: if not None, provides the root dir for the relative file paths in `datalist`. + allow_missing_keys: whether allow missing keys in the datalist items. + if False, raise exception if missing. default to False. + + Returns: + A list of missing filenames. + + """ + missing_files = [] + for item in datalist: + for k in ensure_tuple(keys): + if k not in item: + if not allow_missing_keys: + raise ValueError(f"key `{k}` is missing in the datalist item: {item}") + continue + + for f in ensure_tuple(item[k]): + if not isinstance(f, str): + raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.") + if isinstance(root_dir, str): + f = os.path.join(root_dir, f) + if not os.path.exists(f): + missing_files.append(f) + + return missing_files diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 622a4865d1..bf8fe39ec6 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -627,6 +627,8 @@ def get_data(self, img): It computes `spatial_shape` and stores it in meta dict. When loading a list of files, they are stacked together at a new dimension as the first dimension, and the meta data of the first image is used to represent the output meta data. + Note that it will switch axis 0 and 1 after loading the array because the `HW` definition in PIL + is different from other common medical packages. Args: img: a PIL Image object loaded from a file or a list of PIL Image objects. diff --git a/tests/min_tests.py b/tests/min_tests.py index ed48f6986f..05bd6781c1 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -147,6 +147,7 @@ def run_testsuit(): "test_handler_mlflow", "test_prepare_batch_extra_input", "test_prepare_batch_default", + "test_check_missing_files", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_check_missing_files.py b/tests/test_check_missing_files.py new file mode 100644 index 0000000000..ff7a43b3c6 --- /dev/null +++ b/tests/test_check_missing_files.py @@ -0,0 +1,55 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np + +from monai.data import check_missing_files + + +class TestCheckMissingFiles(unittest.TestCase): + def test_content(self): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + + datalist = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": [os.path.join(tempdir, "test_label1.nii.gz"), os.path.join(tempdir, "test_extra1.nii.gz")], + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label_missing.nii.gz"), + }, + ] + + missings = check_missing_files(datalist=datalist, keys=["image", "label"]) + self.assertEqual(len(missings), 1) + self.assertEqual(missings[0], os.path.join(tempdir, "test_label_missing.nii.gz")) + + # test with missing key and relative path + datalist = [{"image": "test_image1.nii.gz", "label": "test_label_missing.nii.gz"}] + missings = check_missing_files( + datalist=datalist, keys=["image", "label", "test"], root_dir=tempdir, allow_missing_keys=True + ) + self.assertEqual(missings[0], os.path.join(tempdir, "test_label_missing.nii.gz")) + + +if __name__ == "__main__": + unittest.main()