Skip to content

Commit

Permalink
remove .pth for early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Nov 11, 2021
1 parent e493d99 commit a6e9317
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
22 changes: 12 additions & 10 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import logging.handlers
import os
import shutil
import tempfile
import time
from typing import Any, Dict, List, Optional, Tuple, cast
Expand Down Expand Up @@ -351,6 +352,16 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
)
self.save_model_for_ensemble()

# As training have finished, load the best weight
if self.checkpoint_dir is not None:
best_path = os.path.join(self.checkpoint_dir, 'best.pth')
self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}")
# We will stop the training. Load the last best performing weights
X['network'].load_state_dict(torch.load(best_path))

# Clean the temp dir
shutil.rmtree(self.checkpoint_dir)

self.logger.info(f"Finished training with {self.run_summary.repr_last_epoch()}")

# Tag as fitted
Expand Down Expand Up @@ -387,16 +398,7 @@ def early_stop_handler(self, X: Dict[str, Any]) -> bool:
if epochs_since_best == 0:
torch.save(X['network'].state_dict(), best_path)

if epochs_since_best > X['early_stopping']:
self.logger.debug(f" Early stopped model {X['num_run']} on epoch {self.run_summary.get_best_epoch()}")
# We will stop the training. Load the last best performing weights
X['network'].load_state_dict(torch.load(best_path))

# Let the tempfile module clean the temp dir
self.checkpoint_dir = None
return True

return False
return epochs_since_best > cast(int, X['early_stopping'])

def eval_valid_each_epoch(self, X: Dict[str, Any]) -> bool:
"""
Expand Down
50 changes: 50 additions & 0 deletions test/test_pipeline/components/training/test_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import glob
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock

Expand Down Expand Up @@ -382,5 +385,52 @@ def test_get_set_config_space(self):
self.assertEqual(value, trainer_choice.choice.__dict__[key])


def test_early_stopping():
dataset_properties = {'task_type': 'tabular_classification', 'output_type': 'binary'}
trainer_choice = TrainerChoice(dataset_properties=dataset_properties)

def dummy_performance(*args, **kwargs):
return (-time.time(), {'accuracy': -time.time()})

# Fake the training so that the first epoch was the best one
import time
trainer_choice.choice = unittest.mock.MagicMock()
trainer_choice.choice.train_epoch = dummy_performance
trainer_choice.choice.evaluate = dummy_performance
trainer_choice.choice.on_epoch_end.return_value = False

fit_dictionary = {
'logger_port': 1000,
'budget_type': 'epochs',
'epochs': 6,
'budget': 10,
'num_run': 1,
'torch_num_threads': 1,
'early_stopping': 5,
'metrics_during_training': True,
'dataset_properties': dataset_properties,
'split_id': 0,
}
for item in ['backend', 'lr_scheduler', 'network', 'optimizer', 'train_data_loader', 'val_data_loader',
'device', 'y_train']:
fit_dictionary[item] = unittest.mock.MagicMock()

fit_dictionary['backend'].temporary_directory = tempfile.mkdtemp()
fit_dictionary['network'].state_dict.return_value = {'dummy': 1}
trainer_choice.fit(fit_dictionary)
epochs_since_best = trainer_choice.run_summary.get_last_epoch() - trainer_choice.run_summary.get_best_epoch()

# Six epochs ran
assert len(trainer_choice.run_summary.performance_tracker['val_metrics']) == 6

# But the best performance was achieved on the first epoch
assert epochs_since_best == 0

# No files are left after training
left_files = glob.glob(f"{fit_dictionary['backend'].temporary_directory}/*")
assert len(left_files) == 0
shutil.rmtree(fit_dictionary['backend'].temporary_directory)


if __name__ == '__main__':
unittest.main()

0 comments on commit a6e9317

Please sign in to comment.