Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Add compose test
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Feb 17, 2020
1 parent 7a1c8ba commit 50ce626
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/transforms/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from rising.transforms.spatial import Mirror
from rising.transforms.compose import Compose, DropoutCompose
from rising.transforms.compose import Compose, DropoutCompose, AbstractTransform


class TestCompose(unittest.TestCase):
Expand Down Expand Up @@ -47,6 +47,26 @@ def test_dropout_compose_error(self):
with self.assertRaises(TypeError):
compose = DropoutCompose(self.transforms, dropout=[1.0])

def test_device_dtype_change(self):
class DummyTrafo(AbstractTransform):
def __init__(self, a):
super().__init__(False)
self.register_buffer('tmp', a)

def __call__(self, *args, **kwargs):
return self.tmp

trafo_a = DummyTrafo(torch.tensor([1.], dtype=torch.float32))
trafo_a = trafo_a.to(torch.float32)
trafo_b = DummyTrafo(torch.tensor([2.], dtype=torch.float32))
trafo_b = trafo_b.to(torch.float32)
self.assertEquals(trafo_a.tmp.dtype, torch.float32)
self.assertEquals(trafo_b.tmp.dtype, torch.float32)
compose = Compose(trafo_a, trafo_b)
compose = compose.to(torch.float64)

self.assertEquals(compose.transforms[0].tmp.dtype, torch.float64)


if __name__ == '__main__':
unittest.main()

0 comments on commit 50ce626

Please sign in to comment.