Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@
AsChannelLast,
CastToType,
ConvertToMultiChannelBasedOnBratsClasses,
CopyToDevice,
DataStats,
FgBgToIndices,
Identity,
Expand Down Expand Up @@ -280,6 +281,9 @@
CopyItemsd,
CopyItemsD,
CopyItemsDict,
CopyToDeviced,
CopyToDeviceD,
CopyToDeviceDict,
DataStatsd,
DataStatsD,
DataStatsDict,
Expand Down
29 changes: 28 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices
from monai.utils import ensure_tuple, min_version, optional_import
from monai.utils import copy_to_device, ensure_tuple, min_version, optional_import

__all__ = [
"Identity",
Expand All @@ -44,6 +44,7 @@
"ConvertToMultiChannelBasedOnBratsClasses",
"AddExtremePointsChannel",
"TorchVision",
"CopyToDevice",
]

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
Expand Down Expand Up @@ -671,3 +672,29 @@ def __call__(self, img: torch.Tensor):

"""
return self.trans(img)


class CopyToDevice(Transform):
"""
Copy to ``device`` where possible.
"""

def __init__(
self,
device: Optional[Union[str, torch.device]],
non_blocking: bool = True,
verbose: bool = False,
) -> None:
self.device = device
self.non_blocking = non_blocking
self.verbose = verbose

def __call__(
self,
img: Union[torch.Tensor, np.ndarray],
) -> np.ndarray:
"""
Args:
img: the image to be moved to ``device``.
"""
return copy_to_device(img, self.device, self.non_blocking, self.verbose)
27 changes: 27 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AsChannelLast,
CastToType,
ConvertToMultiChannelBasedOnBratsClasses,
CopyToDevice,
DataStats,
FgBgToIndices,
Identity,
Expand Down Expand Up @@ -109,6 +110,9 @@
"TorchVisiond",
"TorchVisionD",
"TorchVisionDict",
"CopyToDeviceD",
"CopyToDeviceDict",
"CopyToDeviced",
]


Expand Down Expand Up @@ -801,6 +805,28 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
return d


class CopyToDeviced(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.CopyToDevice`.
"""

def __init__(
self,
keys: KeysCollection,
device: Optional[Union[str, torch.device]],
non_blocking: bool = True,
verbose: bool = False,
) -> None:
super().__init__(keys)
self.converter = CopyToDevice(device, non_blocking, verbose)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.keys:
d[key] = self.converter(d[key])
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand All @@ -823,3 +849,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
) = ConvertToMultiChannelBasedOnBratsClassesd
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
CopyToDeviceD = CopyToDeviceDict = CopyToDeviced
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .misc import (
MAX_SEED,
copy_to_device,
dtype_numpy_to_torch,
dtype_torch_to_numpy,
ensure_tuple,
Expand Down
43 changes: 42 additions & 1 deletion monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
# limitations under the License.

import collections.abc
import inspect
import itertools
import random
import types
import warnings
from ast import literal_eval
from distutils.util import strtobool
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand All @@ -37,6 +40,7 @@
"dtype_torch_to_numpy",
"dtype_numpy_to_torch",
"MAX_SEED",
"copy_to_device",
]

_seed = None
Expand Down Expand Up @@ -306,3 +310,40 @@ def dtype_torch_to_numpy(dtype):
def dtype_numpy_to_torch(dtype):
"""Convert a numpy dtype to its torch equivalent."""
return _np_to_torch_dtype[dtype]


def copy_to_device(
obj: Any,
device: Optional[Union[str, torch.device]],
non_blocking: bool = True,
verbose: bool = False,
) -> Any:
"""
Copy object or tuple/list/dictionary of objects to ``device``.

Args:
obj: object or tuple/list/dictionary of objects to move to ``device``.
device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``,
``cuda:0``, etc.) or of type ``torch.device``.
non_blocking_transfer: when `True`, moves data to device asynchronously if
possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.
verbose: when `True`, will print a warning for any elements of incompatible type
not copied to ``device``.
Returns:
Same as input, copied to ``device`` where possible. Original input will be
unchanged.
"""

