/
base.py
245 lines (199 loc) · 7.14 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABCMeta, abstractmethod
import pandas as pd
from typing import (
Any,
Generic,
List,
Optional,
TypeVar,
Union,
TYPE_CHECKING,
Tuple,
Callable,
)
from pyspark import since
from pyspark.ml.common import inherit_doc
from pyspark.sql.dataframe import DataFrame
from pyspark.ml.param import Params
from pyspark.mlv2.util import transform_dataframe_column
if TYPE_CHECKING:
from pyspark.ml._typing import ParamMap
M = TypeVar("M", bound="Transformer")
@inherit_doc
class Estimator(Params, Generic[M], metaclass=ABCMeta):
"""
Abstract class for estimators that fit models to data.
.. versionadded:: 3.5.0
"""
@abstractmethod
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> M:
"""
Fits a model to the input dataset. This is called by the default implementation of fit.
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
input dataset
Returns
-------
:class:`Transformer`
fitted model
"""
raise NotImplementedError()
def fit(
self,
dataset: Union[DataFrame, pd.DataFrame],
params: Optional["ParamMap"] = None,
) -> Union[M, List[M]]:
"""
Fits a model to the input dataset with optional parameters.
.. versionadded:: 3.5.0
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame` or py:class:`pandas.DataFrame`
input dataset, it can be either pandas dataframe or spark dataframe.
params : a dict of param values, optional
an optional param map that overrides embedded params.
Returns
-------
:py:class:`Transformer`
fitted model
"""
if params is None:
params = dict()
if isinstance(params, dict):
if params:
return self.copy(params)._fit(dataset)
else:
return self._fit(dataset)
else:
raise TypeError(
"Params must be either a param map or a list/tuple of param maps, "
"but got %s." % type(params)
)
_SPARKML_TRANSFORMER_TMP_OUTPUT_COLNAME = "_sparkML_transformer_tmp_output"
@inherit_doc
class Transformer(Params, metaclass=ABCMeta):
"""
Abstract class for transformers that transform one dataset into another.
.. versionadded:: 3.5.0
"""
def _input_column_name(self) -> str:
"""
Return the name of the input column that is transformed.
"""
raise NotImplementedError()
def _output_columns(self) -> List[Tuple[str, str]]:
"""
Return a list of output transformed columns, each elements in the list
is a tuple of (column_name, column_spark_type)
"""
raise NotImplementedError()
def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
"""
Return a transformation function that accepts an instance of `pd.Series` as input and
returns transformed result as an instance of `pd.Series` or `pd.DataFrame`.
If there's only one output column, the transformed result must be an
instance of `pd.Series`, if there are multiple output columns, the transformed result
must be an instance of `pd.DataFrame` with column names matching output schema
returned by `_output_columns` interface.
"""
raise NotImplementedError()
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
"""
Transforms the input dataset.
The dataset can be either pandas dataframe or spark dataframe,
if it is pandas dataframe, transforms the dataframe locally without creating spark jobs.
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame` or py:class:`pandas.DataFrame`
input dataset.
Returns
-------
:py:class:`pyspark.sql.DataFrame` or py:class:`pandas.DataFrame`
transformed dataset, the type of output dataframe is consistent with
input dataframe.
"""
input_col_name = self._input_column_name()
transform_fn = self._get_transform_fn()
output_cols = self._output_columns()
return transform_dataframe_column(
dataset,
input_col_name=input_col_name,
transform_fn=transform_fn,
output_cols=output_cols,
)
@inherit_doc
class Evaluator(Params, metaclass=ABCMeta):
"""
Base class for evaluators that compute metrics from predictions.
.. versionadded:: 3.5.0
"""
@abstractmethod
def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
"""
Evaluates the output.
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
a dataset that contains labels/observations and predictions
Returns
-------
float
metric
"""
raise NotImplementedError()
def evaluate(self, dataset: DataFrame, params: Optional["ParamMap"] = None) -> float:
"""
Evaluates the output with optional parameters.
.. versionadded:: 3.5.0
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
a dataset that contains labels/observations and predictions
params : dict, optional
an optional param map that overrides embedded params
Returns
-------
float
metric
"""
if params is None:
params = dict()
if isinstance(params, dict):
if params:
return self.copy(params)._evaluate(dataset)
else:
return self._evaluate(dataset)
else:
raise TypeError("Params must be a param map but got %s." % type(params))
@since("1.5.0")
def isLargerBetter(self) -> bool:
"""
Indicates whether the metric returned by :py:meth:`evaluate` should be maximized
(True, default) or minimized (False).
A given evaluator may support multiple metrics which may be maximized or minimized.
"""
return True
@inherit_doc
class Model(Transformer, metaclass=ABCMeta):
"""
Abstract class for models that are fitted by estimators.
.. versionadded:: 3.5.0
"""
pass