Skip to content

Commit

Permalink
[python-package] reorganize early stopping callback (microsoft#6114)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored and Ten0 committed Jan 12, 2024
1 parent 051e223 commit d0b29e0
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
64 changes: 45 additions & 19 deletions python-package/lightgbm/callback.py
Expand Up @@ -229,7 +229,12 @@ def __call__(self, env: CallbackEnv) -> None:
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for booster in env.model.boosters:
booster.reset_parameter(new_parameters)
env.params.update(new_parameters)


Expand Down Expand Up @@ -267,6 +272,10 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0
) -> None:

if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")

self.order = 30
self.before_iteration = False

Expand All @@ -291,32 +300,45 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
"""Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True

# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
return True

return False

def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)

is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return

if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
env=env
)
)
if only_train_set:
self.enabled = False
_log_warning('Only training set found, disabling early stopping.')
return

if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
Expand Down Expand Up @@ -395,7 +417,11 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
Expand Down
11 changes: 11 additions & 0 deletions tests/python_package_test/test_callback.py
Expand Up @@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert callback.stopping_rounds == rounds


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
periods = 42
Expand Down
4 changes: 2 additions & 2 deletions tests/python_package_test/test_engine.py
Expand Up @@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object

def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
params = {"num_leaves": "too-many"}
dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""):
with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""):
lgb.train(params, dtrain)


Expand Down

0 comments on commit d0b29e0

Please sign in to comment.