if hasattr(obj, "to"):
return obj.to(device, non_blocking=non_blocking)
elif isinstance(obj, tuple):
return tuple(copy_to_device(o, device, non_blocking) for o in obj)
elif isinstance(obj, list):
return [copy_to_device(o, device, non_blocking) for o in obj]
elif isinstance(obj, dict):
return {k: copy_to_device(o, device, non_blocking) for k, o in obj.items()}
elif verbose:
fn_name = cast(types.FrameType, inspect.currentframe()).f_code.co_name
warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.")

return obj
49 changes: 49 additions & 0 deletions tests/test_array_copytodevice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.data import ArrayDataset
from monai.transforms import Compose, CopyToDevice, ToTensor
from tests.utils import skip_if_no_cuda

DEVICE = "cuda:0"

TEST_CASE_0 = [
Compose([ToTensor(), CopyToDevice(device=DEVICE)]),
Compose([ToTensor()]),
DEVICE,
"cpu",
]


@skip_if_no_cuda
class TestArrayCopyToDevice(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
def test_array_copy_to_device(self, img_transform, label_transform, img_device, label_device):
numel = 2
test_imgs = [np.zeros((3, 3, 3)) for _ in range(numel)]
test_segs = [np.zeros((3, 3, 3)) for _ in range(numel)]

test_labels = [1, 1]
dataset = ArrayDataset(test_imgs, img_transform, test_segs, label_transform, test_labels, None)
self.assertEqual(len(dataset), 2)
for data in dataset:
im, seg = data[0], data[1]
self.assertTrue(str(im.device) == img_device)
self.assertTrue(str(seg.device) == label_device)


if __name__ == "__main__":
unittest.main()
73 changes: 73 additions & 0 deletions tests/test_copy_to_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.utils import copy_to_device
from tests.utils import skip_if_no_cuda

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu:0"

TEST_CASE_TENSOR = [
torch.Tensor([1.0]).to(DEVICE),
"cuda:0",
"cpu",
]
TEST_CASE_LIST = [
2 * [torch.Tensor([1.0])],
"cpu",
"cuda:0",
]
TEST_CASE_TUPLE = [
2 * (torch.Tensor([1.0]),),
"cpu",
"cuda:0",
]
TEST_CASE_MIXED_LIST = [
[torch.Tensor([1.0]), np.array([1])],
"cpu",
"cuda:0",
]
TEST_CASE_DICT = [
{
"x": torch.Tensor([1.0]),
"y": 2 * [torch.Tensor([1.0])],
"z": np.array([1]),
},
"cpu",
"cuda:0",
]
TEST_CASES = [TEST_CASE_TENSOR, TEST_CASE_LIST, TEST_CASE_TUPLE, TEST_CASE_MIXED_LIST, TEST_CASE_DICT]


@skip_if_no_cuda
class TestCopyToDevice(unittest.TestCase):
def _check_on_device(self, obj, device):
if hasattr(obj, "device"):
self.assertTrue(str(obj.device) == device)
elif any(isinstance(obj, x) for x in [list, tuple]):
_ = [self._check_on_device(o, device) for o in obj]
elif isinstance(obj, dict):
_ = [self._check_on_device(o, device) for o in obj.values()]

@parameterized.expand(TEST_CASES)
def test_copy(self, input, in_device, out_device):
out = copy_to_device(input, out_device)
self._check_on_device(input, in_device)
self._check_on_device(out, out_device)


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions tests/test_dict_copytodevice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.data import Dataset
from monai.transforms import Compose, CopyToDeviced, ToTensord
from tests.utils import skip_if_no_cuda

DEVICE = "cuda:0"

TEST_CASE_0 = [
Compose([ToTensord(keys=["image", "label", "other"]), CopyToDeviced(keys=["image", "label"], device=DEVICE)]),
DEVICE,
"cpu",
]


@skip_if_no_cuda
class TestDictCopyToDevice(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
def test_dict_copy_to_device(self, transform, modified_device, unmodified_device):

numel = 2
test_data = [
{
"image": np.zeros((3, 3, 3)),
"label": np.zeros((3, 3, 3)),
"other": np.zeros((3, 3, 3)),
}
for _ in range(numel)
]

dataset = Dataset(data=test_data, transform=transform)
self.assertEqual(len(dataset), 2)
for data in dataset:
self.assertTrue(str(data["image"].device) == modified_device)
self.assertTrue(str(data["label"].device) == modified_device)
self.assertTrue(str(data["other"].device) == unmodified_device)


if __name__ == "__main__":
unittest.main()
Loading