Skip to content

Commit

Permalink
Add TargetEncoder wrapper (as of scikit-learn 1.3)
Browse files Browse the repository at this point in the history
More sklearn schema updates

Signed-off-by: Avi Shinnar <shinnar@us.ibm.com>
  • Loading branch information
shinnar committed Feb 7, 2024
1 parent df4225b commit fe00ebd
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ graphviz
hyperopt
jsonschema
jsonsubschema
scikit-learn>=1.0.0,<=1.2.0
scikit-learn>=1.0.0,<1.4
scipy
pandas
decorator
Expand Down
73 changes: 71 additions & 2 deletions lale/lib/autogen/fast_ica.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.decomposition import FastICA as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _FastICAImpl:
Expand Down Expand Up @@ -173,4 +173,73 @@ def transform(self, X):
}
FastICA = make_operator(_FastICAImpl, _combined_schemas)

if sklearn_version >= version.Version("1.1"):
FastICA = FastICA.customize_schema(
whiten={
"anyOf": [
{
"enum": [False],
"description": "The data is already considered to be whitened, and no whitening is performed.",
},
{
"enum": ["arbitrary-variance"],
"description": "(default) A whitening with variance arbitrary is used",
},
{
"enum": ["unit-variance"],
"description": "The whitening matrix is rescaled to ensure that each recovered source has unit variance.",
},
{
"enum": [True, "warn"],
"description": "deprecated. Use 'arbitrary-variance' instead",
},
],
"description": "Specify the whitening strategy to use.",
"default": "warn",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.1"):
FastICA = FastICA.customize_schema(
whiten_solver={
"anyOf": [
{
"enum": ["eigh"],
"description": "Generally more memory efficient when n_samples >= n_features, and can be faster when n_samples >= 50 * n_features.",
},
{
"enum": ["svd"],
"description": "More stable numerically if the problem is degenerate, and often faster when n_samples <= n_features.",
},
],
"description": "The solver to use for whitening.",
"default": "svd",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.3"):
FastICA = FastICA.customize_schema(
whiten={
"anyOf": [
{
"enum": [False],
"description": "The data is already considered to be whitened, and no whitening is performed.",
},
{
"enum": ["arbitrary-variance"],
"description": "A whitening with variance arbitrary is used",
},
{
"enum": ["unit-variance"],
"description": "The whitening matrix is rescaled to ensure that each recovered source has unit variance.",
},
],
"description": "Specify the whitening strategy to use.",
"default": "arbitrary-variance",
},
set_as_available=True,
)

set_docstrings(FastICA)
5 changes: 5 additions & 0 deletions lale/lib/autogen/mini_batch_sparse_pca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from numpy import inf, nan
from packaging import version
from sklearn.decomposition import MiniBatchSparsePCA as Op
from sklearn.utils._available_if import available_if

from lale.docstrings import set_docstrings
from lale.operators import make_operator, sklearn_version
Expand All @@ -21,6 +22,10 @@ def fit(self, X, y=None):
def transform(self, X):
return self._wrapped_model.transform(X)

@available_if(lambda self: (hasattr(self._wrapped_model, "inverse_transform")))
def inverse_transform(self, X):
return self._wrapped_model.inverse_transform(X)


_hyperparams_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
Expand Down
5 changes: 5 additions & 0 deletions lale/lib/autogen/sparse_pca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from numpy import inf, nan
from sklearn.decomposition import SparsePCA as Op
from sklearn.utils._available_if import available_if

from lale.docstrings import set_docstrings
from lale.operators import make_operator
Expand All @@ -20,6 +21,10 @@ def fit(self, X, y=None):
def transform(self, X):
return self._wrapped_model.transform(X)

@available_if(lambda self: (hasattr(self._wrapped_model, "inverse_transform")))
def inverse_transform(self, X):
return self._wrapped_model.inverse_transform(X)


_hyperparams_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
Expand Down
3 changes: 3 additions & 0 deletions lale/lib/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
* lale.lib.sklearn. `SelectKBest`_
* lale.lib.sklearn. `SimpleImputer`_
* lale.lib.sklearn. `StandardScaler`_
* lale.lib.sklearn. `TargetEncoder`_
* lale.lib.sklearn. `TfidfVectorizer`_
* lale.lib.sklearn. `VarianceThreshold`_
Expand Down Expand Up @@ -130,6 +131,7 @@
.. _`Nystroem`: lale.lib.sklearn.nystroem.html
.. _`OneHotEncoder`: lale.lib.sklearn.one_hot_encoder.html
.. _`OrdinalEncoder`: lale.lib.sklearn.ordinal_encoder.html
.. _`TargetEncoder`: lale.lib.sklearn.target_encoder.html
.. _`PassiveAggressiveClassifier`: lale.lib.sklearn.passive_aggressive_classifier.html
.. _`PCA`: lale.lib.sklearn.pca.html
.. _`Perceptron`: lale.lib.sklearn.perceptron.html
Expand Down Expand Up @@ -231,6 +233,7 @@
from .standard_scaler import StandardScaler as StandardScaler
from .svc import SVC as SVC
from .svr import SVR as SVR
from .target_encoder import TargetEncoder as TargetEncoder
from .tfidf_vectorizer import TfidfVectorizer as TfidfVectorizer
from .variance_threshold import VarianceThreshold as VarianceThreshold
from .voting_classifier import VotingClassifier as VotingClassifier
Expand Down
225 changes: 225 additions & 0 deletions lale/lib/sklearn/target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sklearn.preprocessing
from packaging import version

import lale.docstrings
import lale.operators


class _TargetEncoderNotFoundImpl:
def __init__(self, **hyperparams):
raise NotImplementedError(
"TargetEncoder is only available with scikit-learn versions >= 1.3"
)

def transform(self, X):
raise NotImplementedError(
"TargetEncoder is only available with scikit-learn versions >= 1.3"
)


_hyperparams_schema = {
"description": "Hyperparameter schema for the TargetEncoder model from scikit-learn.",
"allOf": [
{
"type": "object",
"additionalProperties": False,
"required": ["categories", "target_type"],
"relevantToOptimizer": [],
"properties": {
"categories": {
"anyOf": [
{
"description": "Determine categories automatically from training data.",
"enum": ["auto"],
},
{
"description": "The ith list element holds the categories expected in the ith column.",
"type": "array",
"items": {
"anyOf": [
{
"type": "array",
"items": {"type": "string"},
},
{
"type": "array",
"items": {"type": "number"},
"description": "Should be sorted.",
},
]
},
},
],
"default": "auto",
"description": "Categories (unique values) per feature.",
},
"target_type": {
"anyOf": [
{
"enum": ["auto"],
"description": "Type of target is inferred with type_of_target.",
},
{"enum": ["continuous"], "description": "Continuous target"},
{"enum": ["binary"], "description": "Binary target"},
],
"description": "Type of target.",
"default": "auto",
},
"smooth": {
"anyOf": [
{
"enum": ["auto"],
"description": "Set to an empirical Bayes estimate.",
},
{
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": "A larger smooth value will put more weight on the global target mean",
},
],
"description": "The amount of mixing of the target mean conditioned on the value of the category with the global target mean.",
"default": "auto",
},
"cv": {
"type": "integer",
"minimum": 1,
"description": "Determines the number of folds in the cross fitting strategy used in fit_transform. For classification targets, StratifiedKFold is used and for continuous targets, KFold is used.",
"default": 5,
},
"shuffle": {
"type": "boolean",
"description": "Whether to shuffle the data in fit_transform before splitting into folds. Note that the samples within each split will not be shuffled.",
"default": True,
},
"random_state": {
"description": "When shuffle is True, random_state affects the ordering of the indices, which controls the randomness of each fold. Otherwise, this parameter has no effect. Pass an int for reproducible output across multiple function calls.",
"anyOf": [
{
"enum": [None],
},
{
"description": "Use the provided random state, only affecting other users of that same random state instance.",
"laleType": "numpy.random.RandomState",
},
{"description": "Explicit seed.", "type": "integer"},
],
"default": None,
},
},
}
],
}

_input_fit_schema = {
"type": "object",
"required": ["X"],
"additionalProperties": False,
"properties": {
"X": {
"description": "Features; the outer array is over samples.",
"type": "array",
"items": {
"anyOf": [
{"type": "array", "items": {"type": "number"}},
{"type": "array", "items": {"type": "string"}},
]
},
},
"y": {
"description": "The target data used to encode the categories.",
"type": "array",
},
},
}

_input_transform_schema = {
"type": "object",
"required": ["X"],
"additionalProperties": False,
"properties": {
"X": {
"description": "Features; the outer array is over samples.",
"type": "array",
"items": {
"anyOf": [
{"type": "array", "items": {"type": "number"}},
{"type": "array", "items": {"type": "string"}},
]
},
}
},
}

_output_transform_schema = {
"description": "Transformed input; the outer array is over samples.",
"type": "array",
"items": {
"anyOf": [
{"type": "array", "items": {"type": "number"}},
{"type": "array", "items": {"type": "string"}},
]
},
}

_combined_schemas = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": """`Target encoder`_ for regression and classification targets..
.. _`Target encoder`: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
""",
"documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.target_encoder.html",
"import_from": "sklearn.preprocessing",
"type": "object",
"tags": {"pre": ["categoricals"], "op": ["transformer"], "post": []},
"properties": {
"hyperparams": _hyperparams_schema,
"input_fit": _input_fit_schema,
"input_transform": _input_transform_schema,
"output_transform": _output_transform_schema,
},
}

if lale.operators.sklearn_version >= version.Version("1.3"):
TargetEncoder = lale.operators.make_operator(
sklearn.preprocessing.TargetEncoder, _combined_schemas
)
else:
TargetEncoder = lale.operators.make_operator(
_TargetEncoderNotFoundImpl, _combined_schemas
)


if lale.operators.sklearn_version >= version.Version("1.4"):
TargetEncoder = TargetEncoder.customize_schema(
target_type={
"anyOf": [
{
"enum": ["auto"],
"description": "Type of target is inferred with type_of_target.",
},
{"enum": ["continuous"], "description": "Continuous target"},
{"enum": ["binary"], "description": "Binary target"},
{"enum": ["multiclass"], "description": "Multiclass target"},
],
"description": "Type of target.",
"default": "auto",
},
set_as_available=True,
)

lale.docstrings.set_docstrings(TargetEncoder)
Loading

0 comments on commit fe00ebd

Please sign in to comment.