Skip to content

Commit

Permalink
Add early_stopping_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Jul 19, 2024
1 parent aeb82c1 commit 445b998
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 0 deletions.
2 changes: 2 additions & 0 deletions core/src/autogluon/core/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from ._abstract_callback import AbstractCallback
from ._early_stopping_callback import EarlyStoppingCallback
from ._example_callback import ExampleCallback
14 changes: 14 additions & 0 deletions core/src/autogluon/core/callbacks/_abstract_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@


class AbstractCallback(object, metaclass=ABCMeta):
"""
Abstract callback class for AutoGluon's TabularPredictor.
The inner API and logic within `trainer` is considered private API. It may change without warning between releases.
Examples
--------
>>> from autogluon.core.callbacks import ExampleCallback
>>> from autogluon.tabular import TabularDataset, TabularPredictor
>>> callbacks = [ExampleCallback()]
>>> train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')
>>> label = 'class'
>>> predictor = TabularPredictor(label=label).fit(train_data, callbacks=callbacks)
"""

@abstractmethod
def before_fit(
self,
Expand Down
68 changes: 68 additions & 0 deletions core/src/autogluon/core/callbacks/_early_stopping_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from logging import Logger
from typing import Tuple

from ..trainer import AbstractTrainer
from ._abstract_callback import AbstractCallback


class EarlyStoppingCallback(AbstractCallback):
"""
A simple early stopping callback.
Will early stop AutoGluon's training process after `patience` number of models fitted sequentially without improvement to score_val.
Parameters
----------
patience : int, default = 10
The number of models fit in a row without improvement in score_val before early stopping the training process.
verbose : bool, default = False
If True, will log a stopping message when early stopping triggers.
"""

def __init__(self, patience: int = 10, verbose: bool = False):
self.patience = patience
self.last_improvement = 0
self.score_best = None
self.verbose = verbose

def before_fit(self, logger: Logger, **kwargs) -> Tuple[bool, bool]:
early_stop = self._early_stop()
if self.verbose and early_stop:
msg = f"Stopping trainer fit due to callback early stopping. Reason: No score_val improvement in the past {self.last_improvement} models."
self._log(logger, 20, msg=msg)
return early_stop, False

def after_fit(self, trainer: AbstractTrainer, logger: Logger, **kwargs) -> bool:
self._calc_new_best(trainer=trainer)
early_stop = self._early_stop()
if self.verbose and early_stop:
msg = f"Stopping trainer fit due to callback early stopping. Reason: No score_val improvement in the past {self.last_improvement} models."
self._log(logger, 20, msg=msg)
return early_stop

def _calc_new_best(self, trainer: AbstractTrainer):
leaderboard = trainer.leaderboard()
if len(leaderboard) == 0:
score_cur = None
else:
score_cur = leaderboard["score_val"].max()
if score_cur is None:
self.last_improvement += 1
elif self.score_best is None or score_cur > self.score_best:
self.score_best = score_cur
self.last_improvement = 0
else:
self.last_improvement += 1

def _early_stop(self):
if self.last_improvement >= self.patience:
return True
else:
return False

def _log(self, logger: Logger, level, msg: str):
msg = f"{self.__class__.__name__}: {msg}"
logger.log(
level,
msg,
)
60 changes: 60 additions & 0 deletions core/src/autogluon/core/callbacks/_example_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time
from logging import Logger
from typing import List, Tuple

import pandas as pd

from ..models import AbstractModel
from ..trainer import AbstractTrainer
from ._abstract_callback import AbstractCallback


class ExampleCallback(AbstractCallback):
"""
Example callback showcasing how to access and log information from the trainer.
"""

def before_fit(
self,
trainer: AbstractTrainer,
model: AbstractModel,
logger: Logger,
time_limit: float | None = None,
stack_name: str = "core",
level: int = 1,
**kwargs,
) -> Tuple[bool, bool]:
time_limit_trainer = trainer._time_limit
if time_limit_trainer is not None and trainer._time_train_start is not None:
time_left_total = time_limit_trainer - (time.time() - trainer._time_train_start)
else:
time_left_total = None

time_limit_log = f"\ttime_limit = {time_limit:.1f}\t(model)\n" if time_limit else ""
time_limit_trainer_log = f"\ttime_limit = {time_limit_trainer:.1f}\t(trainer)\n" if time_limit_trainer else ""
time_left_log = f"\ttime_left = {time_left_total:.1f}\t(trainer)\n" if time_left_total else ""
time_used_log = f"\ttime_used = {time_limit_trainer - time_left_total:.1f}\t(trainer)\n" if time_limit_trainer else ""
logger.log(
20,
f"{self.__class__.__name__}: before_fit\n"
f"\tmodel = {model.name}\n"
f"{time_limit_log}"
f"{time_limit_trainer_log}"
f"{time_left_log}"
f"{time_used_log}"
f"\tmodels_fit = {len(trainer.get_model_names())}\n"
f"\tstack_name = {stack_name}\n"
f"\tlevel = {level}",
)

return False, False

def after_fit(
self,
trainer: AbstractTrainer,
logger: Logger,
**kwargs,
) -> bool:
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000):
logger.log(20, f"{self.__class__.__name__}: after_fit | Leaderboard:\n{trainer.leaderboard()}")
return False
4 changes: 4 additions & 0 deletions core/src/autogluon/core/trainer/abstract_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ def stack_new_level_core(
If self.bagged_mode, then models will be trained as StackerEnsembleModels.
The data provided in this method should not contain stack features, as they will be automatically generated if necessary.
"""
if self._callback_early_stop:
return []
if get_models_func is None:
get_models_func = self.construct_model_templates
if base_model_names is None:
Expand Down Expand Up @@ -826,6 +828,8 @@ def stack_new_level_aux(
Level must be greater than the level of any of the base models.
Auxiliary models never use the original features and only train with the predictions of other models as features.
"""
if self._callback_early_stop:
return []
if fit_weighted_ensemble is False:
# Skip fitting of aux models
return []
Expand Down

0 comments on commit 445b998

Please sign in to comment.