-
Notifications
You must be signed in to change notification settings - Fork 1k
/
compose.py
301 lines (233 loc) · 12.3 KB
/
compose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of generic interfaces for MONAI transforms.
"""
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Hashable, Optional, Sequence, Tuple, Union
import numpy as np
from monai.config import KeysCollection
from monai.transforms.utils import apply_transform
from monai.utils import MAX_SEED, ensure_tuple, get_seed
__all__ = ["Transform", "Randomizable", "Compose", "MapTransform"]
class Transform(ABC):
"""
An abstract class of a ``Transform``.
A transform is callable that processes ``data``.
It could be stateful and may modify ``data`` in place,
the implementation should be aware of:
#. thread safety when mutating its own states.
When used from a multi-process context, transform's instance variables are read-only.
#. ``data`` content unused by this transform may still be used in the
subsequent transforms in a composed transform.
#. storing too much information in ``data`` may not scale.
See Also
:py:class:`monai.transforms.Compose`
"""
@abstractmethod
def __call__(self, data: Any):
"""
``data`` is an element which often comes from an iteration over an
iterable, such as :py:class:`torch.utils.data.Dataset`. This method should
return an updated version of ``data``.
To simplify the input validations, most of the transforms assume that
- ``data`` is a Numpy ndarray, PyTorch Tensor or string
- the data shape can be:
#. string data without shape, `LoadImage` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``
- the channel dimension is not omitted even if number of channels is one
This method can optionally take additional arguments to help execute transformation operation.
Raises:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
class Randomizable(ABC):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
This is mainly for randomized data augmentation transforms. For example::
class RandShiftIntensity(Randomizable):
def randomize():
self._offset = self.R.uniform(low=0, high=100)
def __call__(self, img):
self.randomize()
return img + self._offset
transform = RandShiftIntensity()
transform.set_random_state(seed=0)
"""
R: np.random.RandomState = np.random.RandomState()
def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
"""
Set the random state locally, to control the randomness, the derived
classes should use :py:attr:`self.R` instead of `np.random` to introduce random
factors.
Args:
seed: set the random state with an integer seed.
state: set the random state with a `np.random.RandomState` object.
Raises:
TypeError: When ``state`` is not an ``Optional[np.random.RandomState]``.
Returns:
a Randomizable instance.
"""
if seed is not None:
_seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed
_seed = _seed % MAX_SEED
self.R = np.random.RandomState(_seed)
return self
if state is not None:
if not isinstance(state, np.random.RandomState):
raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.")
self.R = state
return self
self.R = np.random.RandomState()
return self
@abstractmethod
def randomize(self, data: Any) -> None:
"""
Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors.
all :py:attr:`self.R` calls happen here so that we have a better chance to
identify errors of sync the random state.
This method can generate the random factors based on properties of the input data.
Raises:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
class Compose(Randomizable, Transform):
"""
``Compose`` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
return a single value, so that the transforms can be called in a chain.
``Compose`` can be used in two ways:
#. With a series of transforms that accept and return a single
ndarray / tensor / tensor-like parameter.
#. With a series of transforms that accept and return a dictionary that
contains one or more parameters. Such transforms must have pass-through
semantics; unused values in the dictionary must be copied to the return
dictionary. It is required that the dictionary is copied between input
and output of each transform.
If some transform generates a list batch of data in the transform chain,
every item in the list is still a dictionary, and all the following
transforms will apply to every item of the list, for example:
#. transformA normalizes the intensity of 'img' field in the dict data.
#. transformB crops out a list batch of images on 'img' and 'seg' field.
And constructs a list of dict data, other fields are copied::
{ [{ {
'img': [1, 2], 'img': [1], 'img': [2],
'seg': [1, 2], 'seg': [1], 'seg': [2],
'extra': 123, --> 'extra': 123, 'extra': 123,
'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD'
} }, }]
#. transformC then randomly rotates or flips 'img' and 'seg' fields of
every dictionary item in the list.
The composed transforms will be set the same global random seed if user called
`set_determinism()`.
When using the pass-through dictionary operation, you can make use of
:class:`monai.transforms.adaptors.adaptor` to wrap transforms that don't conform
to the requirements. This approach allows you to use transforms from
otherwise incompatible libraries with minimal additional work.
Note:
In many cases, Compose is not the best way to create pre-processing
pipelines. Pre-processing is often not a strictly sequential series of
operations, and much of the complexity arises when a not-sequential
set of functions must be called as if it were a sequence.
Example: images and labels
Images typically require some kind of normalization that labels do not.
Both are then typically augmented through the use of random rotations,
flips, and deformations.
Compose can be used with a series of transforms that take a dictionary
that contains 'image' and 'label' entries. This might require wrapping
`torchvision` transforms before passing them to compose.
Alternatively, one can create a class with a `__call__` function that
calls your pre-processing functions taking into account that not all of
them are called on the labels.
"""
def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None) -> None:
if transforms is None:
transforms = []
self.transforms = ensure_tuple(transforms)
self.set_random_state(seed=get_seed())
def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
super().set_random_state(seed=seed, state=state)
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
continue
_transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
return self
def randomize(self, data: Optional[Any] = None) -> None:
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
continue
try:
_transform.randomize(data)
except TypeError as type_error:
tfm_name: str = type(_transform).__name__
warnings.warn(
f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning
)
def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_)
return input_
class MapTransform(Transform):
"""
A subclass of :py:class:`monai.transforms.Transform` with an assumption
that the ``data`` input of ``self.__call__`` is a MutableMapping such as ``dict``.
The ``keys`` parameter will be used to get and set the actual data
item to transform. That is, the callable of this transform should
follow the pattern:
.. code-block:: python
def __call__(self, data):
for key in self.keys:
if key in data:
# update output data with some_transform_function(data[key]).
else:
# do nothing or some exceptions handling.
return data
Raises:
ValueError: When ``keys`` is an empty iterable.
TypeError: When ``keys`` type is not in ``Union[Hashable, Iterable[Hashable]]``.
"""
def __init__(self, keys: KeysCollection) -> None:
self.keys: Tuple[Hashable, ...] = ensure_tuple(keys)
if not self.keys:
raise ValueError("keys must be non empty.")
for key in self.keys:
if not isinstance(key, Hashable):
raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.")
@abstractmethod
def __call__(self, data):
"""
``data`` often comes from an iteration over an iterable,
such as :py:class:`torch.utils.data.Dataset`.
To simplify the input validations, this method assumes:
- ``data`` is a Python dictionary
- ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element
of ``self.keys``, the data shape can be:
#. string data without shape, `LoadImaged` transform expects file paths
#. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,
except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and
`AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels)
#. most of the post-processing transforms expect
``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])``
- the channel dimension is not omitted even if number of channels is one
Raises:
NotImplementedError: When the subclass does not override this method.
returns:
An updated dictionary version of ``data`` by applying the transform.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")