diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 6fdd8d1081..5957a2d068 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -219,8 +219,8 @@ def __call__( `kwargs` supports other args for `Tensor.to()` API. """ image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs) - args_ = [] - kwargs_ = {} + args_: tuple = () + kwargs_: dict = {} def _get_data(key: str) -> torch.Tensor: data = batchdata[key] @@ -231,13 +231,13 @@ def _get_data(key: str) -> torch.Tensor: return data if isinstance(self.extra_keys, (str, list, tuple)): - for k in ensure_tuple(self.extra_keys): - args_.append(_get_data(k)) + args_ = tuple(_get_data(k) for k in ensure_tuple(self.extra_keys)) + elif isinstance(self.extra_keys, dict): - for k, v in self.extra_keys.items(): - kwargs_.update({k: _get_data(v)}) + kwargs_ = {k: _get_data(v) for k, v in self.extra_keys.items()} + - return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ + return cast(torch.Tensor, image), cast(torch.Tensor, label), args_, kwargs_ class DiffusionPrepareBatch(PrepareBatch):