diff --git a/sklearn/_fast_gradient_boosting/gradient_boosting.py b/sklearn/_fast_gradient_boosting/gradient_boosting.py index 4fd5148555ce0..c29a2673831ca 100644 --- a/sklearn/_fast_gradient_boosting/gradient_boosting.py +++ b/sklearn/_fast_gradient_boosting/gradient_boosting.py @@ -24,20 +24,20 @@ class BaseFastGradientBoosting(BaseEstimator, ABC): """Base class for fast gradient boosting estimators.""" @abstractmethod - def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes, + def __init__(self, loss, learning_rate, n_estimators, max_leaf_nodes, max_depth, min_samples_leaf, l2_regularization, max_bins, - scoring, validation_split, n_iter_no_change, tol, verbose, + scoring, validation_fraction, n_iter_no_change, tol, verbose, random_state): self.loss = loss self.learning_rate = learning_rate - self.max_iter = max_iter + self.n_estimators = n_estimators self.max_leaf_nodes = max_leaf_nodes self.max_depth = max_depth self.min_samples_leaf = min_samples_leaf self.l2_regularization = l2_regularization self.max_bins = max_bins self.n_iter_no_change = n_iter_no_change - self.validation_split = validation_split + self.validation_fraction = validation_fraction self.scoring = scoring self.tol = tol self.verbose = verbose @@ -58,14 +58,14 @@ def _validate_parameters(self): if self.learning_rate <= 0: raise ValueError(f'learning_rate={self.learning_rate} must ' f'be strictly positive') - if self.max_iter < 1: - raise ValueError(f'max_iter={self.max_iter} must ' + if self.n_estimators < 1: + raise ValueError(f'n_estimators={self.n_estimators} must ' f'not be smaller than 1.') if self.n_iter_no_change is not None and self.n_iter_no_change < 0: raise ValueError(f'n_iter_no_change={self.n_iter_no_change} ' f'must be positive.') - if self.validation_split is not None and self.validation_split <= 0: - raise ValueError(f'validation_split={self.validation_split} ' + if self.validation_fraction is not None and self.validation_fraction <= 0: + raise ValueError(f'validation_fraction={self.validation_fraction} ' f'must be strictly positive, or None.') if self.tol is not None and self.tol < 0: raise ValueError(f'tol={self.tol} ' @@ -116,19 +116,19 @@ def fit(self, X, y): self.do_early_stopping_ = (self.n_iter_no_change is not None and self.n_iter_no_change > 0) - if self.do_early_stopping_ and self.validation_split is not None: + if self.do_early_stopping_ and self.validation_fraction is not None: # stratify for classification stratify = y if hasattr(self.loss_, 'predict_proba') else None X_binned_train, X_binned_val, y_train, y_val = train_test_split( - X_binned, y, test_size=self.validation_split, + X_binned, y, test_size=self.validation_fraction, stratify=stratify, random_state=rng) if X_binned_train.size == 0 or X_binned_val.size == 0: raise ValueError( f'Not enough data (n_samples={X_binned.shape[0]}) to ' - f'perform early stopping with validation_split=' - f'{self.validation_split}. Use more training data or ' - f'adjust validation_split.' + f'perform early stopping with validation_fraction=' + f'{self.validation_fraction}. Use more training data or ' + f'adjust validation_fraction.' ) # Predicting is faster of C-contiguous arrays, training is faster # on Fortran arrays. @@ -138,15 +138,15 @@ def fit(self, X, y): X_binned_train, y_train = X_binned, y X_binned_val, y_val = None, None - # Subsample the training set for score-based monitoring. + # Subsample the training set for early stopping if self.do_early_stopping_: - subsample_size = 10000 + subsample_size = 10000 # should we expose this? indices = np.arange(X_binned_train.shape[0]) if X_binned_train.shape[0] > subsample_size: indices = rng.choice(indices, subsample_size) X_binned_small_train = X_binned_train[indices] y_small_train = y_train[indices] - # Predicting is faster of C-contiguous arrays. + # Predicting is faster on C-contiguous arrays. X_binned_small_train = np.ascontiguousarray(X_binned_small_train) if self.verbose: @@ -170,8 +170,8 @@ def fit(self, X, y): prediction_dim=self.n_trees_per_iteration_ ) - # predictors_ is a matrix of TreePredictor objects with shape - # (n_iter_, n_trees_per_iteration) + # predictors_ is a matrix (list of lists) of TreePredictor objects + # with shape (n_iter_, n_trees_per_iteration) self.predictors_ = predictors = [] # scorer_ is a callable with signature (est, X, y) and calls @@ -184,15 +184,15 @@ def fit(self, X, y): self.train_scores_.append( self._get_scores(X_binned_train, y_train)) - if self.validation_split is not None: + if self.validation_fraction is not None: self.validation_scores_.append( self._get_scores(X_binned_val, y_val)) - for iteration in range(self.max_iter): + for iteration in range(self.n_estimators): if self.verbose: iteration_start_time = time() - print(f"[{iteration + 1}/{self.max_iter}] ", end='', + print(f"[{iteration + 1}/{self.n_estimators}] ", end='', flush=True) # Update gradients and hessians, inplace @@ -277,7 +277,7 @@ def _check_early_stopping(self, X_binned_train, y_train, self.train_scores_.append( self._get_scores(X_binned_train, y_train)) - if self.validation_split is not None: + if self.validation_fraction is not None: self.validation_scores_.append( self._get_scores(X_binned_val, y_val)) return self._should_stop(self.validation_scores_) @@ -342,7 +342,7 @@ def _print_iteration_stats(self, iteration_start_time): if self.do_early_stopping_: log_msg += f"{self.scoring} train: {self.train_scores_[-1]:.5f}, " - if self.validation_split is not None: + if self.validation_fraction is not None: log_msg += (f"{self.scoring} val: " f"{self.validation_scores_[-1]:.5f}, ") @@ -357,8 +357,7 @@ def _raw_predict(self, X): Parameters ---------- X : array-like, shape=(n_samples, n_features) - The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + The input samples. Returns ------- @@ -409,7 +408,7 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin): The learning rate, also known as *shrinkage*. This is used as a multiplicative factor for the leaves values. Use ``1`` for no shrinkage. - max_iter : int, optional(default=100) + n_estimators : int, optional(default=100) The maximum number of iterations of the boosting process, i.e. the maximum number of trees. max_leaf_nodes : int or None, optional(default=None) @@ -428,25 +427,26 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin): allows for a much faster training stage. Features with a small number of unique values may use less than ``max_bins`` bins. Must be no larger than 256. - scoring : str or callable or None, \ - optional (default=None) - Scoring parameter to use for early stopping (see sklearn.metrics for - available options). If None, early stopping is check w.r.t the loss - value. - validation_split : int or float or None, optional(default=0.1) + scoring : str or callable or None, optional (default=None) + Scoring parameter to use for early stopping. It can be a single + string (see :ref:`scoring_parameter`) or a callable (see + :ref:`scoring`). If None, the estimator's default scorer (if + available) is used. If ``scoring='loss'``, early stopping is checked + w.r.t the loss value. Only used if ``n_iter_no_change`` is not None. + validation_fraction : int or float or None, optional(default=0.1) Proportion (or absolute size) of training data to set aside as validation data for early stopping. If None, early stopping is done on - the training data. + the training data. Only used if ``n_iter_no_change`` is not None. n_iter_no_change : int or None, optional (default=5) Used to determine when to "early stop". The fitting process is stopped when none of the last ``n_iter_no_change`` scores are better than the ``n_iter_no_change - 1``th-to-last one, up to some tolerance. If None or 0, no early-stopping is done. tol : float or None optional (default=1e-7) - The absolute tolerance to use when comparing scores. The higher the - tolerance, the more likely we are to early stop: higher tolerance - means that it will be harder for subsequent iterations to be - considered an improvement upon the reference score. + The absolute tolerance to use when comparing scores during early + stopping. The higher the tolerance, the more likely we are to early + stop: higher tolerance means that it will be harder for subsequent + iterations to be considered an improvement upon the reference score. verbose: int, optional (default=0) The verbosity level. If not zero, print some information about the fitting process. @@ -454,9 +454,7 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin): optional (default=None) Pseudo-random number generator to control the subsampling in the binning process, and the train/validation data split if early stopping - is enabled. See - `scikit-learn glossary - `_. + is enabled. See :term:`random_state`. Examples @@ -472,16 +470,16 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin): _VALID_LOSSES = ('least_squares',) def __init__(self, loss='least_squares', learning_rate=0.1, - max_iter=100, max_leaf_nodes=31, max_depth=None, + n_estimators=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0., max_bins=256, - scoring=None, validation_split=0.1, n_iter_no_change=5, + scoring=None, validation_fraction=0.1, n_iter_no_change=5, tol=1e-7, verbose=0, random_state=None): super(FastGradientBoostingRegressor, self).__init__( - loss=loss, learning_rate=learning_rate, max_iter=max_iter, + loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, min_samples_leaf=min_samples_leaf, l2_regularization=l2_regularization, max_bins=max_bins, - scoring=scoring, validation_split=validation_split, + scoring=scoring, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, random_state=random_state) @@ -491,8 +489,7 @@ def predict(self, X): Parameters ---------- X : array-like, shape=(n_samples, n_features) - The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + The input samples. Returns ------- @@ -504,7 +501,7 @@ def predict(self, X): return self._raw_predict(X).ravel() def _encode_y(self, y): - # Just convert y to float32 + # Just convert y to the expected dtype self.n_trees_per_iteration_ = 1 y = y.astype(Y_DTYPE, copy=False) return y @@ -530,7 +527,7 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting, The learning rate, also known as *shrinkage*. This is used as a multiplicative factor for the leaves values. Use ``1`` for no shrinkage. - max_iter : int, optional(default=100) + n_estimators : int, optional(default=100) The maximum number of iterations of the boosting process, i.e. the maximum number of trees for binary classification. For multiclass classification, `n_classes` trees per iteration are built. @@ -551,10 +548,12 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting, number of unique values may use less than ``max_bins`` bins. Must be no larger than 256. scoring : str or callable or None, optional (default=None) - Scoring parameter to use for early stopping (see sklearn.metrics for - available options). If None, early stopping is check w.r.t the loss - value. - validation_split : int or float or None, optional(default=0.1) + Scoring parameter to use for early stopping. It can be a single + string (see :ref:`scoring_parameter`) or a callable (see + :ref:`scoring`). If None, the estimator's default scorer (if + available) is used. If ``scoring='loss'``, early stopping is checked + w.r.t the loss value. Only used if ``n_iter_no_change`` is not None. + validation_fraction : int or float or None, optional(default=0.1) Proportion (or absolute size) of training data to set aside as validation data for early stopping. If None, early stopping is done on the training data. @@ -575,8 +574,7 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting, optional(default=None) Pseudo-random number generator to control the subsampling in the binning process, and the train/validation data split if early stopping - is enabled. See `scikit-learn glossary - `_. + is enabled. See :term:`random_state`. Examples -------- @@ -591,17 +589,17 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting, _VALID_LOSSES = ('binary_crossentropy', 'categorical_crossentropy', 'auto') - def __init__(self, loss='auto', learning_rate=0.1, max_iter=100, + def __init__(self, loss='auto', learning_rate=0.1, n_estimators=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0., max_bins=256, scoring=None, - validation_split=0.1, n_iter_no_change=5, tol=1e-7, + validation_fraction=0.1, n_iter_no_change=5, tol=1e-7, verbose=0, random_state=None): super(FastGradientBoostingClassifier, self).__init__( - loss=loss, learning_rate=learning_rate, max_iter=max_iter, + loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, min_samples_leaf=min_samples_leaf, l2_regularization=l2_regularization, max_bins=max_bins, - scoring=scoring, validation_split=validation_split, + scoring=scoring, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, random_state=random_state) @@ -611,8 +609,7 @@ def predict(self, X): Parameters ---------- X : array-like, shape=(n_samples, n_features) - The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + The input samples. Returns ------- @@ -629,8 +626,7 @@ def predict_proba(self, X): Parameters ---------- X : array-like, shape=(n_samples, n_features) - The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + The input samples. Returns ------- diff --git a/sklearn/_fast_gradient_boosting/grower.py b/sklearn/_fast_gradient_boosting/grower.py index 0f9fdc69b90aa..a50bb7ff715da 100644 --- a/sklearn/_fast_gradient_boosting/grower.py +++ b/sklearn/_fast_gradient_boosting/grower.py @@ -211,8 +211,7 @@ def _validate_parameters(self, X_binned, max_leaf_nodes, max_depth, l2_regularization, min_hessian_to_split): """Validate parameters passed to __init__. - Also validate parameters passed to splitter because we cannot - raise exceptions in a jitclass. + Also validate parameters passed to splitter. """ if X_binned.dtype != np.uint8: raise NotImplementedError( diff --git a/sklearn/_fast_gradient_boosting/histogram.pxd b/sklearn/_fast_gradient_boosting/histogram.pxd index 0b1b8e61bd4f0..ce9c10a48e3c1 100644 --- a/sklearn/_fast_gradient_boosting/histogram.pxd +++ b/sklearn/_fast_gradient_boosting/histogram.pxd @@ -1,9 +1,17 @@ # cython: language_level=3 -"""This module contains njitted routines for building histograms. +"""This module contains routines for building histograms. A histogram is an array with n_bins entry of type HISTOGRAM_DTYPE. Each feature has its own histogram. A histogram contains the sum of gradients and hessians of all the samples belonging to each bin. + +There are different ways to build a histogram: +- by subtraction: hist(child) = hist(parent) - hist(sibling) +- from scratch. In this case we have rountines that update the hessians or not + (not useful when hessians are constant for some losses e.g. least squares). + Also, there's a special case for the root which contains all the samples, + leading to some possible optimizations. Overall all the implementations look + the same, and are optimized for cache hit. """ import numpy as np cimport numpy as np diff --git a/sklearn/_fast_gradient_boosting/histogram.pyx b/sklearn/_fast_gradient_boosting/histogram.pyx index eefc0c84b6951..39176fc770daa 100644 --- a/sklearn/_fast_gradient_boosting/histogram.pyx +++ b/sklearn/_fast_gradient_boosting/histogram.pyx @@ -2,7 +2,7 @@ # cython: boundscheck=False # cython: wraparound=False # cython: language_level=3 -"""This module contains njitted routines for building histograms. +"""This module contains routines for building histograms. A histogram is an array with n_bins entry of type HISTOGRAM_DTYPE. Each feature has its own histogram. A histogram contains the sum of gradients and diff --git a/sklearn/_fast_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/_fast_gradient_boosting/tests/test_compare_lightgbm.py index 05ba2d36a5e84..887cf059dd2ff 100644 --- a/sklearn/_fast_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/_fast_gradient_boosting/tests/test_compare_lightgbm.py @@ -39,7 +39,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, rng = np.random.RandomState(seed=seed) n_samples = n_samples - max_iter = 1 + n_estimators = 1 max_bins = 256 X, y = make_regression(n_samples=n_samples, n_features=5, @@ -53,7 +53,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = FastGradientBoostingRegressor( - max_iter=max_iter, + n_estimators=n_estimators, max_bins=max_bins, learning_rate=1, n_iter_no_change=None, @@ -91,7 +91,7 @@ def test_same_predictions_classification(seed, min_samples_leaf, n_samples, rng = np.random.RandomState(seed=seed) n_samples = n_samples - max_iter = 1 + n_estimators = 1 max_bins = 256 X, y = make_classification(n_samples=n_samples, n_classes=2, n_features=5, @@ -106,7 +106,7 @@ def test_same_predictions_classification(seed, min_samples_leaf, n_samples, est_pygbm = FastGradientBoostingClassifier( loss='binary_crossentropy', - max_iter=max_iter, + n_estimators=n_estimators, max_bins=max_bins, learning_rate=1, n_iter_no_change=None, @@ -151,7 +151,7 @@ def test_same_predictions_multiclass_classification( rng = np.random.RandomState(seed=seed) n_samples = n_samples - max_iter = 1 + n_estimators = 1 max_bins = 256 lr = 1 @@ -168,7 +168,7 @@ def test_same_predictions_multiclass_classification( est_pygbm = FastGradientBoostingClassifier( loss='categorical_crossentropy', - max_iter=max_iter, + n_estimators=n_estimators, max_bins=max_bins, learning_rate=lr, n_iter_no_change=None, diff --git a/sklearn/_fast_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/_fast_gradient_boosting/tests/test_gradient_boosting.py index a56fa0ccb0d0f..20a2fee690f61 100644 --- a/sklearn/_fast_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/_fast_gradient_boosting/tests/test_gradient_boosting.py @@ -32,8 +32,8 @@ def test_init_parameters_validation(GradientBoosting, X, y): assert_raises_regex( ValueError, - f"max_iter=0 must not be smaller than 1", - GradientBoosting(max_iter=0).fit, X, y + f"n_estimators=0 must not be smaller than 1", + GradientBoosting(n_estimators=0).fit, X, y ) assert_raises_regex( @@ -73,11 +73,11 @@ def test_init_parameters_validation(GradientBoosting, X, y): GradientBoosting(n_iter_no_change=-1).fit, X, y ) - for validation_split in (-1, 0): + for validation_fraction in (-1, 0): assert_raises_regex( ValueError, - f"validation_split={validation_split} must be strictly positive", - GradientBoosting(validation_split=validation_split).fit, X, y + f"validation_fraction={validation_fraction} must be strictly positive", + GradientBoosting(validation_fraction=validation_fraction).fit, X, y ) assert_raises_regex( @@ -87,66 +87,66 @@ def test_init_parameters_validation(GradientBoosting, X, y): ) -@pytest.mark.parametrize('scoring, validation_split, n_iter_no_change, tol', [ +@pytest.mark.parametrize('scoring, validation_fraction, n_iter_no_change, tol', [ ('neg_mean_squared_error', .1, 5, 1e-7), # use scorer ('neg_mean_squared_error', None, 5, 1e-1), # use scorer on training data (None, .1, 5, 1e-7), # use loss (None, None, 5, 1e-1), # use loss on training data (None, None, None, None), # no early stopping ]) -def test_early_stopping_regression(scoring, validation_split, +def test_e(scoring, validation_fraction, n_iter_no_change, tol): - max_iter = 500 + n_estimators = 500 X, y = make_regression(random_state=0) gb = FastGradientBoostingRegressor(verbose=1, # just for coverage scoring=scoring, tol=tol, - validation_split=validation_split, - max_iter=max_iter, + validation_fraction=validation_fraction, + n_estimators=n_estimators, n_iter_no_change=n_iter_no_change, random_state=0) gb.fit(X, y) if n_iter_no_change is not None: - assert n_iter_no_change <= gb.n_iter_ < max_iter + assert n_iter_no_change <= gb.n_iter_ < n_estimators else: - assert gb.n_iter_ == max_iter + assert gb.n_iter_ == n_estimators @pytest.mark.parametrize('data', ( make_classification(random_state=0), make_classification(n_classes=3, n_clusters_per_class=1, random_state=0) )) -@pytest.mark.parametrize('scoring, validation_split, n_iter_no_change, tol', [ +@pytest.mark.parametrize('scoring, validation_fraction, n_iter_no_change, tol', [ ('accuracy', .1, 5, 1e-7), # use scorer ('accuracy', None, 5, 1e-1), # use scorer on training data (None, .1, 5, 1e-7), # use loss (None, None, 5, 1e-1), # use loss on training data (None, None, None, None), # no early stopping ]) -def test_early_stopping_classification(data, scoring, validation_split, +def test_early_stopping_classification(data, scoring, validation_fraction, n_iter_no_change, tol): - max_iter = 500 + n_estimators = 500 X, y = data gb = FastGradientBoostingClassifier(verbose=1, # just for coverage scoring=scoring, tol=tol, - validation_split=validation_split, - max_iter=max_iter, + validation_fraction=validation_fraction, + n_estimators=n_estimators, n_iter_no_change=n_iter_no_change, random_state=0) gb.fit(X, y) if n_iter_no_change is not None: - assert n_iter_no_change <= gb.n_iter_ < max_iter + assert n_iter_no_change <= gb.n_iter_ < n_estimators else: - assert gb.n_iter_ == max_iter + assert gb.n_iter_ == n_estimators def test_should_stop(): @@ -178,7 +178,7 @@ def should_stop(scores, n_iter_no_change, tol): @pytest.mark.parametrize('Estimator', ( FastGradientBoostingRegressor(), - FastGradientBoostingClassifier(scoring=None, validation_split=None, + FastGradientBoostingClassifier(scoring=None, validation_fraction=None, min_samples_leaf=5), )) def test_estimator_checks(Estimator): @@ -187,7 +187,7 @@ def test_estimator_checks(Estimator): # Default parameters to the estimators have to be changed to pass the # tests: # - Can't do early stopping with classifier because often - # validation_split=.1 leads to test_size=2 < n_classes and + # validation_fraction=.1 leads to test_size=2 < n_classes and # train_test_split raises an error. # - Also, need to set a low min_samples_leaf for # check_classifiers_classes() to pass: with only 30 samples on the diff --git a/sklearn/_fast_gradient_boosting/utils.py b/sklearn/_fast_gradient_boosting/utils.py index 3481cba080f8d..f9c9b59f42849 100644 --- a/sklearn/_fast_gradient_boosting/utils.py +++ b/sklearn/_fast_gradient_boosting/utils.py @@ -1,4 +1,5 @@ """This module contains utility routines.""" +from .binning import BinMapper def get_lightgbm_estimator(pygbm_estimator): @@ -30,7 +31,7 @@ def get_lightgbm_estimator(pygbm_estimator): lgbm_params = { 'objective': loss_mapping[pygbm_params['loss']], 'learning_rate': pygbm_params['learning_rate'], - 'n_estimators': pygbm_params['max_iter'], + 'n_estimators': pygbm_params['n_estimators'], 'num_leaves': pygbm_params['max_leaf_nodes'], 'max_depth': pygbm_params['max_depth'], 'min_data_in_leaf': pygbm_params['min_samples_leaf'], @@ -41,6 +42,9 @@ def get_lightgbm_estimator(pygbm_estimator): 'min_gain_to_split': 0, 'verbosity': 10 if pygbm_params['verbose'] else 0, 'boost_from_average': True, + 'enable_bundle': False, # also makes feature order consistent + 'min_data_in_bin': 1, + 'bin_construct_sample_cnt': BinMapper().subsample, } # TODO: change hardcoded values when / if they're arguments to the # estimator.