Skip to content

Commit

Permalink
Add support for scikit-learn 1.4
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Shinnar <shinnar@us.ibm.com>
  • Loading branch information
shinnar committed Feb 12, 2024
1 parent c43aabd commit 2b4c300
Show file tree
Hide file tree
Showing 35 changed files with 625 additions and 97 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ graphviz
hyperopt
jsonschema
jsonsubschema
scikit-learn>=1.0.0,<1.4
scikit-learn>=1.0.0,<1.5.0
scipy
pandas
decorator
Expand Down
6 changes: 5 additions & 1 deletion lale/datasets/openml/openml_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,12 @@ def fetch(
]
txm1 = ColumnTransformer(transformers1, sparse_threshold=0.0)

if sklearn_version >= version.Version("1.2"):
ohe2 = OneHotEncoder(sparse_output=False)
else:
ohe2 = OneHotEncoder(sparse=False)
transformers2 = [
("ohe", OneHotEncoder(sparse=False), list(range(len(categorical_cols)))),
("ohe", ohe2, list(range(len(categorical_cols)))),
(
"no_op",
"passthrough",
Expand Down
24 changes: 24 additions & 0 deletions lale/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,30 @@ def get_sklearn_estimator_name() -> str:
return "estimator"


def with_fixed_estimator_name(**kwargs):
"""Some higher order sklearn operators changed the name of the nested estimator in later versions.
This fixes up the arguments, renaming estimator and base_estimator appropriately.
"""

if "base_estimator" in kwargs or "estimator" in kwargs:
from packaging import version

import lale.operators

if lale.operators.sklearn_version < version.Version("1.2"):
return {
"base_estimator" if k == "estimator" else k: v
for k, v in kwargs.items()
}
else:
return {
"estimator" if k == "base_estimator" else k: v
for k, v in kwargs.items()
}

return kwargs


def get_estimator_param_name_from_hyperparams(hyperparams):
be = hyperparams.get("base_estimator", "deprecated")
if be == "deprecated" or (be is None and "estimator" in hyperparams):
Expand Down
11 changes: 7 additions & 4 deletions lale/lib/aif360/bagging_orbis_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import lale.operators
from lale.lib.imblearn._common_schemas import _hparam_n_jobs, _hparam_random_state

from ...helpers import with_fixed_estimator_name
from .orbis import Orbis
from .orbis import _hyperparams_schema as orbis_hyperparams_schema
from .util import (
Expand Down Expand Up @@ -115,10 +116,12 @@ def _repair_dtypes(inner_X): # for some reason BaggingClassifier spoils dtypes

repair_dtypes = lale.lib.sklearn.FunctionTransformer(func=_repair_dtypes)
trainable_ensemble = lale.lib.sklearn.BaggingClassifier(
base_estimator=repair_dtypes >> orbis,
n_estimators=self.n_estimators,
n_jobs=self.sampler_hparams["n_jobs"],
random_state=self.sampler_hparams["random_state"],
**with_fixed_estimator_name(
estimator=repair_dtypes >> orbis,
n_estimators=self.n_estimators,
n_jobs=self.sampler_hparams["n_jobs"],
random_state=self.sampler_hparams["random_state"],
)
)
encoded_y = pd.Series(self.lab_enc.transform(y), index=y.index)
self.trained_ensemble = trainable_ensemble.fit(X, encoded_y)
Expand Down
22 changes: 21 additions & 1 deletion lale/lib/autogen/kernel_pca.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from numpy import inf, nan
from packaging import version
from sklearn.decomposition import KernelPCA as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _KernelPCAImpl:
Expand Down Expand Up @@ -239,4 +240,23 @@ def transform(self, X):
}
KernelPCA = make_operator(_KernelPCAImpl, _combined_schemas)

if sklearn_version >= version.Version("1.4"):

KernelPCA = KernelPCA.customize_schema(
degree={
"anyOf": [
{
"type": "integer",
"minimumForOptimizer": 2,
"maximumForOptimizer": 3,
"distribution": "uniform",
},
{"type": "number", "forOptimizer": False},
],
"default": 3,
"description": "Degree for poly kernels",
},
set_as_available=True,
)

set_docstrings(KernelPCA)
19 changes: 19 additions & 0 deletions lale/lib/autogen/kernel_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,23 @@ def predict(self, X):
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):

KernelRidge = KernelRidge.customize_schema(
degree={
"anyOf": [
{
"type": "integer",
"minimumForOptimizer": 0,
"maximumForOptimizer": 100,
"distribution": "uniform",
},
{"type": "number", "forOptimizer": False},
],
"default": 3,
"description": "Degree of the polynomial kernel",
},
set_as_available=True,
)

set_docstrings(KernelRidge)
24 changes: 22 additions & 2 deletions lale/lib/autogen/lars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import Lars as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _LarsImpl:
Expand Down Expand Up @@ -197,4 +197,24 @@ def predict(self, X):
}
Lars = make_operator(_LarsImpl, _combined_schemas)

