/
base.py
210 lines (183 loc) · 7.92 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
import abc
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Sequence
from typing import TypeVar
import apache_beam as beam
__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation']
TransformedDatasetT = TypeVar('TransformedDatasetT')
TransformedMetadataT = TypeVar('TransformedMetadataT')
# Input/Output types to the MLTransform.
ExampleT = TypeVar('ExampleT')
MLTransformOutputT = TypeVar('MLTransformOutputT')
# Input to the apply() method of BaseOperation.
OperationInputT = TypeVar('OperationInputT')
# Output of the apply() method of BaseOperation.
OperationOutputT = TypeVar('OperationOutputT')
class ArtifactMode(object):
PRODUCE = 'produce'
CONSUME = 'consume'
class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC):
def __init__(self, columns: List[str]) -> None:
"""
Base Opertation class data processing transformations.
Args:
columns: List of column names to apply the transformation.
"""
self.columns = columns
@abc.abstractmethod
def apply_transform(self, data: OperationInputT,
output_column_name: str) -> Dict[str, OperationOutputT]:
"""
Define any processing logic in the apply_transform() method.
processing logics are applied on inputs and returns a transformed
output.
Args:
inputs: input data.
"""
@abc.abstractmethod
def get_artifacts(
self, data: OperationInputT,
output_column_prefix: str) -> Optional[Dict[str, OperationOutputT]]:
"""
If the operation generates any artifacts, they can be returned from this
method.
"""
pass
def __call__(self, data: OperationInputT,
output_column_name: str) -> Dict[str, OperationOutputT]:
"""
This method is called when the instance of the class is called.
This method will invoke the apply() method of the class.
"""
transformed_data = self.apply_transform(data, output_column_name)
artifacts = self.get_artifacts(data, output_column_name)
if artifacts:
transformed_data = {**transformed_data, **artifacts}
return transformed_data
class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC):
"""
Only for internal use. No backwards compatibility guarantees.
"""
@abc.abstractmethod
def process_data(
self, pcoll: beam.PCollection[ExampleT]
) -> beam.PCollection[MLTransformOutputT]:
"""
Logic to process the data. This will be the entrypoint in
beam.MLTransform to process incoming data.
"""
@abc.abstractmethod
def append_transform(self, transform: BaseOperation):
"""
Append transforms to the ProcessHandler.
"""
class MLTransform(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[MLTransformOutputT]],
Generic[ExampleT, MLTransformOutputT]):
def __init__(
self,
*,
artifact_location: str,
artifact_mode: str = ArtifactMode.PRODUCE,
transforms: Optional[Sequence[BaseOperation]] = None):
"""
MLTransform is a Beam PTransform that can be used to apply
transformations to the data. MLTransform is used to wrap the
data processing transforms provided by Apache Beam. MLTransform
works in two modes: produce and consume. In the produce mode,
MLTransform will apply the transforms to the data and store the
artifacts in the artifact_location. In the consume mode, MLTransform
will read the artifacts from the artifact_location and apply the
transforms to the data. The artifact_location should be a valid
storage path where the artifacts can be written to or read from.
Note that when consuming artifacts, it is not necessary to pass the
transforms since they are inherently stored within the artifacts
themselves.
Args:
artifact_location: A storage location for artifacts resulting from
MLTransform. These artifacts include transformations applied to
the dataset and generated values like min, max from ScaleTo01,
and mean, var from ScaleToZScore. Artifacts are produced and stored
in this location when the `artifact_mode` is set to 'produce'.
Conversely, when `artifact_mode` is set to 'consume', artifacts are
retrieved from this location. Note that when consuming artifacts,
it is not necessary to pass the transforms since they are inherently
stored within the artifacts themselves. The value assigned to
`artifact_location` should be a valid storage path where the artifacts
can be written to or read from.
transforms: A list of transforms to apply to the data. All the transforms
are applied in the order they are specified. The input of the
i-th transform is the output of the (i-1)-th transform. Multi-input
transforms are not supported yet.
artifact_mode: Whether to produce or consume artifacts. If set to
'consume', MLTransform will assume that the artifacts are already
computed and stored in the artifact_location. Pass the same artifact
location that was passed during produce phase to ensure that the
right artifacts are read. If set to 'produce', MLTransform
will compute the artifacts and store them in the artifact_location.
The artifacts will be read from this location during the consume phase.
"""
if transforms:
_ = [self._validate_transform(transform) for transform in transforms]
# avoid circular import
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.ml.transforms.handlers import TFTProcessHandler
# TODO: When new ProcessHandlers(eg: JaxProcessHandler) are introduced,
# create a mapping between transforms and ProcessHandler since
# ProcessHandler is not exposed to the user.
process_handler: ProcessHandler = TFTProcessHandler(
artifact_location=artifact_location,
artifact_mode=artifact_mode,
transforms=transforms) # type: ignore[arg-type]
self._process_handler = process_handler
def expand(
self, pcoll: beam.PCollection[ExampleT]
) -> beam.PCollection[MLTransformOutputT]:
"""
This is the entrypoint for the MLTransform. This method will
invoke the process_data() method of the ProcessHandler instance
to process the incoming data.
process_data takes in a PCollection and applies the PTransforms
necessary to process the data and returns a PCollection of
transformed data.
Args:
pcoll: A PCollection of ExampleT type.
Returns:
A PCollection of MLTransformOutputT type.
"""
return self._process_handler.process_data(pcoll)
def with_transform(self, transform: BaseOperation):
"""
Add a transform to the MLTransform pipeline.
Args:
transform: A BaseOperation instance.
Returns:
A MLTransform instance.
"""
self._validate_transform(transform)
self._process_handler.append_transform(transform)
return self
def _validate_transform(self, transform):
if not isinstance(transform, BaseOperation):
raise TypeError(
'transform must be a subclass of BaseOperation. '
'Got: %s instead.' % type(transform))