Skip to content

Commit

Permalink
checkpoint before changing scoring param
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 17, 2019
1 parent 2fd29e1 commit 5a82534
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 94 deletions.
120 changes: 58 additions & 62 deletions sklearn/_fast_gradient_boosting/gradient_boosting.py
Expand Up @@ -24,20 +24,20 @@ class BaseFastGradientBoosting(BaseEstimator, ABC):
"""Base class for fast gradient boosting estimators."""

@abstractmethod
def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
def __init__(self, loss, learning_rate, n_estimators, max_leaf_nodes,
max_depth, min_samples_leaf, l2_regularization, max_bins,
scoring, validation_split, n_iter_no_change, tol, verbose,
scoring, validation_fraction, n_iter_no_change, tol, verbose,
random_state):
self.loss = loss
self.learning_rate = learning_rate
self.max_iter = max_iter
self.n_estimators = n_estimators
self.max_leaf_nodes = max_leaf_nodes
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.l2_regularization = l2_regularization
self.max_bins = max_bins
self.n_iter_no_change = n_iter_no_change
self.validation_split = validation_split
self.validation_fraction = validation_fraction
self.scoring = scoring
self.tol = tol
self.verbose = verbose
Expand All @@ -58,14 +58,14 @@ def _validate_parameters(self):
if self.learning_rate <= 0:
raise ValueError(f'learning_rate={self.learning_rate} must '
f'be strictly positive')
if self.max_iter < 1:
raise ValueError(f'max_iter={self.max_iter} must '
if self.n_estimators < 1:
raise ValueError(f'n_estimators={self.n_estimators} must '
f'not be smaller than 1.')
if self.n_iter_no_change is not None and self.n_iter_no_change < 0:
raise ValueError(f'n_iter_no_change={self.n_iter_no_change} '
f'must be positive.')
if self.validation_split is not None and self.validation_split <= 0:
raise ValueError(f'validation_split={self.validation_split} '
if self.validation_fraction is not None and self.validation_fraction <= 0:
raise ValueError(f'validation_fraction={self.validation_fraction} '
f'must be strictly positive, or None.')
if self.tol is not None and self.tol < 0:
raise ValueError(f'tol={self.tol} '
Expand Down Expand Up @@ -116,19 +116,19 @@ def fit(self, X, y):
self.do_early_stopping_ = (self.n_iter_no_change is not None and
self.n_iter_no_change > 0)

if self.do_early_stopping_ and self.validation_split is not None:
if self.do_early_stopping_ and self.validation_fraction is not None:
# stratify for classification
stratify = y if hasattr(self.loss_, 'predict_proba') else None

X_binned_train, X_binned_val, y_train, y_val = train_test_split(
X_binned, y, test_size=self.validation_split,
X_binned, y, test_size=self.validation_fraction,
stratify=stratify, random_state=rng)
if X_binned_train.size == 0 or X_binned_val.size == 0:
raise ValueError(
f'Not enough data (n_samples={X_binned.shape[0]}) to '
f'perform early stopping with validation_split='
f'{self.validation_split}. Use more training data or '
f'adjust validation_split.'
f'perform early stopping with validation_fraction='
f'{self.validation_fraction}. Use more training data or '
f'adjust validation_fraction.'
)
# Predicting is faster of C-contiguous arrays, training is faster
# on Fortran arrays.
Expand All @@ -138,15 +138,15 @@ def fit(self, X, y):
X_binned_train, y_train = X_binned, y
X_binned_val, y_val = None, None

# Subsample the training set for score-based monitoring.
# Subsample the training set for early stopping
if self.do_early_stopping_:
subsample_size = 10000
subsample_size = 10000 # should we expose this?
indices = np.arange(X_binned_train.shape[0])
if X_binned_train.shape[0] > subsample_size:
indices = rng.choice(indices, subsample_size)
X_binned_small_train = X_binned_train[indices]
y_small_train = y_train[indices]
# Predicting is faster of C-contiguous arrays.
# Predicting is faster on C-contiguous arrays.
X_binned_small_train = np.ascontiguousarray(X_binned_small_train)

if self.verbose:
Expand All @@ -170,8 +170,8 @@ def fit(self, X, y):
prediction_dim=self.n_trees_per_iteration_
)

# predictors_ is a matrix of TreePredictor objects with shape
# (n_iter_, n_trees_per_iteration)
# predictors_ is a matrix (list of lists) of TreePredictor objects
# with shape (n_iter_, n_trees_per_iteration)
self.predictors_ = predictors = []

# scorer_ is a callable with signature (est, X, y) and calls
Expand All @@ -184,15 +184,15 @@ def fit(self, X, y):
self.train_scores_.append(
self._get_scores(X_binned_train, y_train))

if self.validation_split is not None:
if self.validation_fraction is not None:
self.validation_scores_.append(
self._get_scores(X_binned_val, y_val))

for iteration in range(self.max_iter):
for iteration in range(self.n_estimators):

if self.verbose:
iteration_start_time = time()
print(f"[{iteration + 1}/{self.max_iter}] ", end='',
print(f"[{iteration + 1}/{self.n_estimators}] ", end='',
flush=True)

# Update gradients and hessians, inplace
Expand Down Expand Up @@ -277,7 +277,7 @@ def _check_early_stopping(self, X_binned_train, y_train,
self.train_scores_.append(
self._get_scores(X_binned_train, y_train))

if self.validation_split is not None:
if self.validation_fraction is not None:
self.validation_scores_.append(
self._get_scores(X_binned_val, y_val))
return self._should_stop(self.validation_scores_)
Expand Down Expand Up @@ -342,7 +342,7 @@ def _print_iteration_stats(self, iteration_start_time):

if self.do_early_stopping_:
log_msg += f"{self.scoring} train: {self.train_scores_[-1]:.5f}, "
if self.validation_split is not None:
if self.validation_fraction is not None:
log_msg += (f"{self.scoring} val: "
f"{self.validation_scores_[-1]:.5f}, ")

Expand All @@ -357,8 +357,7 @@ def _raw_predict(self, X):
Parameters
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
The input samples.
Returns
-------
Expand Down Expand Up @@ -409,7 +408,7 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin):
The learning rate, also known as *shrinkage*. This is used as a
multiplicative factor for the leaves values. Use ``1`` for no
shrinkage.
max_iter : int, optional(default=100)
n_estimators : int, optional(default=100)
The maximum number of iterations of the boosting process, i.e. the
maximum number of trees.
max_leaf_nodes : int or None, optional(default=None)
Expand All @@ -428,35 +427,34 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin):
allows for a much faster training stage. Features with a small
number of unique values may use less than ``max_bins`` bins. Must be no
larger than 256.
scoring : str or callable or None, \
optional (default=None)
Scoring parameter to use for early stopping (see sklearn.metrics for
available options). If None, early stopping is check w.r.t the loss
value.
validation_split : int or float or None, optional(default=0.1)
scoring : str or callable or None, optional (default=None)
Scoring parameter to use for early stopping. It can be a single
string (see :ref:`scoring_parameter`) or a callable (see
:ref:`scoring`). If None, the estimator's default scorer (if
available) is used. If ``scoring='loss'``, early stopping is checked
w.r.t the loss value. Only used if ``n_iter_no_change`` is not None.
validation_fraction : int or float or None, optional(default=0.1)
Proportion (or absolute size) of training data to set aside as
validation data for early stopping. If None, early stopping is done on
the training data.
the training data. Only used if ``n_iter_no_change`` is not None.
n_iter_no_change : int or None, optional (default=5)
Used to determine when to "early stop". The fitting process is
stopped when none of the last ``n_iter_no_change`` scores are better
than the ``n_iter_no_change - 1``th-to-last one, up to some
tolerance. If None or 0, no early-stopping is done.
tol : float or None optional (default=1e-7)
The absolute tolerance to use when comparing scores. The higher the
tolerance, the more likely we are to early stop: higher tolerance
means that it will be harder for subsequent iterations to be
considered an improvement upon the reference score.
The absolute tolerance to use when comparing scores during early
stopping. The higher the tolerance, the more likely we are to early
stop: higher tolerance means that it will be harder for subsequent
iterations to be considered an improvement upon the reference score.
verbose: int, optional (default=0)
The verbosity level. If not zero, print some information about the
fitting process.
random_state : int, np.random.RandomStateInstance or None, \
optional (default=None)
Pseudo-random number generator to control the subsampling in the
binning process, and the train/validation data split if early stopping
is enabled. See
`scikit-learn glossary
<https://scikit-learn.org/stable/glossary.html#term-random-state>`_.
is enabled. See :term:`random_state`.
Examples
Expand All @@ -472,16 +470,16 @@ class FastGradientBoostingRegressor(BaseFastGradientBoosting, RegressorMixin):
_VALID_LOSSES = ('least_squares',)

def __init__(self, loss='least_squares', learning_rate=0.1,
max_iter=100, max_leaf_nodes=31, max_depth=None,
n_estimators=100, max_leaf_nodes=31, max_depth=None,
min_samples_leaf=20, l2_regularization=0., max_bins=256,
scoring=None, validation_split=0.1, n_iter_no_change=5,
scoring=None, validation_fraction=0.1, n_iter_no_change=5,
tol=1e-7, verbose=0, random_state=None):
super(FastGradientBoostingRegressor, self).__init__(
loss=loss, learning_rate=learning_rate, max_iter=max_iter,
loss=loss, learning_rate=learning_rate, n_estimators=n_estimators,
max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization, max_bins=max_bins,
scoring=scoring, validation_split=validation_split,
scoring=scoring, validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
random_state=random_state)

Expand All @@ -491,8 +489,7 @@ def predict(self, X):
Parameters
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
The input samples.
Returns
-------
Expand All @@ -504,7 +501,7 @@ def predict(self, X):
return self._raw_predict(X).ravel()

def _encode_y(self, y):
# Just convert y to float32
# Just convert y to the expected dtype
self.n_trees_per_iteration_ = 1
y = y.astype(Y_DTYPE, copy=False)
return y
Expand All @@ -530,7 +527,7 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting,
The learning rate, also known as *shrinkage*. This is used as a
multiplicative factor for the leaves values. Use ``1`` for no
shrinkage.
max_iter : int, optional(default=100)
n_estimators : int, optional(default=100)
The maximum number of iterations of the boosting process, i.e. the
maximum number of trees for binary classification. For multiclass
classification, `n_classes` trees per iteration are built.
Expand All @@ -551,10 +548,12 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting,
number of unique values may use less than ``max_bins`` bins. Must be no
larger than 256.
scoring : str or callable or None, optional (default=None)
Scoring parameter to use for early stopping (see sklearn.metrics for
available options). If None, early stopping is check w.r.t the loss
value.
validation_split : int or float or None, optional(default=0.1)
Scoring parameter to use for early stopping. It can be a single
string (see :ref:`scoring_parameter`) or a callable (see
:ref:`scoring`). If None, the estimator's default scorer (if
available) is used. If ``scoring='loss'``, early stopping is checked
w.r.t the loss value. Only used if ``n_iter_no_change`` is not None.
validation_fraction : int or float or None, optional(default=0.1)
Proportion (or absolute size) of training data to set aside as
validation data for early stopping. If None, early stopping is done on
the training data.
Expand All @@ -575,8 +574,7 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting,
optional(default=None)
Pseudo-random number generator to control the subsampling in the
binning process, and the train/validation data split if early stopping
is enabled. See `scikit-learn glossary
<https://scikit-learn.org/stable/glossary.html#term-random-state>`_.
is enabled. See :term:`random_state`.
Examples
--------
Expand All @@ -591,17 +589,17 @@ class FastGradientBoostingClassifier(BaseFastGradientBoosting,
_VALID_LOSSES = ('binary_crossentropy', 'categorical_crossentropy',
'auto')

def __init__(self, loss='auto', learning_rate=0.1, max_iter=100,
def __init__(self, loss='auto', learning_rate=0.1, n_estimators=100,
max_leaf_nodes=31, max_depth=None, min_samples_leaf=20,
l2_regularization=0., max_bins=256, scoring=None,
validation_split=0.1, n_iter_no_change=5, tol=1e-7,
validation_fraction=0.1, n_iter_no_change=5, tol=1e-7,
verbose=0, random_state=None):
super(FastGradientBoostingClassifier, self).__init__(
loss=loss, learning_rate=learning_rate, max_iter=max_iter,
loss=loss, learning_rate=learning_rate, n_estimators=n_estimators,
max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization, max_bins=max_bins,
scoring=scoring, validation_split=validation_split,
scoring=scoring, validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
random_state=random_state)

Expand All @@ -611,8 +609,7 @@ def predict(self, X):
Parameters
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
The input samples.
Returns
-------
Expand All @@ -629,8 +626,7 @@ def predict_proba(self, X):
Parameters
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
The input samples.
Returns
-------
Expand Down
3 changes: 1 addition & 2 deletions sklearn/_fast_gradient_boosting/grower.py
Expand Up @@ -211,8 +211,7 @@ def _validate_parameters(self, X_binned, max_leaf_nodes, max_depth,
l2_regularization, min_hessian_to_split):
"""Validate parameters passed to __init__.
Also validate parameters passed to splitter because we cannot
raise exceptions in a jitclass.
Also validate parameters passed to splitter.
"""
if X_binned.dtype != np.uint8:
raise NotImplementedError(
Expand Down
10 changes: 9 additions & 1 deletion sklearn/_fast_gradient_boosting/histogram.pxd
@@ -1,9 +1,17 @@
# cython: language_level=3
"""This module contains njitted routines for building histograms.
"""This module contains routines for building histograms.
A histogram is an array with n_bins entry of type HISTOGRAM_DTYPE. Each
feature has its own histogram. A histogram contains the sum of gradients and
hessians of all the samples belonging to each bin.
There are different ways to build a histogram:
- by subtraction: hist(child) = hist(parent) - hist(sibling)
- from scratch. In this case we have rountines that update the hessians or not
(not useful when hessians are constant for some losses e.g. least squares).
Also, there's a special case for the root which contains all the samples,
leading to some possible optimizations. Overall all the implementations look
the same, and are optimized for cache hit.
"""
import numpy as np
cimport numpy as np
Expand Down
2 changes: 1 addition & 1 deletion sklearn/_fast_gradient_boosting/histogram.pyx
Expand Up @@ -2,7 +2,7 @@
# cython: boundscheck=False
# cython: wraparound=False
# cython: language_level=3
"""This module contains njitted routines for building histograms.
"""This module contains routines for building histograms.
A histogram is an array with n_bins entry of type HISTOGRAM_DTYPE. Each
feature has its own histogram. A histogram contains the sum of gradients and
Expand Down

0 comments on commit 5a82534

Please sign in to comment.