Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 361 #367

Merged
merged 3 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions autoPyTorch/pipeline/components/training/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
writer=writer,
)

# its fine if train_loss is None due to `is_max_time_reached()`
if train_loss is None:
if self.budget_tracker.is_max_time_reached():
break
else:
raise RuntimeError("Got an unexpected None in `train_loss`.")

val_loss, val_metrics, test_loss, test_metrics = None, {}, None, {}
if self.eval_valid_each_epoch(X):
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
Expand Down Expand Up @@ -334,6 +341,9 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
if 'cuda' in X['device']:
torch.cuda.empty_cache()

if self.run_summary.is_empty():
raise RuntimeError("Budget exhausted without finishing an epoch.")

# wrap up -- add score if not evaluating every epoch
if not self.eval_valid_each_epoch(X):
val_loss, val_metrics = self.choice.evaluate(X['val_data_loader'], epoch, writer)
Expand Down
15 changes: 14 additions & 1 deletion autoPyTorch/pipeline/components/training/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ def repr_last_epoch(self) -> str:
string += '=' * 40
return string

def is_empty(self) -> bool:
"""
Checks if the object is empty or not

Returns:
bool
"""
# if train_loss is empty, we can be sure that RunSummary is empty.
return not bool(self.performance_tracker['train_loss'])


class BaseTrainerComponent(autoPyTorchTrainingComponent):

Expand Down Expand Up @@ -277,7 +287,7 @@ def _scheduler_step(

def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
writer: Optional[SummaryWriter],
) -> Tuple[float, Dict[str, float]]:
) -> Tuple[Optional[float], Dict[str, float]]:
"""
Train the model for a single epoch.

Expand Down Expand Up @@ -317,6 +327,9 @@ def train_epoch(self, train_loader: torch.utils.data.DataLoader, epoch: int,
epoch * len(train_loader) + step,
)

if N == 0:
return None, {}
Comment on lines +330 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a simple test?


self._scheduler_step(step_interval=StepIntervalUnit.epoch, loss=loss_sum / N)

if self.metrics_during_training:
Expand Down
37 changes: 37 additions & 0 deletions test/test_pipeline/components/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ def test_train_step(self):
lr = optimizer.param_groups[0]['lr']
assert lr == target_lr

def test_train_epoch_no_step(self):
"""
This test checks if max runtime is reached
for an epoch before any train_step has been
completed. In this case we would like to
return None for train_loss and an empty
dictionary for the metrics.
"""
device = torch.device('cpu')
model = torch.nn.Linear(1, 1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1)
data_loader = unittest.mock.MagicMock(spec=torch.utils.data.DataLoader)
ms = [3, 5, 6]
params = {
'metrics': [],
'device': device,
'task_type': constants.TABULAR_REGRESSION,
'labels': torch.Tensor([]),
'metrics_during_training': False,
'budget_tracker': BudgetTracker(budget_type='runtime', max_runtime=0),
'criterion': torch.nn.MSELoss,
'optimizer': optimizer,
'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=ms, gamma=2),
'model': model,
'step_interval': StepIntervalUnit.epoch
}
trainer = StandardTrainer()
trainer.prepare(**params)

loss, metrics = trainer.train_epoch(
train_loader=data_loader,
epoch=0,
writer=None
)
assert loss is None
assert metrics == {}


class TestStandardTrainer(BaseTraining):
def test_regression_epoch_training(self, n_samples):
Expand Down
28 changes: 28 additions & 0 deletions test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import unittest
import unittest.mock

from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
Expand Down Expand Up @@ -491,3 +492,30 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):

# More than 200 epochs would have pass in 5 seconds for this dataset
assert len(run_summary.performance_tracker['start_time']) > 100


@pytest.mark.parametrize("fit_dictionary_tabular_dummy", ["classification"], indirect=True)
def test_train_pipeline_with_runtime_max_reached(fit_dictionary_tabular_dummy):
"""
This test makes sure that the pipeline raises an
error in case no epoch has finished successfully
due to max runtime reached
"""

# Convert the training to runtime
fit_dictionary_tabular_dummy.pop('epochs', None)
fit_dictionary_tabular_dummy['budget_type'] = 'runtime'
fit_dictionary_tabular_dummy['runtime'] = 5
fit_dictionary_tabular_dummy['early_stopping'] = -1

pipeline = TabularClassificationPipeline(
dataset_properties=fit_dictionary_tabular_dummy['dataset_properties'])

cs = pipeline.get_hyperparameter_search_space()
config = cs.get_default_configuration()
pipeline.set_hyperparameters(config)

with unittest.mock.patch('autoPyTorch.pipeline.components.training.trainer.BudgetTracker') as patch:
patch.is_max_time_reached.return_value = True
with pytest.raises(RuntimeError):
pipeline.fit(fit_dictionary_tabular_dummy)