Skip to content

Commit

Permalink
[FIX] remove .pth for early stopping (#321)
Browse files Browse the repository at this point in the history
* remove .pth for early stopping

* Fix tests

* Update test/test_pipeline/components/training/test_training.py

* refactor to function
  • Loading branch information
ravinkohli committed Nov 15, 2021
1 parent cab1a73 commit 28e1d47
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
34 changes: 24 additions & 10 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import logging.handlers
import os
import shutil
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, cast
Expand Down Expand Up @@ -351,13 +352,35 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
)
self.save_model_for_ensemble()

# As training have finished, load the best weight
if self.checkpoint_dir is not None:
self._load_best_weights_and_clean_checkpoints(X)

self.logger.info(f"Finished training with {self.run_summary.repr_last_epoch()}")

# Tag as fitted
self.fitted_ = True

return self

def _load_best_weights_and_clean_checkpoints(self, X: Dict[str, Any]) -> None:
"""
Load the best model until the last epoch and delete all the files for checkpoints.
Args:
X (Dict[str, Any]): Dependencies needed by current component to perform fit
"""
assert self.checkpoint_dir is not None # mypy
assert self.run_summary is not None # mypy

best_path = os.path.join(self.checkpoint_dir, 'best.pth')
self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}")
# We will stop the training. Load the last best performing weights
X['network'].load_state_dict(torch.load(best_path))

# Clean the temp dir
shutil.rmtree(self.checkpoint_dir)

def early_stop_handler(self, X: Dict[str, Any]) -> bool:
"""
If early stopping is enabled, this procedure stops the training after a
Expand Down Expand Up @@ -387,16 +410,7 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
if epochs_since_best == 0:
torch.save(X['network'].state_dict(), best_path)

if epochs_since_best > X['early_stopping']:
self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}")
# We will stop the training. Load the last best performing weights
X['network'].load_state_dict(torch.load(best_path))

# Let the tempfile module clean the temp dir
self.checkpoint_dir = None
return True

return False
return epochs_since_best > cast(int, X['early_stopping'])

def eval_valid_each_epoch(self, X: Dict[str, Any]) -> bool:
"""
Expand Down
51 changes: 51 additions & 0 deletions test/test_pipeline/components/training/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import glob
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock

Expand Down Expand Up @@ -382,5 +385,53 @@ def test_get_set_config_space(self):
self.assertEqual(value, trainer_choice.choice.__dict__[key])


def test_early_stopping():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
trainer_choice = TrainerChoice(dataset_properties=dataset_properties)

def dummy_performance(*args, **kwargs):
return (-time.time(), {'accuracy': -time.time()})

# Fake the training so that the first epoch was the best one
import time
trainer_choice.choice = unittest.mock.MagicMock()
trainer_choice.choice.train_epoch = dummy_performance
trainer_choice.choice.evaluate = dummy_performance
trainer_choice.choice.on_epoch_end.return_value = False

fit_dictionary = {
'logger_port': 1000,
'budget_type': 'epochs',
'epochs': 6,
'budget': 10,
'num_run': 1,
'torch_num_threads': 1,
'early_stopping': 5,
'metrics_during_training': True,
'dataset_properties': dataset_properties,
'split_id': 0,
'step_interval': StepIntervalUnit.batch
}
for item in ['backend', 'lr_scheduler', 'network', 'optimizer', 'train_data_loader', 'val_data_loader',
'device', 'y_train']:
fit_dictionary[item] = unittest.mock.MagicMock()

fit_dictionary['backend'].temporary_directory = tempfile.mkdtemp()
fit_dictionary['network'].state_dict.return_value = {'dummy': 1}
trainer_choice.fit(fit_dictionary)
epochs_since_best = trainer_choice.run_summary.get_last_epoch() - trainer_choice.run_summary.get_best_epoch()

# Six epochs ran
assert len(trainer_choice.run_summary.performance_tracker['val_metrics']) == 6

# But the best performance was achieved on the first epoch
assert epochs_since_best == 0

# No files are left after training
left_files = glob.glob(f"{fit_dictionary['backend'].temporary_directory}/*")
assert len(left_files) == 0
shutil.rmtree(fit_dictionary['backend'].temporary_directory)


if __name__ == '__main__':
unittest.main()

0 comments on commit 28e1d47

Please sign in to comment.