diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9eaedd6b15..874edbee42 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -241,6 +241,7 @@ AsChannelLast, CastToType, ConvertToMultiChannelBasedOnBratsClasses, + CopyToDevice, DataStats, FgBgToIndices, Identity, @@ -280,6 +281,9 @@ CopyItemsd, CopyItemsD, CopyItemsDict, + CopyToDeviced, + CopyToDeviceD, + CopyToDeviceDict, DataStatsd, DataStatsD, DataStatsDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5476e800f4..97fa78ca6f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices -from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils import copy_to_device, ensure_tuple, min_version, optional_import __all__ = [ "Identity", @@ -44,6 +44,7 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "CopyToDevice", ] # Generic type which can represent either a numpy.ndarray or a torch.Tensor @@ -671,3 +672,29 @@ def __call__(self, img: torch.Tensor): """ return self.trans(img) + + +class CopyToDevice(Transform): + """ + Copy to ``device`` where possible. + """ + + def __init__( + self, + device: Optional[Union[str, torch.device]], + non_blocking: bool = True, + verbose: bool = False, + ) -> None: + self.device = device + self.non_blocking = non_blocking + self.verbose = verbose + + def __call__( + self, + img: Union[torch.Tensor, np.ndarray], + ) -> np.ndarray: + """ + Args: + img: the image to be moved to ``device``. + """ + return copy_to_device(img, self.device, self.non_blocking, self.verbose) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1427f24356..091a9de8a5 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -30,6 +30,7 @@ AsChannelLast, CastToType, ConvertToMultiChannelBasedOnBratsClasses, + CopyToDevice, DataStats, FgBgToIndices, Identity, @@ -109,6 +110,9 @@ "TorchVisiond", "TorchVisionD", "TorchVisionDict", + "CopyToDeviceD", + "CopyToDeviceDict", + "CopyToDeviced", ] @@ -801,6 +805,28 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class CopyToDeviced(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.CopyToDevice`. + """ + + def __init__( + self, + keys: KeysCollection, + device: Optional[Union[str, torch.device]], + non_blocking: bool = True, + verbose: bool = False, + ) -> None: + super().__init__(keys) + self.converter = CopyToDevice(device, non_blocking, verbose) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + d[key] = self.converter(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -823,3 +849,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc ) = ConvertToMultiChannelBasedOnBratsClassesd AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond +CopyToDeviceD = CopyToDeviceDict = CopyToDeviced diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 9bb25d723a..6430fae75a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -32,6 +32,7 @@ ) from .misc import ( MAX_SEED, + copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, ensure_tuple, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bf1ff60cbc..2b31392a46 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -10,11 +10,14 @@ # limitations under the License. import collections.abc +import inspect import itertools import random +import types +import warnings from ast import literal_eval from distutils.util import strtobool -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -37,6 +40,7 @@ "dtype_torch_to_numpy", "dtype_numpy_to_torch", "MAX_SEED", + "copy_to_device", ] _seed = None @@ -306,3 +310,40 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" return _np_to_torch_dtype[dtype] + + +def copy_to_device( + obj: Any, + device: Optional[Union[str, torch.device]], + non_blocking: bool = True, + verbose: bool = False, +) -> Any: + """ + Copy object or tuple/list/dictionary of objects to ``device``. + + Args: + obj: object or tuple/list/dictionary of objects to move to ``device``. + device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``, + ``cuda:0``, etc.) or of type ``torch.device``. + non_blocking_transfer: when `True`, moves data to device asynchronously if + possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. + verbose: when `True`, will print a warning for any elements of incompatible type + not copied to ``device``. + Returns: + Same as input, copied to ``device`` where possible. Original input will be + unchanged. + """ + + if hasattr(obj, "to"): + return obj.to(device, non_blocking=non_blocking) + elif isinstance(obj, tuple): + return tuple(copy_to_device(o, device, non_blocking) for o in obj) + elif isinstance(obj, list): + return [copy_to_device(o, device, non_blocking) for o in obj] + elif isinstance(obj, dict): + return {k: copy_to_device(o, device, non_blocking) for k, o in obj.items()} + elif verbose: + fn_name = cast(types.FrameType, inspect.currentframe()).f_code.co_name + warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.") + + return obj diff --git a/tests/test_array_copytodevice.py b/tests/test_array_copytodevice.py new file mode 100644 index 0000000000..ee73cad753 --- /dev/null +++ b/tests/test_array_copytodevice.py @@ -0,0 +1,49 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import ArrayDataset +from monai.transforms import Compose, CopyToDevice, ToTensor +from tests.utils import skip_if_no_cuda + +DEVICE = "cuda:0" + +TEST_CASE_0 = [ + Compose([ToTensor(), CopyToDevice(device=DEVICE)]), + Compose([ToTensor()]), + DEVICE, + "cpu", +] + + +@skip_if_no_cuda +class TestArrayCopyToDevice(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) + def test_array_copy_to_device(self, img_transform, label_transform, img_device, label_device): + numel = 2 + test_imgs = [np.zeros((3, 3, 3)) for _ in range(numel)] + test_segs = [np.zeros((3, 3, 3)) for _ in range(numel)] + + test_labels = [1, 1] + dataset = ArrayDataset(test_imgs, img_transform, test_segs, label_transform, test_labels, None) + self.assertEqual(len(dataset), 2) + for data in dataset: + im, seg = data[0], data[1] + self.assertTrue(str(im.device) == img_device) + self.assertTrue(str(seg.device) == label_device) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_copy_to_device.py b/tests/test_copy_to_device.py new file mode 100644 index 0000000000..d9e0afc8f6 --- /dev/null +++ b/tests/test_copy_to_device.py @@ -0,0 +1,73 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils import copy_to_device +from tests.utils import skip_if_no_cuda + +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu:0" + +TEST_CASE_TENSOR = [ + torch.Tensor([1.0]).to(DEVICE), + "cuda:0", + "cpu", +] +TEST_CASE_LIST = [ + 2 * [torch.Tensor([1.0])], + "cpu", + "cuda:0", +] +TEST_CASE_TUPLE = [ + 2 * (torch.Tensor([1.0]),), + "cpu", + "cuda:0", +] +TEST_CASE_MIXED_LIST = [ + [torch.Tensor([1.0]), np.array([1])], + "cpu", + "cuda:0", +] +TEST_CASE_DICT = [ + { + "x": torch.Tensor([1.0]), + "y": 2 * [torch.Tensor([1.0])], + "z": np.array([1]), + }, + "cpu", + "cuda:0", +] +TEST_CASES = [TEST_CASE_TENSOR, TEST_CASE_LIST, TEST_CASE_TUPLE, TEST_CASE_MIXED_LIST, TEST_CASE_DICT] + + +@skip_if_no_cuda +class TestCopyToDevice(unittest.TestCase): + def _check_on_device(self, obj, device): + if hasattr(obj, "device"): + self.assertTrue(str(obj.device) == device) + elif any(isinstance(obj, x) for x in [list, tuple]): + _ = [self._check_on_device(o, device) for o in obj] + elif isinstance(obj, dict): + _ = [self._check_on_device(o, device) for o in obj.values()] + + @parameterized.expand(TEST_CASES) + def test_copy(self, input, in_device, out_device): + out = copy_to_device(input, out_device) + self._check_on_device(input, in_device) + self._check_on_device(out, out_device) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dict_copytodevice.py b/tests/test_dict_copytodevice.py new file mode 100644 index 0000000000..cdcb645c04 --- /dev/null +++ b/tests/test_dict_copytodevice.py @@ -0,0 +1,54 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import Dataset +from monai.transforms import Compose, CopyToDeviced, ToTensord +from tests.utils import skip_if_no_cuda + +DEVICE = "cuda:0" + +TEST_CASE_0 = [ + Compose([ToTensord(keys=["image", "label", "other"]), CopyToDeviced(keys=["image", "label"], device=DEVICE)]), + DEVICE, + "cpu", +] + + +@skip_if_no_cuda +class TestDictCopyToDevice(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) + def test_dict_copy_to_device(self, transform, modified_device, unmodified_device): + + numel = 2 + test_data = [ + { + "image": np.zeros((3, 3, 3)), + "label": np.zeros((3, 3, 3)), + "other": np.zeros((3, 3, 3)), + } + for _ in range(numel) + ] + + dataset = Dataset(data=test_data, transform=transform) + self.assertEqual(len(dataset), 2) + for data in dataset: + self.assertTrue(str(data["image"].device) == modified_device) + self.assertTrue(str(data["label"].device) == modified_device) + self.assertTrue(str(data["other"].device) == unmodified_device) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dict_copytodevice_timecomparison.py b/tests/test_dict_copytodevice_timecomparison.py new file mode 100644 index 0000000000..aba907d8c4 --- /dev/null +++ b/tests/test_dict_copytodevice_timecomparison.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import unittest + +import torch +from torch.utils.data import DataLoader + +from monai.apps import MedNISTDataset +from monai.networks.nets import densenet121 +from monai.transforms import Compose, CopyToDeviced, ToTensord, LoadImaged, AddChanneld +from tests.utils import skip_if_no_cuda + +# This test is only run with cuda +DEVICE = "cuda:0" + +@skip_if_no_cuda +class TestDictCopyToDeviceTimeComparison(unittest.TestCase): + + @staticmethod + def get_data(use_copy_to_device_transform): + + root_dir = os.environ.get("MONAI_DATA_DIRECTORY") + if not root_dir: + root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + + transforms = Compose( + [ + LoadImaged(keys="image"), + AddChanneld(keys="image"), + ToTensord(keys="image"), + ] + ) + # If necessary, append the transform + if use_copy_to_device_transform: + transforms.transforms = transforms.transforms + (CopyToDeviced(keys="image", device=DEVICE),) + + train_ds = MedNISTDataset( + root_dir=root_dir, + transform=transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) + num_classes = train_ds.get_num_classes() + + model = densenet121(spatial_dims=2, in_channels=1, out_channels=num_classes).to(DEVICE) + + return train_loader, model + + def test_dict_copy_to_device_time_comparison(self): + + + for use_copy_transform in [True, False]: + start_time = time.time() + + train_loader, model = self.get_data(use_copy_transform) + + model.train() + for batch_data in train_loader: + inputs, labels = batch_data["image"], batch_data["label"] + # If using the copy transform, check they're on the GPU + if use_copy_transform: + self.assertEqual(str(inputs.device), DEVICE) + # Assert not already on device, and then copy them there + else: + self.assertNotEqual(str(inputs.device), DEVICE) + inputs = inputs.to(DEVICE) + labels = labels.to(DEVICE) + + loss_function = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), 1e-5) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + + print(f"--- {time.time() - start_time} seconds ---") + + +if __name__ == "__main__": + unittest.main()