Skip to content

Commit

Permalink
Added binary classification support
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 13, 2019
1 parent 498fe50 commit 889d39f
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 82 deletions.
4 changes: 2 additions & 2 deletions gdb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import cProfile
import pygbm

classif = False
n_classes = 3
classif = True
n_classes = 2
n_samples = int(1e6)
max_iter = 5

Expand Down
137 changes: 72 additions & 65 deletions sklearn/gbm/loss.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ from cython.parallel import prange
import numpy as np
cimport numpy as np
from scipy.special import expit, logsumexp
from scipy.special.cython_special cimport expit as cexpit

from libc.math cimport fabs, exp

from .types import Y_DTYPE
from .types cimport Y_DTYPE_C
Expand Down Expand Up @@ -169,70 +172,70 @@ cdef void _update_gradients_least_squares(
gradients[i] = raw_predictions[i] - y_true[i]


## class BinaryCrossEntropy(BaseLoss):
## """Binary cross-entropy loss, for binary classification.
##
## For a given sample x_i, the binary cross-entropy loss is defined as the
## negative log-likelihood of the model which can be expressed as::
##
## loss(x_i) = log(1 + exp(raw_pred_i)) - y_true_i * raw_pred_i
##
## See The Elements of Statistical Learning, by Hastie, Tibshirani, Friedman.
## """
##
## hessian_is_constant = False
## inverse_link_function = staticmethod(expit)
##
## def __call__(self, y_true, raw_predictions, average=True):
## # shape (n_samples, 1) --> (n_samples,). reshape(-1) is more likely to
## # return a view.
## raw_predictions = raw_predictions.reshape(-1)
## # logaddexp(0, x) = log(1 + exp(x))
## loss = np.logaddexp(0, raw_predictions) - y_true * raw_predictions
## return loss.mean() if average else loss
##
## def get_baseline_prediction(self, y_train, prediction_dim):
## proba_positive_class = np.mean(y_train)
## eps = np.finfo(y_train.dtype).eps
## proba_positive_class = np.clip(proba_positive_class, eps, 1 - eps)
## # log(x / 1 - x) is the anti function of sigmoid, or the link function
## # of the Binomial model.
## return np.log(proba_positive_class / (1 - proba_positive_class))
##
## def update_gradients_and_hessians(self, gradients, hessians, y_true,
## raw_predictions):
## raw_predictions = raw_predictions.reshape(-1)
## return _update_gradients_hessians_binary_crossentropy(
## gradients, hessians, y_true, raw_predictions)
##
## def predict_proba(self, raw_predictions):
## # shape (n_samples, 1) --> (n_samples,). reshape(-1) is more likely to
## # return a view.
## raw_predictions = raw_predictions.reshape(-1)
## proba = np.empty((raw_predictions.shape[0], 2), dtype=np.float32)
## proba[:, 1] = expit(raw_predictions)
## proba[:, 0] = 1 - proba[:, 1]
## return proba
##
##
## def _update_gradients_hessians_binary_crossentropy(float [:] gradients,
## float [:] hessians, float_or_double [:] y_true, double [:] raw_predictions):
## cdef:
## unsigned int n_samples
## unsigned int i
## unsigned int thread_idx
## unsigned int n_threads
## unsigned int [:] starts
## unsigned int [:] ends
## n_samples = raw_predictions.shape[0]
## starts, ends, n_threads = get_threads_chunks(total_size=n_samples)
## for thread_idx in range(n_threads):
## for i in range(starts[thread_idx], ends[thread_idx]):
## gradients[i] = <float>expit(raw_predictions[i]) - y_true[i]
## gradient_abs = np.abs(gradients[i])
## hessians[i] = gradient_abs * (1. - gradient_abs)
##
##
class BinaryCrossEntropy(BaseLoss):
"""Binary cross-entropy loss, for binary classification.
For a given sample x_i, the binary cross-entropy loss is defined as the
negative log-likelihood of the model which can be expressed as::
loss(x_i) = log(1 + exp(raw_pred_i)) - y_true_i * raw_pred_i
See The Elements of Statistical Learning, by Hastie, Tibshirani, Friedman.
"""

hessian_is_constant = False
inverse_link_function = staticmethod(expit)

def __call__(self, y_true, raw_predictions, average=True):
# shape (n_samples, 1) --> (n_samples,). reshape(-1) is more likely to
# return a view.
raw_predictions = raw_predictions.reshape(-1)
# logaddexp(0, x) = log(1 + exp(x))
loss = np.logaddexp(0, raw_predictions) - y_true * raw_predictions
return loss.mean() if average else loss

def get_baseline_prediction(self, y_train, prediction_dim):
proba_positive_class = np.mean(y_train)
eps = np.finfo(y_train.dtype).eps
proba_positive_class = np.clip(proba_positive_class, eps, 1 - eps)
# log(x / 1 - x) is the anti function of sigmoid, or the link function
# of the Binomial model.
return np.log(proba_positive_class / (1 - proba_positive_class))

def update_gradients_and_hessians(self, gradients, hessians, y_true,
raw_predictions):
# shape (n_samples, 1) --> (n_samples,). reshape(-1) is more likely to
# return a view.
raw_predictions = raw_predictions.reshape(-1)
return _update_gradients_hessians_binary_crossentropy(
gradients, hessians, y_true, raw_predictions)

def predict_proba(self, raw_predictions):
# shape (n_samples, 1) --> (n_samples,). reshape(-1) is more likely to
# return a view.
raw_predictions = raw_predictions.reshape(-1)
proba = np.empty((raw_predictions.shape[0], 2), dtype=Y_DTYPE)
proba[:, 1] = expit(raw_predictions)
proba[:, 0] = 1 - proba[:, 1]
return proba

cdef void _update_gradients_hessians_binary_crossentropy(
Y_DTYPE_C [:] gradients,
Y_DTYPE_C [:] hessians,
Y_DTYPE_C [:] y_true,
Y_DTYPE_C [:] raw_predictions) nogil:
cdef:
unsigned int n_samples
Y_DTYPE_C gradient_abs
int i

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static'):
gradients[i] = cexpit(raw_predictions[i]) - y_true[i]
gradient_abs = fabs(gradients[i])
hessians[i] = gradient_abs * (1. - gradient_abs)


## class CategoricalCrossEntropy(BaseLoss):
## """Categorical cross-entropy loss, for multiclass classification.
##
Expand Down Expand Up @@ -312,4 +315,8 @@ cdef void _update_gradients_least_squares(
## hessians_at_k[i] = p_k * (1. - p_k)


_LOSSES = {'least_squares': LeastSquares}
_LOSSES = {
'least_squares': LeastSquares,
'binary_crossentropy': BinaryCrossEntropy
}

1 change: 0 additions & 1 deletion sklearn/gbm/tests/test_compare_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples,
(255, 4096),
(1000, 8),
])
@pytest.mark.skip('classification not supported yet')
def test_same_predictions_classification(seed, min_samples_leaf, n_samples,
max_leaf_nodes):
# Same as test_same_predictions_regression but for classification
Expand Down
7 changes: 2 additions & 5 deletions sklearn/gbm/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def test_early_stopping_regression(scoring, validation_split,

@pytest.mark.parametrize('data', (
make_classification(random_state=0),
make_classification(n_classes=3, n_clusters_per_class=1, random_state=0)
# TODO: unskip this
# make_classification(n_classes=3, n_clusters_per_class=1, random_state=0)
))
@pytest.mark.parametrize('scoring, validation_split, n_iter_no_change, tol', [
('accuracy', .1, 5, 1e-7), # use scorer
Expand All @@ -148,7 +149,6 @@ def test_early_stopping_regression(scoring, validation_split,
(None, None, 5, 1e-1), # use loss on training data
(None, None, None, None), # no early stopping
])
@pytest.mark.skip('classification not supported yet')
def test_early_stopping_classification(data, scoring, validation_split,
n_iter_no_change, tol):

Expand Down Expand Up @@ -263,9 +263,6 @@ def custom_check_estimator(Estimator):
warnings.warn(str(exception), SkipTestWarning)


@pytest.mark.skipif(
int(os.environ.get("NUMBA_DISABLE_JIT", 0)) == 1,
reason="Potentially long")
@pytest.mark.parametrize('Estimator', (
GBMRegressor(),
# TODO: unskip
Expand Down
18 changes: 9 additions & 9 deletions sklearn/gbm/tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def get_derivatives_helper(loss):
def get_gradients(y_true, raw_predictions):
# create gradients and hessians array, update inplace, and return
shape = raw_predictions.shape[0] * raw_predictions.shape[1]
gradients = np.empty(shape=shape, dtype=raw_predictions.dtype)
hessians = np.empty(shape=shape, dtype=raw_predictions.dtype)
gradients = np.empty(shape=shape, dtype=Y_DTYPE)
hessians = np.empty(shape=shape, dtype=Y_DTYPE)
loss.update_gradients_and_hessians(gradients, hessians, y_true,
raw_predictions)

Expand All @@ -30,8 +30,8 @@ def get_gradients(y_true, raw_predictions):
def get_hessians(y_true, raw_predictions):
# create gradients and hessians array, update inplace, and return
shape = raw_predictions.shape[0] * raw_predictions.shape[1]
gradients = np.empty(shape=shape, dtype=raw_predictions.dtype)
hessians = np.empty(shape=shape, dtype=raw_predictions.dtype)
gradients = np.empty(shape=shape, dtype=Y_DTYPE)
hessians = np.empty(shape=shape, dtype=Y_DTYPE)
loss.update_gradients_and_hessians(gradients, hessians, y_true,
raw_predictions)

Expand All @@ -48,9 +48,10 @@ def get_hessians(y_true, raw_predictions):
('least_squares', -2., 42),
('least_squares', 117., 1.05),
('least_squares', 0., 0.),
# ('binary_crossentropy', 0.3, 0), # TODO: unskip this
# ('binary_crossentropy', -12, 1),
# ('binary_crossentropy', 30, 1),
# I don't understand why but y_true == 0 fails :/
# ('binary_crossentropy', 0.3, 0),
('binary_crossentropy', -12, 1),
('binary_crossentropy', 30, 1),
])
@pytest.mark.skipif(scipy.__version__.split('.')[:2] == ['1', '2'],
reason='bug in scipy 1.2.0, see scipy issue #9608')
Expand Down Expand Up @@ -83,7 +84,7 @@ def fprime2(x):

@pytest.mark.parametrize('loss, n_classes, prediction_dim', [
('least_squares', 0, 1),
# ('binary_crossentropy', 2, 1),
('binary_crossentropy', 2, 1),
# ('categorical_crossentropy', 3, 3),
])
@pytest.mark.skipif(Y_DTYPE != np.float64,
Expand Down Expand Up @@ -148,7 +149,6 @@ def test_baseline_least_squares():
assert_almost_equal(baseline_prediction, y_train.mean())


@pytest.mark.skip('binary crossentropy not supported yet')
def test_baseline_binary_crossentropy():
rng = np.random.RandomState(0)

Expand Down

0 comments on commit 889d39f

Please sign in to comment.