From 59f067e17ee299a809c43c674e20173315f1f286 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 19 Jan 2022 23:49:58 +0000 Subject: [PATCH 1/2] add dev mode collate Signed-off-by: Wenqi Li --- monai/data/utils.py | 60 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 0f7194ec39..d4b54e3fab 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -11,11 +11,12 @@ import hashlib import json +import logging import math import os import pickle import warnings -from collections import defaultdict +from collections import abc, defaultdict from copy import deepcopy from functools import reduce from itertools import product, starmap, zip_longest @@ -254,6 +255,55 @@ def get_valid_patch_size(image_size: Sequence[int], patch_size: Union[Sequence[i return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_)) +def dev_collate(batch, level=1): + """collate with detailed error messages for debugging purposes.""" + + elem = batch[0] + elem_type = type(elem) + l_str = ">" * level + batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}" + if isinstance(elem, torch.Tensor): + try: + logging.getLogger("dev_collate").critical(f"{l_str} collate/stack a list of tensors") + return torch.stack(batch, 0) + except TypeError as e: + logging.getLogger("dev_collate").critical( + f"{l_str} E: {e}, while stacking {[type(elem).__name__ for elem in batch]} in collate({batch_str})" + ) + return + except RuntimeError as e: + logging.getLogger("dev_collate").critical( + f"{l_str} E: {e}, while stacking {[elem.shape for elem in batch]} in collate({batch_str})" + ) + return + elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + if elem_type.__name__ in ["ndarray", "memmap"]: + logging.getLogger("dev_collate").critical(f"{l_str} collate/stack a list of numpy arrays") + return dev_collate([torch.as_tensor(b) for b in batch], level=level) + elif elem.shape == (): # scalars + return batch + elif isinstance(elem, (float, int, str, bytes)): + return batch + elif isinstance(elem, abc.Mapping): + out = {} + for key in elem: + logging.getLogger("dev_collate").critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys') + out[key] = dev_collate([d[key] for d in batch], level=level + 1) + return out + elif isinstance(elem, abc.Sequence): + it = iter(batch) + sizes = [len(elem) for elem in it] + logging.getLogger("dev_collate").critical(f"{l_str} collate list of sizes: {sizes}.") + if any(s != sizes[0] for s in sizes): + logging.getLogger("dev_collate").critical( + f"{l_str} collate list inconsistent sizes, " f"got size: {sizes}, in collate({batch_str})" + ) + transposed = zip(*batch) + return [dev_collate(samples, level=level + 1) for samples in transposed] + logging.getLogger("dev_collate").critical(f"{l_str} E: unsupported type in collate {batch_str}.") + return + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -266,14 +316,14 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - key = None + key, data_k = None, None try: - elem = batch[0] if isinstance(elem, Mapping): ret = {} for k in elem: key = k - ret[k] = default_collate([d[k] for d in data]) + data_k = [d[key] for d in data] + ret[key] = default_collate(data_k) return ret return default_collate(data) except RuntimeError as re: @@ -286,6 +336,7 @@ def list_data_collate(batch: Sequence): + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation)." ) + _ = dev_collate(data_k) if data_k is not None else dev_collate(data) raise RuntimeError(re_str) from re except TypeError as re: re_str = str(re) @@ -297,6 +348,7 @@ def list_data_collate(batch: Sequence): + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation)." ) + _ = dev_collate(data_k) if data_k is not None else dev_collate(data) raise TypeError(re_str) from re From b6df1ad4efd80cab32035ff430acd397df4a5088 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 21 Jan 2022 17:21:03 +0000 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/data/utils.py | 48 +++++++++++++++++++++++++-------------- tests/test_dev_collate.py | 45 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 tests/test_dev_collate.py diff --git a/monai/data/utils.py b/monai/data/utils.py index d4b54e3fab..64291fe56d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -255,31 +255,39 @@ def get_valid_patch_size(image_size: Sequence[int], patch_size: Union[Sequence[i return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_)) -def dev_collate(batch, level=1): - """collate with detailed error messages for debugging purposes.""" +def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): + """ + Recursively run collate logic and provide detailed loggings for debugging purposes. + + Args: + batch: batch input to collate + level: current level of recursion for logging purposes + logger_name: name of logger to use for logging + See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn + """ elem = batch[0] elem_type = type(elem) l_str = ">" * level batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}" if isinstance(elem, torch.Tensor): try: - logging.getLogger("dev_collate").critical(f"{l_str} collate/stack a list of tensors") + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of tensors") return torch.stack(batch, 0) except TypeError as e: - logging.getLogger("dev_collate").critical( - f"{l_str} E: {e}, while stacking {[type(elem).__name__ for elem in batch]} in collate({batch_str})" + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})" ) return except RuntimeError as e: - logging.getLogger("dev_collate").critical( - f"{l_str} E: {e}, while stacking {[elem.shape for elem in batch]} in collate({batch_str})" + logging.getLogger(logger_name).critical( + f"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})" ) return elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": if elem_type.__name__ in ["ndarray", "memmap"]: - logging.getLogger("dev_collate").critical(f"{l_str} collate/stack a list of numpy arrays") - return dev_collate([torch.as_tensor(b) for b in batch], level=level) + logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of numpy arrays") + return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name) elif elem.shape == (): # scalars return batch elif isinstance(elem, (float, int, str, bytes)): @@ -287,20 +295,26 @@ def dev_collate(batch, level=1): elif isinstance(elem, abc.Mapping): out = {} for key in elem: - logging.getLogger("dev_collate").critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys') - out[key] = dev_collate([d[key] for d in batch], level=level + 1) + logging.getLogger(logger_name).critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys') + out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name) return out elif isinstance(elem, abc.Sequence): it = iter(batch) - sizes = [len(elem) for elem in it] - logging.getLogger("dev_collate").critical(f"{l_str} collate list of sizes: {sizes}.") + els = list(it) + try: + sizes = [len(elem) for elem in els] # may not have `len` + except TypeError: + types = [type(elem).__name__ for elem in els] + logging.getLogger(logger_name).critical(f"{l_str} E: type {types} in collate({batch_str})") + return + logging.getLogger(logger_name).critical(f"{l_str} collate list of sizes: {sizes}.") if any(s != sizes[0] for s in sizes): - logging.getLogger("dev_collate").critical( - f"{l_str} collate list inconsistent sizes, " f"got size: {sizes}, in collate({batch_str})" + logging.getLogger(logger_name).critical( + f"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})" ) transposed = zip(*batch) - return [dev_collate(samples, level=level + 1) for samples in transposed] - logging.getLogger("dev_collate").critical(f"{l_str} E: unsupported type in collate {batch_str}.") + return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed] + logging.getLogger(logger_name).critical(f"{l_str} E: unsupported type in collate {batch_str}.") return diff --git a/tests/test_dev_collate.py b/tests/test_dev_collate.py new file mode 100644 index 0000000000..83dbd71d28 --- /dev/null +++ b/tests/test_dev_collate.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.utils import dev_collate + +TEST_CASES = [ + [ + [ + {"img": 2, "meta": {"shape": [torch.tensor(1.0)]}}, + {"img": 3, "meta": {"shape": [np.asarray(1.0)]}}, + {"img": 4, "meta": {"shape": [torch.tensor(1.0)]}}, + ], + "got numpy.ndarray", + ], + [[["img", np.array([2])], ["img", np.array([3, 4])], ["img", np.array([4])]], "size"], + [[["img", [2]], ["img", [3, 4]], ["img", 4]], "type"], + [[["img", [2, 2]], ["img", [3, 4]], ["img", 4]], "type"], +] + + +class DevCollateTest(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_dev_collate(self, inputs, msg): + with self.assertLogs(level=logging.CRITICAL) as log: + dev_collate(inputs) + self.assertRegex(" ".join(log.output), f"{msg}") + + +if __name__ == "__main__": + unittest.main()