Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
t and wrap non-module trafos
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Feb 17, 2020
1 parent 50ce626 commit 5053ca7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
33 changes: 32 additions & 1 deletion rising/transforms/compose.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Sequence, Union, Callable, Any, Mapping
from rising.utils import check_scalar
from rising.transforms import AbstractTransform, RandomProcess
import torch


__all__ = ["Compose", "DropoutCompose"]
Expand All @@ -25,6 +26,31 @@ def dict_call(batch: dict, transform: Callable) -> Any:
return transform(**batch)


class _TransformWrapper(torch.nn.Module):
def __init__(self, trafo: Callable):
"""
Helper Class to wrap all non-module transforms into modules to use the
torch.nn.ModuleList as container for the transforms. This enables
forwarding of all model specific calls as ``.to()`` to all transforms
Parameters
----------
trafo : Callable
the actual transform, which will be wrapped by this class.
Since this transform is no subclass of ``torch.nn.Module``,
its internal state won't be affected by module specific calls
"""
super().__init__()

self.trafo = trafo

def forward(self, *args, **kwargs) -> Any:
"""
Forwards calls to this wrapper to the internal transform
"""
return self.trafo(*args, **kwargs)


class Compose(AbstractTransform):
def __init__(self, *transforms,
transform_call: Callable[[Any, Callable], Any] = dict_call):
Expand All @@ -42,7 +68,12 @@ def __init__(self, *transforms,
super().__init__(grad=True)
if isinstance(transforms[0], Sequence):
transforms = transforms[0]
self.transforms = transforms

for idx, trafo in enumerate(transforms):
if not isinstance(trafo, torch.nn.Module):
transforms[idx] = _TransformWrapper(trafo)

self.transforms = torch.nn.ModuleList(transforms)
self.transform_call = transform_call

def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]:
Expand Down
17 changes: 16 additions & 1 deletion tests/transforms/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch

from rising.transforms.spatial import Mirror
from rising.transforms.compose import Compose, DropoutCompose, AbstractTransform
from rising.transforms.compose import Compose, DropoutCompose, \
AbstractTransform, _TransformWrapper


class TestCompose(unittest.TestCase):
Expand Down Expand Up @@ -67,6 +68,20 @@ def __call__(self, *args, **kwargs):

self.assertEquals(compose.transforms[0].tmp.dtype, torch.float64)

def test_wrapping_non_module_trafos(self):
class DummyTrafo:
def __init__(self):
self.a = 5

def __call__(self, *args, **kwargs):
return 5

dummy_trafo = DummyTrafo()

compose = Compose([dummy_trafo])
self.assertIsInstance(compose.transforms[0], _TransformWrapper)
self.assertIsInstance(compose.transforms[0].trafo, DummyTrafo)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5053ca7

Please sign in to comment.