Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu committed Oct 26, 2023
1 parent 85243f5 commit 13a914c
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import numpy as np
import torch
from torch.utils.data._utils.collate import default_collate
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate, default_collate_fn_map

from monai import config
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
Expand Down Expand Up @@ -444,29 +444,18 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
return data


def collate_meta_tensor(batch):
def collate_meta_tensor(batch, *, collate_fn_map=None):
"""collate a sequence of meta tensor sequences/dictionaries into
a single batched metatensor or a dictionary of batched metatensor"""
if not isinstance(batch, Sequence):
raise NotImplementedError()
elem_0 = first(batch)
if isinstance(elem_0, MetaObj):
collated = default_collate(batch)
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
collated.meta = default_collate(meta_dicts)
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated
if isinstance(elem_0, Mapping):
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
if isinstance(elem_0, (tuple, list)):
return [collate_meta_tensor([d[i] for d in batch]) for i in range(len(elem_0))]

# no more recursive search for MetaTensor
return default_collate(batch)
collated = collate_tensor_fn(batch)
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]
collated.meta = default_collate(meta_dicts)
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated


def list_data_collate(batch: Sequence):
Expand All @@ -479,6 +468,9 @@ def list_data_collate(batch: Sequence):
Need to use this collate if apply some transforms that can generate batch data.
"""
from monai.data.meta_tensor import MetaTensor

default_collate_fn_map.update({MetaTensor: collate_meta_tensor})
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
key = None
Expand All @@ -490,9 +482,9 @@ def list_data_collate(batch: Sequence):
for k in elem:
key = k
data_for_batch = [d[key] for d in data]
ret[key] = collate_meta_tensor(data_for_batch)
ret[key] = default_collate(data_for_batch)
else:
ret = collate_meta_tensor(data)
ret = default_collate(data)
return ret
except RuntimeError as re:
re_str = str(re)
Expand Down

0 comments on commit 13a914c

Please sign in to comment.