forked from autogluon/autogluon
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from ._abstract_callback import AbstractCallback | ||
from ._early_stopping_callback import EarlyStoppingCallback | ||
from ._example_callback import ExampleCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
core/src/autogluon/core/callbacks/_early_stopping_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from logging import Logger | ||
from typing import Tuple | ||
|
||
from ..trainer import AbstractTrainer | ||
from ._abstract_callback import AbstractCallback | ||
|
||
|
||
class EarlyStoppingCallback(AbstractCallback): | ||
""" | ||
A simple early stopping callback. | ||
Will early stop AutoGluon's training process after `patience` number of models fitted sequentially without improvement to score_val. | ||
Parameters | ||
---------- | ||
patience : int, default = 10 | ||
The number of models fit in a row without improvement in score_val before early stopping the training process. | ||
verbose : bool, default = False | ||
If True, will log a stopping message when early stopping triggers. | ||
""" | ||
|
||
def __init__(self, patience: int = 10, verbose: bool = False): | ||
self.patience = patience | ||
self.last_improvement = 0 | ||
self.score_best = None | ||
self.verbose = verbose | ||
|
||
def before_fit(self, logger: Logger, **kwargs) -> Tuple[bool, bool]: | ||
early_stop = self._early_stop() | ||
if self.verbose and early_stop: | ||
msg = f"Stopping trainer fit due to callback early stopping. Reason: No score_val improvement in the past {self.last_improvement} models." | ||
self._log(logger, 20, msg=msg) | ||
return early_stop, False | ||
|
||
def after_fit(self, trainer: AbstractTrainer, logger: Logger, **kwargs) -> bool: | ||
self._calc_new_best(trainer=trainer) | ||
early_stop = self._early_stop() | ||
if self.verbose and early_stop: | ||
msg = f"Stopping trainer fit due to callback early stopping. Reason: No score_val improvement in the past {self.last_improvement} models." | ||
self._log(logger, 20, msg=msg) | ||
return early_stop | ||
|
||
def _calc_new_best(self, trainer: AbstractTrainer): | ||
leaderboard = trainer.leaderboard() | ||
if len(leaderboard) == 0: | ||
score_cur = None | ||
else: | ||
score_cur = leaderboard["score_val"].max() | ||
if score_cur is None: | ||
self.last_improvement += 1 | ||
elif self.score_best is None or score_cur > self.score_best: | ||
self.score_best = score_cur | ||
self.last_improvement = 0 | ||
else: | ||
self.last_improvement += 1 | ||
|
||
def _early_stop(self): | ||
if self.last_improvement >= self.patience: | ||
return True | ||
else: | ||
return False | ||
|
||
def _log(self, logger: Logger, level, msg: str): | ||
msg = f"{self.__class__.__name__}: {msg}" | ||
logger.log( | ||
level, | ||
msg, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import time | ||
from logging import Logger | ||
from typing import List, Tuple | ||
|
||
import pandas as pd | ||
|
||
from ..models import AbstractModel | ||
from ..trainer import AbstractTrainer | ||
from ._abstract_callback import AbstractCallback | ||
|
||
|
||
class ExampleCallback(AbstractCallback): | ||
""" | ||
Example callback showcasing how to access and log information from the trainer. | ||
""" | ||
|
||
def before_fit( | ||
self, | ||
trainer: AbstractTrainer, | ||
model: AbstractModel, | ||
logger: Logger, | ||
time_limit: float | None = None, | ||
stack_name: str = "core", | ||
level: int = 1, | ||
**kwargs, | ||
) -> Tuple[bool, bool]: | ||
time_limit_trainer = trainer._time_limit | ||
if time_limit_trainer is not None and trainer._time_train_start is not None: | ||
time_left_total = time_limit_trainer - (time.time() - trainer._time_train_start) | ||
else: | ||
time_left_total = None | ||
|
||
time_limit_log = f"\ttime_limit = {time_limit:.1f}\t(model)\n" if time_limit else "" | ||
time_limit_trainer_log = f"\ttime_limit = {time_limit_trainer:.1f}\t(trainer)\n" if time_limit_trainer else "" | ||
time_left_log = f"\ttime_left = {time_left_total:.1f}\t(trainer)\n" if time_left_total else "" | ||
time_used_log = f"\ttime_used = {time_limit_trainer - time_left_total:.1f}\t(trainer)\n" if time_limit_trainer else "" | ||
logger.log( | ||
20, | ||
f"{self.__class__.__name__}: before_fit\n" | ||
f"\tmodel = {model.name}\n" | ||
f"{time_limit_log}" | ||
f"{time_limit_trainer_log}" | ||
f"{time_left_log}" | ||
f"{time_used_log}" | ||
f"\tmodels_fit = {len(trainer.get_model_names())}\n" | ||
f"\tstack_name = {stack_name}\n" | ||
f"\tlevel = {level}", | ||
) | ||
|
||
return False, False | ||
|
||
def after_fit( | ||
self, | ||
trainer: AbstractTrainer, | ||
logger: Logger, | ||
**kwargs, | ||
) -> bool: | ||
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000): | ||
logger.log(20, f"{self.__class__.__name__}: after_fit | Leaderboard:\n{trainer.leaderboard()}") | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters