This repository has been archived by the owner on Jan 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 28
/
compose.py
287 lines (234 loc) · 9.29 KB
/
compose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
from random import shuffle
from typing import Any, Callable, Mapping, Optional, Sequence, Union
import torch
from rising.random import ContinuousParameter, UniformParameter
from rising.transforms import AbstractTransform
from rising.utils import check_scalar
__all__ = ["Compose", "DropoutCompose", "OneOf"]
def dict_call(batch: dict, transform: Callable) -> Any:
"""
Unpacks the dict for every transformation
Args:
batch: current batch which is passed to transform
transform: transform to perform
Returns:
Any: transformed batch
"""
return transform(**batch)
class _TransformWrapper(torch.nn.Module):
"""
Helper Class to wrap all non-module transforms into modules to use the
torch.nn.ModuleList as container for the transforms. This enables
forwarding of all model specific calls as ``.to()`` to all transforms
"""
def __init__(self, trafo: Callable):
"""
Args:
trafo: the actual transform, which will be wrapped by this class.
Since this transform is no subclass of ``torch.nn.Module``,
its internal state won't be affected by module specific calls
"""
super().__init__()
self.trafo = trafo
def forward(self, *args, **kwargs) -> Any:
"""
Forwards calls to this wrapper to the internal transform
Args:
*args: positional arguments
**kwargs: keyword arguments
Returns:
Any: trafo return
"""
return self.trafo(*args, **kwargs)
class Compose(AbstractTransform):
"""
Compose multiple transforms
"""
def __init__(
self,
*transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
shuffle: bool = False,
transform_call: Callable[[Any, Callable], Any] = dict_call,
):
"""
Args:
transforms: one or multiple transformations which are applied
in consecutive order
shuffle: apply transforms in random order
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
"""
super().__init__(grad=True)
if len(transforms) > 0 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
if not transforms:
raise ValueError("At least one transformation needs to be selected.")
self.transforms = transforms
self.transform_call = transform_call
self.shuffle = shuffle
def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]:
"""
Apply transforms in a consecutive order. Can either handle
Sequence like or Mapping like data.
Args:
*seq_like: data which is unpacked like a Sequence
**map_like: data which is unpacked like a dict
Returns:
Union[Sequence, Mapping]: transformed data
"""
assert not (seq_like and map_like)
assert len(self.transforms) == len(self.transform_order)
data = seq_like if seq_like else map_like
if self.shuffle:
shuffle(self.transform_order)
for idx in self.transform_order:
data = self.transform_call(data, self.transforms[idx])
return data
@property
def transforms(self) -> torch.nn.ModuleList:
"""
Transforms getter
Returns:
torch.nn.ModuleList: transforms to compose
"""
return self._transforms
@transforms.setter
def transforms(self, transforms: Union[AbstractTransform, Sequence[AbstractTransform]]):
"""
Transforms setter
Args:
transforms: one or multiple transformations which are applied in
consecutive order
"""
# make transforms a list to be mutable.
# Otherwise the enforced typesetting below might fail.
if isinstance(transforms, tuple):
transforms = list(transforms)
for idx, trafo in enumerate(transforms):
if not isinstance(trafo, torch.nn.Module):
transforms[idx] = _TransformWrapper(trafo)
self._transforms = torch.nn.ModuleList(transforms)
self.transform_order = list(range(len(self.transforms)))
@property
def shuffle(self) -> bool:
"""
Getter for attribute shuffle
Returns:
bool: True if shuffle is enabled, False otherwise
"""
return self._shuffle
@shuffle.setter
def shuffle(self, shuffle: bool):
"""
Setter for shuffle
Args:
shuffle: new status of shuffle
"""
self._shuffle = shuffle
self.transform_order = list(range(len(self.transforms)))
class DropoutCompose(Compose):
"""
Compose multiple transforms to one and randomly apply them
"""
def __init__(
self,
*transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
dropout: Union[float, Sequence[float]] = 0.5,
shuffle: bool = False,
random_sampler: ContinuousParameter = None,
transform_call: Callable[[Any, Callable], Any] = dict_call,
**kwargs,
):
"""
Args:
*transforms: one or multiple transformations which are applied in
consecutive order
dropout: if provided as float, each transform is skipped with the
given probability
if :attr:`dropout` is a sequence, it needs to specify the
dropout probability for each given transform
shuffle: apply transforms in random order
random_sampler : a continuous parameter sampler. Samples a
random value for each of the transforms.
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
Raises:
ValueError: if dropout is a sequence it must have the same length
as transforms
"""
super().__init__(*transforms, transform_call=transform_call, shuffle=shuffle, **kwargs)
if random_sampler is None:
random_sampler = UniformParameter(0.0, 1.0)
self.register_sampler("prob", random_sampler, size=(len(self.transforms),))
if check_scalar(dropout):
dropout = [dropout] * len(self.transforms)
self.dropout = dropout
if len(dropout) != len(self.transforms):
raise TypeError(
f"If dropout is a sequence it must specify the "
f"dropout probability for each transform, "
f"found {len(dropout)} probabilities "
f"and {len(self.transforms)} transforms."
)
def forward(self, *seq_like, **map_like) -> Union[Sequence, Mapping]:
"""
Apply transforms in a consecutive order. Can either handle
Sequence like or Mapping like data.
Args:
*seq_like: data which is unpacked like a Sequence
**map_like: data which is unpacked like a dict
Returns:
Union[Sequence, Mapping]: dict with transformed data
"""
assert not (seq_like and map_like)
assert len(self.transforms) == len(self.transform_order)
data = seq_like if seq_like else map_like
rand = self.prob
for idx in self.transform_order:
if rand[idx] > self.dropout[idx]:
data = self.transform_call(data, self.transforms[idx])
return data
class OneOf(AbstractTransform):
"""
Apply one of the given transforms.
"""
def __init__(
self,
*transforms: Union[AbstractTransform, Sequence[AbstractTransform]],
weights: Optional[Sequence[float]] = None,
p: float = 1.0,
transform_call: Callable[[Any, Callable], Any] = dict_call,
):
"""
Args:
*transforms: transforms to choose from
weights: additional weights for transforms
p: probability that one transform i applied
transform_call: function which determines how transforms are
called. By default Mappings and Sequences are unpacked
during the transform.
"""
super().__init__(grad=True)
if len(transforms) > 0 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
if not transforms:
raise ValueError("At least one transformation needs to be selected.")
self.transforms = transforms
if weights is not None and len(weights) != len(transforms):
raise ValueError(
"If weights are porvided, every transform needs a weight. "
f"Found {len(weights)} weights and {len(transforms)} transforms"
)
if weights is None:
self.weights = torch.tensor([1 / len(self.transforms)] * len(self.transforms))
else:
self.weights = torch.tensor(weights)
self.p = p
self.transform_call = transform_call
def forward(self, **data) -> dict:
if torch.rand(1) < self.p:
index = torch.multinomial(self.weights, 1)
data = self.transform_call(data, self.transforms[int(index)])
return data