diff --git a/pyproject.toml b/pyproject.toml index 1284e7c..a451e21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ classifiers = [ requires-python = ">=3.8,<3.14" dependencies = [ "scikit-learn>=1.2.2", + "typing-extensions>=4.1.0; python_full_version < '3.11'" ] [dependency-groups] diff --git a/requirements.txt b/requirements.txt index 1fb8114..b52fd23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ scikit-learn>=1.2.2 +typing-extensions>=4.1.0; python_version < "3.11" \ No newline at end of file diff --git a/src/linearboost/linear_boost.py b/src/linearboost/linear_boost.py index 1f233cf..b99dd74 100644 --- a/src/linearboost/linear_boost.py +++ b/src/linearboost/linear_boost.py @@ -1,7 +1,14 @@ from __future__ import annotations +import sys +import warnings from numbers import Integral, Real +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import numpy as np from sklearn.base import clone from sklearn.ensemble import AdaBoostClassifier @@ -65,13 +72,18 @@ class LinearBoostClassifier(AdaBoostClassifier): algorithm : {'SAMME', 'SAMME.R'}, default='SAMME' If 'SAMME' then use the SAMME discrete boosting algorithm. - If 'SAMME.R' then use the SAMME.R real boosting algorithm. + If 'SAMME.R' then use the SAMME.R real boosting algorithm + (only available in scikit-learn < 1.6). The SAMME.R algorithm typically converges faster than SAMME, achieving a lower test error with fewer boosting iterations. - .. deprecated:: sklearn 1.6 - `algorithm` is deprecated and will be removed in sklearn 1.8. This - estimator only implements the 'SAMME' algorithm. + .. deprecated:: scikit-learn 1.4 + `"SAMME.R"` is deprecated and will be removed in scikit-learn 1.6. + '"SAMME"' will become the default. + + .. deprecated:: scikit-learn 1.6 + `algorithm` is deprecated and will be removed in scikit-learn 1.8. + This estimator only implements the 'SAMME' algorithm in scikit-learn >= 1.6. scaler : str, default='minmax' Specifies the scaler to apply to the data. Options include: @@ -111,21 +123,21 @@ class LinearBoostClassifier(AdaBoostClassifier): where: - y_true: Ground truth (correct) target values. - y_pred: Estimated target values. - - sample_weight: Sample weights. + - sample_weight: Sample weights (optional). Attributes ---------- estimator_ : estimator The base estimator (SEFR) from which the ensemble is grown. - .. versionadded:: sklearn 1.2 + .. versionadded:: scikit-learn 1.2 `base_estimator_` was renamed to `estimator_`. base_estimator_ : estimator The base estimator from which the ensemble is grown. - .. deprecated:: sklearn 1.2 - `base_estimator_` is deprecated and will be removed in sklearn 1.4. + .. deprecated:: scikit-learn 1.2 + `base_estimator_` is deprecated and will be removed in scikit-learn 1.4. Use `estimator_` instead. estimators_ : list of classifiers @@ -176,10 +188,9 @@ class LinearBoostClassifier(AdaBoostClassifier): _parameter_constraints: dict = { "n_estimators": [Interval(Integral, 1, None, closed="left")], "learning_rate": [Interval(Real, 0, None, closed="neither")], - "algorithm": [ - StrOptions({"SAMME", "SAMME.R"}), - Hidden(StrOptions({"deprecated"})), - ], + "algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))] + if SKLEARN_V1_6_OR_LATER + else [StrOptions({"SAMME", "SAMME.R"})], "scaler": [StrOptions({s for s in _scalers})], "class_weight": [ StrOptions({"balanced_subsample", "balanced"}), @@ -257,7 +268,7 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y - def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier": + def fit(self, X, y, sample_weight=None) -> Self: X, y = self._check_X_y(X, y) self.classes_ = np.unique(y) self.n_classes_ = self.classes_.shape[0] @@ -291,7 +302,14 @@ def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier": else: sample_weight = expanded_class_weight - return super().fit(X_transformed, y, sample_weight) + with warnings.catch_warnings(): + if SKLEARN_V1_6_OR_LATER: + warnings.filterwarnings( + "ignore", + category=FutureWarning, + message=".*parameter 'algorithm' is deprecated.*", + ) + return super().fit(X_transformed, y, sample_weight) def _boost(self, iboost, X, y, sample_weight, random_state): estimator = self._make_estimator(random_state=random_state) diff --git a/src/linearboost/sefr.py b/src/linearboost/sefr.py index 29915ff..db6477f 100644 --- a/src/linearboost/sefr.py +++ b/src/linearboost/sefr.py @@ -1,5 +1,12 @@ from __future__ import annotations +import sys + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + import numpy as np from sklearn.base import BaseEstimator from sklearn.linear_model._base import LinearClassifierMixin @@ -139,7 +146,7 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y, sample_weight=None) -> "SEFR": + def fit(self, X, y, sample_weight=None) -> Self: """ Fit the model according to the given training data. diff --git a/uv.lock b/uv.lock index d7c393a..095ddbb 100644 --- a/uv.lock +++ b/uv.lock @@ -74,6 +74,7 @@ source = { editable = "." } dependencies = [ { name = "scikit-learn", version = "1.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "scikit-learn", version = "1.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] [package.dev-dependencies] @@ -85,7 +86,10 @@ dev = [ ] [package.metadata] -requires-dist = [{ name = "scikit-learn", specifier = ">=1.2.2" }] +requires-dist = [ + { name = "scikit-learn", specifier = ">=1.2.2" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.1.0" }, +] [package.metadata.requires-dev] dev = [ @@ -575,3 +579,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383 }, { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, ] + +[[package]] +name = "typing-extensions" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/3e/b00a62db91a83fff600de219b6ea9908e6918664899a2d85db222f4fbf19/typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b", size = 106520 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/86/39b65d676ec5732de17b7e3c476e45bb80ec64eb50737a8dce1a4178aba1/typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5", size = 45683 }, +]