Skip to content
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,12 @@ Utility
:members:
:special-members: __call__

`ToDevice`
""""""""""
.. autoclass:: ToDevice
:members:
:special-members: __call__


Dictionary Transforms
---------------------
Expand Down Expand Up @@ -1347,6 +1353,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ToDeviced`
"""""""""""
.. autoclass:: ToDeviced
:members:
:special-members: __call__


Transform Adaptors
------------------
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@
SplitChannel,
SqueezeDim,
ToCupy,
ToDevice,
ToNumpy,
ToPIL,
TorchVision,
Expand Down Expand Up @@ -468,6 +469,9 @@
ToCupyd,
ToCupyD,
ToCupyDict,
ToDeviced,
ToDeviceD,
ToDeviceDict,
ToNumpyd,
ToNumpyD,
ToNumpyDict,
Expand Down
26 changes: 26 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"TorchVision",
"MapLabelValue",
"IntensityStats",
"ToDevice",
]


Expand Down Expand Up @@ -1021,3 +1022,28 @@ def _compute(op: Callable, data: np.ndarray):
raise ValueError("ops must be key string for predefined operations or callable function.")

return img, meta_data


class ToDevice(Transform):
"""
Move PyTorch Tensor to the specified device.
It can help cache data into GPU and execute following logic on GPU directly.

"""

def __init__(self, device: Union[torch.device, str], **kwargs) -> None:
"""
Args:
device: target device to move the Tensor, for example: "cuda:1".
kwargs: other args for the PyTorch `Tensor.to()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.

"""
self.device = device
self.kwargs = kwargs

def __call__(self, img: torch.Tensor):
if not isinstance(img, torch.Tensor):
raise ValueError("img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.")

return img.to(self.device, **self.kwargs)
36 changes: 36 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SplitChannel,
SqueezeDim,
ToCupy,
ToDevice,
ToNumpy,
ToPIL,
TorchVision,
Expand Down Expand Up @@ -141,6 +142,9 @@
"ToCupyD",
"ToCupyDict",
"ToCupyd",
"ToDeviced",
"ToDeviceD",
"ToDeviceDict",
"ToNumpyD",
"ToNumpyDict",
"ToNumpyd",
Expand Down Expand Up @@ -1354,6 +1358,37 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]:
return d


class ToDeviced(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`.
"""

def __init__(
self,
keys: KeysCollection,
device: Union[torch.device, str],
allow_missing_keys: bool = False,
**kwargs,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
device: target device to move the Tensor, for example: "cuda:1".
allow_missing_keys: don't raise exception if key is missing.
kwargs: other args for the PyTorch `Tensor.to()` API, for more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.
"""
super().__init__(keys, allow_missing_keys)
self.converter = ToDevice(device=device, **kwargs)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand Down Expand Up @@ -1389,3 +1424,4 @@ def __call__(self, data) -> Dict[Hashable, np.ndarray]:
RandLambdaD = RandLambdaDict = RandLambdad
MapLabelValueD = MapLabelValueDict = MapLabelValued
IntensityStatsD = IntensityStatsDict = IntensityStatsd
ToDeviceD = ToDeviceDict = ToDeviced
40 changes: 40 additions & 0 deletions tests/test_to_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import ToDevice
from tests.utils import skip_if_no_cuda

TEST_CASE_1 = ["cuda:0"]

TEST_CASE_2 = ["cuda"]

TEST_CASE_3 = [torch.device("cpu:0")]

TEST_CASE_4 = ["cpu"]


@skip_if_no_cuda
class TestToDevice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, device):
converter = ToDevice(device=device, non_blocking=True)
data = torch.tensor([1, 2, 3, 4])
ret = converter(data)
torch.testing.assert_allclose(ret, data.to(device))


if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions tests/test_to_deviced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from monai.data import CacheDataset, ThreadDataLoader
from monai.transforms import ToDeviced
from tests.utils import skip_if_no_cuda


@skip_if_no_cuda
class TestToDeviced(unittest.TestCase):
def test_value(self):
device = "cuda:0"
data = [{"img": torch.tensor(i)} for i in range(4)]
dataset = CacheDataset(
data=data,
transform=ToDeviced(keys="img", device=device, non_blocking=True),
cache_rate=1.0,
)
dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1)
for i, d in enumerate(dataloader):
torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device))


if __name__ == "__main__":
unittest.main()