if sklearn_version >= version.Version("1.2"):
Lars = Lars.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
Lars = Lars.customize_schema(normalize=None, set_as_available=True)


set_docstrings(Lars)
23 changes: 21 additions & 2 deletions lale/lib/autogen/lars_cv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import LarsCV as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _LarsCVImpl:
Expand Down Expand Up @@ -203,4 +203,23 @@ def predict(self, X):
}
LarsCV = make_operator(_LarsCVImpl, _combined_schemas)

if sklearn_version >= version.Version("1.2"):
LarsCV = LarsCV.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
LarsCV = LarsCV.customize_schema(normalize=None, set_as_available=True)

set_docstrings(LarsCV)
23 changes: 21 additions & 2 deletions lale/lib/autogen/lasso_lars.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import LassoLars as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _LassoLarsImpl:
Expand Down Expand Up @@ -197,4 +197,23 @@ def predict(self, X):
}
LassoLars = make_operator(_LassoLarsImpl, _combined_schemas)

if sklearn_version >= version.Version("1.2"):
LassoLars = LassoLars.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
LassoLars = LassoLars.customize_schema(normalize=None, set_as_available=True)

set_docstrings(LassoLars)
24 changes: 22 additions & 2 deletions lale/lib/autogen/lasso_lars_cv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import LassoLarsCV as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _LassoLarsCVImpl:
Expand Down Expand Up @@ -203,4 +203,24 @@ def predict(self, X):
}
LassoLarsCV = make_operator(_LassoLarsCVImpl, _combined_schemas)


if sklearn_version >= version.Version("1.2"):
LassoLarsCV = LassoLarsCV.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
LassoLarsCV = LassoLarsCV.customize_schema(normalize=None, set_as_available=True)

set_docstrings(LassoLarsCV)
23 changes: 21 additions & 2 deletions lale/lib/autogen/lasso_lars_ic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import LassoLarsIC as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _LassoLarsICImpl:
Expand Down Expand Up @@ -201,4 +201,23 @@ def predict(self, X):
}
LassoLarsIC = make_operator(_LassoLarsICImpl, _combined_schemas)

if sklearn_version >= version.Version("1.2"):
LassoLarsIC = LassoLarsIC.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
LassoLarsIC = LassoLarsIC.customize_schema(normalize=None, set_as_available=True)

set_docstrings(LassoLarsIC)
25 changes: 23 additions & 2 deletions lale/lib/autogen/orthogonal_matching_pursuit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import OrthogonalMatchingPursuit as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _OrthogonalMatchingPursuitImpl:
Expand Down Expand Up @@ -156,4 +156,25 @@ def predict(self, X):
_OrthogonalMatchingPursuitImpl, _combined_schemas
)

if sklearn_version >= version.Version("1.2"):
OrthogonalMatchingPursuit = OrthogonalMatchingPursuit.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
OrthogonalMatchingPursuit = OrthogonalMatchingPursuit.customize_schema(
normalize=None, set_as_available=True
)

set_docstrings(OrthogonalMatchingPursuit)
26 changes: 24 additions & 2 deletions lale/lib/autogen/orthogonal_matching_pursuit_cv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from numpy import inf, nan
from packaging import version
from sklearn.linear_model import OrthogonalMatchingPursuitCV as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.operators import make_operator, sklearn_version


class _OrthogonalMatchingPursuitCVImpl:
Expand Down Expand Up @@ -175,4 +175,26 @@ def predict(self, X):
_OrthogonalMatchingPursuitCVImpl, _combined_schemas
)

if sklearn_version >= version.Version("1.2"):
OrthogonalMatchingPursuitCV = OrthogonalMatchingPursuitCV.customize_schema(
normalize={
"anyOf": [
{
"type": "boolean",
"description": "This parameter is ignored when ``fit_intercept`` is set to False",
},
{"enum": ["deprecated"]},
],
"default": "deprecated",
"description": "Deprecated",
},
set_as_available=True,
)

if sklearn_version >= version.Version("1.4"):
OrthogonalMatchingPursuitCV = OrthogonalMatchingPursuitCV.customize_schema(
normalize=None, set_as_available=True
)


set_docstrings(OrthogonalMatchingPursuitCV)
Loading

0 comments on commit 2b4c300

Please sign in to comment.