Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import hashlib
import json
import logging
import math
import os
import pickle
import warnings
from collections import defaultdict
from collections import abc, defaultdict
from copy import deepcopy
from functools import reduce
from itertools import product, starmap, zip_longest
Expand Down Expand Up @@ -254,6 +255,70 @@ def get_valid_patch_size(image_size: Sequence[int], patch_size: Union[Sequence[i
return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))


def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
"""
Recursively run collate logic and provide detailed loggings for debugging purposes.
It reports results at the 'critical' level, is therefore suitable in the context of exception handling.

Args:
batch: batch input to collate
level: current level of recursion for logging purposes
logger_name: name of logger to use for logging

See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn
"""
elem = batch[0]
elem_type = type(elem)
l_str = ">" * level
batch_str = f"{batch[:10]}{' ... ' if len(batch) > 10 else ''}"
if isinstance(elem, torch.Tensor):
try:
logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of tensors")
return torch.stack(batch, 0)
except TypeError as e:
logging.getLogger(logger_name).critical(
f"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})"
)
return
except RuntimeError as e:
logging.getLogger(logger_name).critical(
f"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})"
)
return
elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_":
if elem_type.__name__ in ["ndarray", "memmap"]:
logging.getLogger(logger_name).critical(f"{l_str} collate/stack a list of numpy arrays")
return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name)
elif elem.shape == (): # scalars
return batch
elif isinstance(elem, (float, int, str, bytes)):
return batch
elif isinstance(elem, abc.Mapping):
out = {}
for key in elem:
logging.getLogger(logger_name).critical(f'{l_str} collate dict key "{key}" out of {len(elem)} keys')
out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name)
return out
elif isinstance(elem, abc.Sequence):
it = iter(batch)
els = list(it)
try:
sizes = [len(elem) for elem in els] # may not have `len`
except TypeError:
types = [type(elem).__name__ for elem in els]
logging.getLogger(logger_name).critical(f"{l_str} E: type {types} in collate({batch_str})")
return
logging.getLogger(logger_name).critical(f"{l_str} collate list of sizes: {sizes}.")
if any(s != sizes[0] for s in sizes):
logging.getLogger(logger_name).critical(
f"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})"
)
transposed = zip(*batch)
return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed]
logging.getLogger(logger_name).critical(f"{l_str} E: unsupported type in collate {batch_str}.")
return


def list_data_collate(batch: Sequence):
"""
Enhancement for PyTorch DataLoader default collate.
Expand All @@ -268,12 +333,11 @@ def list_data_collate(batch: Sequence):
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
key = None
try:
elem = batch[0]
if isinstance(elem, Mapping):
ret = {}
for k in elem:
key = k
ret[k] = default_collate([d[k] for d in data])
ret[key] = default_collate([d[key] for d in data])
return ret
return default_collate(data)
except RuntimeError as re:
Expand All @@ -286,6 +350,7 @@ def list_data_collate(batch: Sequence):
+ "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its "
+ "documentation)."
)
_ = dev_collate(data)
raise RuntimeError(re_str) from re
except TypeError as re:
re_str = str(re)
Expand All @@ -297,6 +362,7 @@ def list_data_collate(batch: Sequence):
+ "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem "
+ "(check its documentation)."
)
_ = dev_collate(data)
raise TypeError(re_str) from re


Expand Down
45 changes: 45 additions & 0 deletions tests/test_dev_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.utils import dev_collate

TEST_CASES = [
[
[
{"img": 2, "meta": {"shape": [torch.tensor(1.0)]}},
{"img": 3, "meta": {"shape": [np.asarray(1.0)]}},
{"img": 4, "meta": {"shape": [torch.tensor(1.0)]}},
],
"got numpy.ndarray",
],
[[["img", np.array([2])], ["img", np.array([3, 4])], ["img", np.array([4])]], "size"],
[[["img", [2]], ["img", [3, 4]], ["img", 4]], "type"],
[[["img", [2, 2]], ["img", [3, 4]], ["img", 4]], "type"],
]


class DevCollateTest(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_dev_collate(self, inputs, msg):
with self.assertLogs(level=logging.CRITICAL) as log:
dev_collate(inputs)
self.assertRegex(" ".join(log.output), f"{msg}")


if __name__ == "__main__":
unittest.main()