Skip to content

Commit

Permalink
ENH boosting early_termination control (fix #7)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiningLiu1998 committed Nov 25, 2021
1 parent 1cc4201 commit 5b17c6c
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 14 deletions.
23 changes: 17 additions & 6 deletions imbalanced_ensemble/ensemble/_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def __init__(self,
sampling_type:str,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

self._sampling_type = sampling_type
self.base_sampler = base_sampler
self.early_termination = early_termination

super(ResampleBoostClassifier, self).__init__(
base_estimator=base_estimator,
Expand Down Expand Up @@ -319,6 +321,9 @@ def _fit(self, X, y,

self.sampler_kwargs_ = check_type(
sampler_kwargs, 'sampler_kwargs', dict)

early_termination_ = check_type(
self.early_termination, 'early_termination', bool)

# Check that algorithm is supported.
if self.algorithm not in ('SAMME', 'SAMME.R'):
Expand Down Expand Up @@ -435,14 +440,14 @@ def _fit(self, X, y,
self._training_log_to_console(iboost, y_resampled)

# Early termination.
if sample_weight is None:
if sample_weight is None and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight is None).")
break

# Stop if error is zero.
if estimator_error == 0:
if estimator_error == 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (training error is 0).")
Expand All @@ -451,7 +456,7 @@ def _fit(self, X, y,
sample_weight_sum = np.sum(sample_weight)

# Stop if the sum of sample weights has become non-positive.
if sample_weight_sum <= 0:
if sample_weight_sum <= 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight_sum <= 0).")
Expand Down Expand Up @@ -521,8 +526,11 @@ def __init__(self,
n_estimators:int,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

self.early_termination = early_termination

super(ReweightBoostClassifier, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand Down Expand Up @@ -737,6 +745,9 @@ def _fit(self, X, y,
eval_metrics:dict,
train_verbose:bool or int or dict,
):

early_termination_ = check_type(
self.early_termination, 'early_termination', bool)

# Check that algorithm is supported.
if self.algorithm not in ('SAMME', 'SAMME.R'):
Expand Down Expand Up @@ -834,14 +845,14 @@ def _fit(self, X, y,
self._training_log_to_console(iboost, y)

# Early termination.
if sample_weight is None:
if sample_weight is None and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight is None).")
break

# Stop if error is zero.
if estimator_error == 0:
if estimator_error == 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (training error is 0).")
Expand All @@ -850,7 +861,7 @@ def _fit(self, X, y,
sample_weight_sum = np.sum(sample_weight)

# Stop if the sum of sample weights has become non-positive.
if sample_weight_sum <= 0:
if sample_weight_sum <= 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight_sum <= 0).")
Expand Down
19 changes: 14 additions & 5 deletions imbalanced_ensemble/ensemble/compatible/adaboost_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ..base import ImbalancedEnsembleClassifierMixin, MAX_INT
from ...utils._validation_data import check_eval_datasets
from ...utils._validation_param import (check_train_verbose,
check_eval_metrics)
check_eval_metrics,
check_type)
from ...utils._validation import _deprecate_positional_args
from ...utils._docstring import (Substitution, FuncSubstitution,
FuncGlossarySubstitution,
Expand Down Expand Up @@ -49,6 +50,7 @@


@Substitution(
early_termination=_get_parameter_docstring('early_termination', **_properties),
example=_get_example_docstring(_method_name)
)
class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
Expand All @@ -70,7 +72,6 @@ class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
Support for sample weighting is required, as well as proper
``classes_`` and ``n_classes_`` attributes. If ``None``, then
the base estimator is :class:`~sklearn.tree.DecisionTreeClassifier`
initialized with `max_depth=1`.
n_estimators : int, default=50
The maximum number of estimators at which boosting is terminated.
Expand All @@ -87,6 +88,8 @@ class CompatibleAdaBoostClassifier(ImbalancedEnsembleClassifierMixin,
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
random_state : int, RandomState instance or None, default=None
Controls the random seed given at each `base_estimator` at each
Expand Down Expand Up @@ -148,8 +151,11 @@ def __init__(self,
n_estimators:int=50,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

self.early_termination = early_termination

super(CompatibleAdaBoostClassifier, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand Down Expand Up @@ -200,6 +206,9 @@ def fit(self, X, y,
self : object
"""

early_termination_ = check_type(
self.early_termination, 'early_termination', bool)

# Check that algorithm is supported.
if self.algorithm not in ('SAMME', 'SAMME.R'):
raise ValueError("algorithm %s is not supported" % self.algorithm)
Expand Down Expand Up @@ -281,14 +290,14 @@ def fit(self, X, y,
self._training_log_to_console(iboost, y)

# Early termination.
if sample_weight is None:
if sample_weight is None and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight is None).")
break

# Stop if error is zero.
if estimator_error == 0:
if estimator_error == 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (training error is 0).")
Expand All @@ -297,7 +306,7 @@ def fit(self, X, y,
sample_weight_sum = np.sum(sample_weight)

# Stop if the sum of sample weights has become non-positive.
if sample_weight_sum <= 0:
if sample_weight_sum <= 0 and early_termination_:
print (f"Training early-stop at iteration"
f" {iboost+1}/{self.n_estimators}"
f" (sample_weight_sum <= 0).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

@Substitution(
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -108,6 +109,8 @@ class KmeansSMOTEBoostClassifier(ResampleBoostClassifier):
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Expand Down Expand Up @@ -177,6 +180,7 @@ def __init__(self,
density_exponent="auto",
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

base_sampler = _sampler_class()
Expand All @@ -189,6 +193,7 @@ def __init__(self,
sampling_type=sampling_type,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)

self.__name__ = _method_name
Expand Down
6 changes: 5 additions & 1 deletion imbalanced_ensemble/ensemble/over_sampling/over_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@


@Substitution(
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -81,6 +81,8 @@ class OverBoostClassifier(ResampleBoostClassifier):
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(self,
*,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

base_sampler = _sampler_class()
Expand All @@ -157,6 +160,7 @@ def __init__(self,
sampling_type=sampling_type,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)

self.__name__ = _method_name
Expand Down
6 changes: 5 additions & 1 deletion imbalanced_ensemble/ensemble/over_sampling/smote_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@


@Substitution(
n_jobs_sampler=_get_parameter_docstring('n_jobs_sampler', **_properties),
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -91,6 +91,8 @@ class balancing by SMOTE over-sampling the sample at each
If 'SAMME' then use the SAMME discrete boosting algorithm.
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Expand Down Expand Up @@ -156,6 +158,7 @@ def __init__(self,
k_neighbors:int=5,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

base_sampler = _sampler_class()
Expand All @@ -168,6 +171,7 @@ def __init__(self,
sampling_type=sampling_type,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)

self.__name__ = _method_name
Expand Down
5 changes: 5 additions & 0 deletions imbalanced_ensemble/ensemble/reweighting/adacost.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@


@Substitution(
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -82,6 +83,8 @@ class AdaCostClassifier(ReweightBoostClassifier):
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Attributes
Expand Down Expand Up @@ -143,13 +146,15 @@ def __init__(self,
*,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

super(AdaCostClassifier, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)

self.__name__ = _method_name
Expand Down
5 changes: 5 additions & 0 deletions imbalanced_ensemble/ensemble/reweighting/adauboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@


@Substitution(
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -83,6 +84,8 @@ class AdaUBoostClassifier(ReweightBoostClassifier):
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Attributes
Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(self,
*,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

self.__name__ = 'AdaUBoostClassifier'
Expand All @@ -156,6 +160,7 @@ def __init__(self,
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)


Expand Down
5 changes: 5 additions & 0 deletions imbalanced_ensemble/ensemble/reweighting/asymmetric_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@


@Substitution(
early_termination=_get_parameter_docstring('early_termination', **_properties),
random_state=_get_parameter_docstring('random_state', **_properties),
example=_get_example_docstring(_method_name)
)
Expand Down Expand Up @@ -78,6 +79,8 @@ class AsymBoostClassifier(ReweightBoostClassifier):
The SAMME.R algorithm typically converges faster than SAMME,
achieving a lower test error with fewer boosting iterations.
{early_termination}
{random_state}
Attributes
Expand Down Expand Up @@ -137,13 +140,15 @@ def __init__(self,
*,
learning_rate:float=1.,
algorithm:str='SAMME.R',
early_termination:bool=False,
random_state=None):

super(AsymBoostClassifier, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
learning_rate=learning_rate,
algorithm=algorithm,
early_termination=early_termination,
random_state=random_state)

self.__name__ = _method_name
Expand Down

0 comments on commit 5b17c6c

Please sign in to comment.