From 64e14e36bac11fdc42c964b42f469eedcb8a4696 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 17 Aug 2021 11:49:14 +0800 Subject: [PATCH 1/4] [DLMED] add ToDevice transform Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 ++++++++ monai/transforms/__init__.py | 4 +++ monai/transforms/utility/array.py | 23 +++++++++++++++ monai/transforms/utility/dictionary.py | 33 +++++++++++++++++++++ tests/test_to_device.py | 40 ++++++++++++++++++++++++++ tests/test_to_deviced.py | 33 +++++++++++++++++++++ 6 files changed, 145 insertions(+) create mode 100644 tests/test_to_device.py create mode 100644 tests/test_to_deviced.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a1bafaf103..f5cfa8b0c4 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -723,6 +723,12 @@ Utility :members: :special-members: __call__ +`ToDevice` +"""""""""" + .. autoclass:: ToDevice + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1352,6 +1358,12 @@ Utility (Dict) :members: :special-members: __call__ +`ToDeviced` +""""""""""" + .. autoclass:: ToDeviced + :members: + :special-members: __call__ + Transform Adaptors ------------------ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f259ff86bc..938d95bcb8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -385,6 +385,7 @@ SplitChannel, SqueezeDim, ToCupy, + ToDevice, ToNumpy, ToPIL, TorchVision, @@ -476,6 +477,9 @@ ToCupyd, ToCupyD, ToCupyDict, + ToDeviced, + ToDeviceD, + ToDeviceDict, ToNumpyd, ToNumpyD, ToNumpyDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fe73c6189c..e64268f46e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -67,6 +67,7 @@ "TorchVision", "MapLabelValue", "IntensityStats", + "ToDevice", ] @@ -1015,3 +1016,25 @@ def _compute(op: Callable, data: np.ndarray): raise ValueError("ops must be key string for predefined operations or callable function.") return img, meta_data + + +class ToDevice: + """ + Move PyTorch Tensor to the specified device. + It can help cache data into GPU and execute following logic on GPU directly. + + """ + + def __init__(self, device: Union[torch.device, str]) -> None: + """ + Args: + device: target device to move the Tensor, for example: "cuda:1". + + """ + self.device = device + + def __call__(self, img: torch.Tensor): + if not isinstance(img, torch.Tensor): + raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.") + + return img.to(self.device) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index fb9963601d..4fb3d68efe 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -49,6 +49,7 @@ SplitChannel, SqueezeDim, ToCupy, + ToDevice, ToNumpy, ToPIL, TorchVision, @@ -141,6 +142,9 @@ "ToCupyD", "ToCupyDict", "ToCupyd", + "ToDeviced", + "ToDeviceD", + "ToDeviceDict", "ToNumpyD", "ToNumpyDict", "ToNumpyd", @@ -1354,6 +1358,34 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]: return d +class ToDeviced(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. + """ + + def __init__( + self, + keys: KeysCollection, + device: Union[torch.device, str], + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + device: target device to move the Tensor, for example: "cuda:1". + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.converter = ToDevice(device=device) + + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -1389,3 +1421,4 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]: RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd +ToDeviceD = ToDeviceDict = ToDeviced diff --git a/tests/test_to_device.py b/tests/test_to_device.py new file mode 100644 index 0000000000..67b170bb01 --- /dev/null +++ b/tests/test_to_device.py @@ -0,0 +1,40 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import ToDevice +from tests.utils import skip_if_no_cuda + +TEST_CASE_1 = ["cuda:0"] + +TEST_CASE_2 = ["cuda"] + +TEST_CASE_3 = [torch.device("cpu:0")] + +TEST_CASE_4 = ["cpu"] + + +@skip_if_no_cuda +class TestToDevice(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_value(self, device): + converter = ToDevice(device=device) + data = torch.tensor([1, 2, 3, 4]) + ret = converter(data) + torch.testing.assert_allclose(ret, data.to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py new file mode 100644 index 0000000000..763d57b1b1 --- /dev/null +++ b/tests/test_to_deviced.py @@ -0,0 +1,33 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.data import CacheDataset, ThreadDataLoader +from monai.transforms import ToDeviced +from tests.utils import skip_if_no_cuda + + +@skip_if_no_cuda +class TestToDeviced(unittest.TestCase): + def test_value(self): + device = "cuda:0" + data = [{"img": torch.tensor(i)} for i in range(4)] + dataset = CacheDataset(data=data, transform=ToDeviced(keys="img", device=device), cache_rate=1.0) + dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) + for i, d in enumerate(dataloader): + torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device)) + + +if __name__ == "__main__": + unittest.main() From 9f97f0e3986d9a3ef43a3f14ae490307d594324a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 17 Aug 2021 12:01:33 +0800 Subject: [PATCH 2/4] [DLMED] fix type-hints Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 4fb3d68efe..7cab64cc2d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1379,7 +1379,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = ToDevice(device=device) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) From 23abe2a4646f254490049a8c13f6725ab2b33c43 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 17 Aug 2021 15:32:50 +0800 Subject: [PATCH 3/4] [DLMED] inherit Transform Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index e64268f46e..b54ad23d0e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1018,7 +1018,7 @@ def _compute(op: Callable, data: np.ndarray): return img, meta_data -class ToDevice: +class ToDevice(Transform): """ Move PyTorch Tensor to the specified device. It can help cache data into GPU and execute following logic on GPU directly. From c3bec64f5064f253f0834fb3aac7717d0ed6b16d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 18 Aug 2021 06:38:19 +0800 Subject: [PATCH 4/4] [DLMED] add kwargs Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 7 +++++-- monai/transforms/utility/dictionary.py | 5 ++++- tests/test_to_device.py | 2 +- tests/test_to_deviced.py | 6 +++++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9289f073c6..1871eedb6f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1031,16 +1031,19 @@ class ToDevice(Transform): """ - def __init__(self, device: Union[torch.device, str]) -> None: + def __init__(self, device: Union[torch.device, str], **kwargs) -> None: """ Args: device: target device to move the Tensor, for example: "cuda:1". + kwargs: other args for the PyTorch `Tensor.to()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.to.html. """ self.device = device + self.kwargs = kwargs def __call__(self, img: torch.Tensor): if not isinstance(img, torch.Tensor): raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.") - return img.to(self.device) + return img.to(self.device, **self.kwargs) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index fdb4459f0d..9c0a709bbf 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1368,6 +1368,7 @@ def __init__( keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, + **kwargs, ) -> None: """ Args: @@ -1375,9 +1376,11 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` device: target device to move the Tensor, for example: "cuda:1". allow_missing_keys: don't raise exception if key is missing. + kwargs: other args for the PyTorch `Tensor.to()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.Tensor.to.html. """ super().__init__(keys, allow_missing_keys) - self.converter = ToDevice(device=device) + self.converter = ToDevice(device=device, **kwargs) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) diff --git a/tests/test_to_device.py b/tests/test_to_device.py index 67b170bb01..9855a353f0 100644 --- a/tests/test_to_device.py +++ b/tests/test_to_device.py @@ -30,7 +30,7 @@ class TestToDevice(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value(self, device): - converter = ToDevice(device=device) + converter = ToDevice(device=device, non_blocking=True) data = torch.tensor([1, 2, 3, 4]) ret = converter(data) torch.testing.assert_allclose(ret, data.to(device)) diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py index 763d57b1b1..0d5d1d1cdc 100644 --- a/tests/test_to_deviced.py +++ b/tests/test_to_deviced.py @@ -23,7 +23,11 @@ class TestToDeviced(unittest.TestCase): def test_value(self): device = "cuda:0" data = [{"img": torch.tensor(i)} for i in range(4)] - dataset = CacheDataset(data=data, transform=ToDeviced(keys="img", device=device), cache_rate=1.0) + dataset = CacheDataset( + data=data, + transform=ToDeviced(keys="img", device=device, non_blocking=True), + cache_rate=1.0, + ) dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) for i, d in enumerate(dataloader): torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device))