diff --git a/monai/data/utils.py b/monai/data/utils.py index 0f7194ec39..79ef9bd7fb 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -11,11 +11,12 @@ import hashlib import json +import logging import math import os import pickle import warnings -from collections import defaultdict +from collections import abc, defaultdict from copy import deepcopy from functools import reduce from itertools import product, starmap, zip_longest @@ -254,6 +255,70 @@ def get_valid_patch_size(image_size: Sequence[int], patch_size: Union[Sequence[i return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_)) +def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): + """ + Recursively run collate logic and provide detailed loggings for debugging purposes. + It reports results at the 'critical' level, is therefore suitable in the context of exception handling. + + Args: + batch: batch input to collate + level: current level of recursion for logging purposes + logger_name: name of logger to use for logging + + See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn + """ + elem = batch[0] + elem_type = type(elem) + l_str = ">" * level + batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}" + if isinstance(elem, torch.Tensor): + try: + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of tensors") + return torch.stack(batch, 0) + except TypeError as e: + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})" + ) + return + except RuntimeError as e: + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})" + ) + return + elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + if elem_type.__name__ in ["ndarray", "memmap"]: + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of numpy arrays") + return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name) + elif elem.shape == (): # scalars + return batch + elif isinstance(elem, (float, int, str, bytes)): + return batch + elif isinstance(elem, abc.Mapping): + out = {} + for key in elem: + logging.getLogger(logger_name).critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys') + out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name) + return out + elif isinstance(elem, abc.Sequence): + it = iter(batch) + els = list(it) + try: + sizes = [len(elem) for elem in els] # may not have `len` + except TypeError: + types = [type(elem).__name__ for elem in els] + logging.getLogger(logger_name).critical(f"{l_str} E: type {types} in collate({batch_str})") + return + logging.getLogger(logger_name).critical(f"{l_str} collate list of sizes: {sizes}.") + if any(s != sizes[0] for s in sizes): + logging.getLogger(logger_name).critical( + f"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})" + ) + transposed = zip(*batch) + return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed] + logging.getLogger(logger_name).critical(f"{l_str} E: unsupported type in collate {batch_str}.") + return + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -268,12 +333,11 @@ def list_data_collate(batch: Sequence): data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None try: - elem = batch[0] if isinstance(elem, Mapping): ret = {} for k in elem: key = k - ret[k] = default_collate([d[k] for d in data]) + ret[key] = default_collate([d[key] for d in data]) return ret return default_collate(data) except RuntimeError as re: @@ -286,6 +350,7 @@ def list_data_collate(batch: Sequence): + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation)." ) + _ = dev_collate(data) raise RuntimeError(re_str) from re except TypeError as re: re_str = str(re) @@ -297,6 +362,7 @@ def list_data_collate(batch: Sequence): + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation)." ) + _ = dev_collate(data) raise TypeError(re_str) from re diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py new file mode 100644 index 0000000000..83dbd71d28 --- /dev/null +++ b/tests/test_dev_collate.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.utils import dev_collate + +TEST_CASES = [ + [ + [ + {"img": 2, "meta": {"shape": [torch.tensor(1.0)]}}, + {"img": 3, "meta": {"shape": [np.asarray(1.0)]}}, + {"img": 4, "meta": {"shape": [torch.tensor(1.0)]}}, + ], + "got numpy.ndarray", + ], + [[["img", np.array([2])], ["img", np.array([3, 4])], ["img", np.array([4])]], "size"], + [[["img", [2]], ["img", [3, 4]], ["img", 4]], "type"], + [[["img", [2, 2]], ["img", [3, 4]], ["img", 4]], "type"], +] + + +class DevCollateTest(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_dev_collate(self, inputs, msg): + with self.assertLogs(level=logging.CRITICAL) as log: + dev_collate(inputs) + self.assertRegex(" ".join(log.output), f"{msg}") + + +if __name__ == "__main__": + unittest.main()