Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
to_device
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Dec 28, 2022
1 parent d7c0753 commit 2b951cf
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
4 changes: 3 additions & 1 deletion rising/loading/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
multiprocessing_context=None,
auto_convert: bool = True,
transform_call: Callable[[Any, Callable], Any] = default_transform_call,
to_gpu_trafo: Optional[ToDevice] = None,
**kwargs
):
"""
Expand Down Expand Up @@ -161,6 +162,7 @@ def __init__(
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked during
the transform.
to_gpu_trafo: if set to ``None``, only 'data' key will be moved to gpu.
"""
super().__init__(
dataset=dataset,
Expand Down Expand Up @@ -194,7 +196,7 @@ def __init__(
if device is None:
device = torch.cuda.current_device()

to_gpu_trafo = ToDevice(device=device, non_blocking=pin_memory)
to_gpu_trafo = to_gpu_trafo or ToDevice(device=device, non_blocking=pin_memory)

gpu_transforms = Compose(to_gpu_trafo, gpu_transforms)
gpu_transforms = gpu_transforms.to(device)
Expand Down
35 changes: 34 additions & 1 deletion tests/loading/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from functools import partial
from typing import Mapping, Sequence
from unittest.mock import Mock, patch

Expand All @@ -15,7 +16,17 @@
_SingleProcessDataLoaderIter,
default_transform_call,
)
from rising.transforms import Mirror
from rising.transforms import Mirror, BaseTransform, Compose, ToDevice


def check_on_device(x: torch.Tensor, device: str):
assert x.device.kind == device
return x


class DeviceChecker(BaseTransform):
def __init__(self, device: str, keys=('data',)):
super().__init__(augment_fn=partial(check_on_device, device=device), keys=keys)


class TestLoader(unittest.TestCase):
Expand Down Expand Up @@ -74,6 +85,28 @@ def check_output_device(self, device):
expected = data[None].flip([2]).to(device=device)
self.assertTrue(torch.allclose(outp["data"], expected))

@unittest.skipUnless(torch.cuda.is_available(), "No cuda gpu available")
def test_data_moved_to_gpu(self):
data = [
{'data': 1, 'label': 1},
{'data': 2, 'label': 2},
{'data': 3, 'label': 3}
]
loader = DataLoader(data, gpu_transforms=Compose[DeviceChecker(keys=("data",), device="cuda"), DeviceChecker(keys=("label",), device="cpu")])
for x in loader:
pass

@unittest.skipUnless(torch.cuda.is_available(), "No cuda gpu available")
def test_label_and_data_moved_to_gpu(self):
data = [
{'data': 1, 'label': 1},
{'data': 2, 'label': 2},
{'data': 3, 'label': 3}
]
loader = DataLoader(data, gpu_transforms=DeviceChecker(keys=("data", "label"), device="cuda"), to_gpu_trafo=ToDevice(device="cuda", keys=("data", "label")))
for x in loader:
pass


class BatchTransformerTest(unittest.TestCase):
def check_batch_transformer(self, collate_output):
Expand Down

0 comments on commit 2b951cf

Please sign in to comment.