diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 65b35b0dc8..9a538aae82 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -718,6 +718,12 @@ Utility :members: :special-members: __call__ +`ToDevice` +"""""""""" + .. autoclass:: ToDevice + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1347,6 +1353,12 @@ Utility (Dict) :members: :special-members: __call__ +`ToDeviced` +""""""""""" + .. autoclass:: ToDeviced + :members: + :special-members: __call__ + Transform Adaptors ------------------ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b3b3b15a1f..180a690e5d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -377,6 +377,7 @@ SplitChannel, SqueezeDim, ToCupy, + ToDevice, ToNumpy, ToPIL, TorchVision, @@ -468,6 +469,9 @@ ToCupyd, ToCupyD, ToCupyDict, + ToDeviced, + ToDeviceD, + ToDeviceDict, ToNumpyd, ToNumpyD, ToNumpyDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index c41983787d..1871eedb6f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -73,6 +73,7 @@ "TorchVision", "MapLabelValue", "IntensityStats", + "ToDevice", ] @@ -1021,3 +1022,28 @@ def _compute(op: Callable, data: np.ndarray): raise ValueError("ops must be key string for predefined operations or callable function.") return img, meta_data + + +class ToDevice(Transform): + """ + Move PyTorch Tensor to the specified device. + It can help cache data into GPU and execute following logic on GPU directly. + + """ + + def __init__(self, device: Union[torch.device, str], **kwargs) -> None: + """ + Args: + device: target device to move the Tensor, for example: "cuda:1". + kwargs: other args for the PyTorch `Tensor.to()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.to.html. + + """ + self.device = device + self.kwargs = kwargs + + def __call__(self, img: torch.Tensor): + if not isinstance(img, torch.Tensor): + raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.") + + return img.to(self.device, **self.kwargs) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 67fe7653e0..9c0a709bbf 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -49,6 +49,7 @@ SplitChannel, SqueezeDim, ToCupy, + ToDevice, ToNumpy, ToPIL, TorchVision, @@ -141,6 +142,9 @@ "ToCupyD", "ToCupyDict", "ToCupyd", + "ToDeviced", + "ToDeviceD", + "ToDeviceDict", "ToNumpyD", "ToNumpyDict", "ToNumpyd", @@ -1354,6 +1358,37 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]: return d +class ToDeviced(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. + """ + + def __init__( + self, + keys: KeysCollection, + device: Union[torch.device, str], + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + device: target device to move the Tensor, for example: "cuda:1". + allow_missing_keys: don't raise exception if key is missing. + kwargs: other args for the PyTorch `Tensor.to()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.to.html. + """ + super().__init__(keys, allow_missing_keys) + self.converter = ToDevice(device=device, **kwargs) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -1389,3 +1424,4 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]: RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd +ToDeviceD = ToDeviceDict = ToDeviced diff --git a/tests/test_to_device.py b/tests/test_to_device.py new file mode 100644 index 0000000000..9855a353f0 --- /dev/null +++ b/tests/test_to_device.py @@ -0,0 +1,40 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import ToDevice +from tests.utils import skip_if_no_cuda + +TEST_CASE_1 = ["cuda:0"] + +TEST_CASE_2 = ["cuda"] + +TEST_CASE_3 = [torch.device("cpu:0")] + +TEST_CASE_4 = ["cpu"] + + +@skip_if_no_cuda +class TestToDevice(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_value(self, device): + converter = ToDevice(device=device, non_blocking=True) + data = torch.tensor([1, 2, 3, 4]) + ret = converter(data) + torch.testing.assert_allclose(ret, data.to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py new file mode 100644 index 0000000000..0d5d1d1cdc --- /dev/null +++ b/tests/test_to_deviced.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.data import CacheDataset, ThreadDataLoader +from monai.transforms import ToDeviced +from tests.utils import skip_if_no_cuda + + +@skip_if_no_cuda +class TestToDeviced(unittest.TestCase): + def test_value(self): + device = "cuda:0" + data = [{"img": torch.tensor(i)} for i in range(4)] + dataset = CacheDataset( + data=data, + transform=ToDeviced(keys="img", device=device, non_blocking=True), + cache_rate=1.0, + ) + dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) + for i, d in enumerate(dataloader): + torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device)) + + +if __name__ == "__main__": + unittest.main()