Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test configuration check and testing #1804

Merged
merged 6 commits into from
May 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,13 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
return eval_results

def run_evaluation(self, test_mode: bool = False):
# when testing make sure user defined a test step
if test_mode and not self.is_overridden('test_step'):
raise MisconfigurationException(
"You called `.test()` without defining model's `.test_step()`."
" Please define and try again")

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test_mode:
if self.test_dataloaders is None:
self.reset_test_dataloader(model)
self.reset_test_dataloader(model)

dataloaders = self.test_dataloaders
max_batches = self.num_test_batches
Expand Down
91 changes: 37 additions & 54 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,9 +1021,6 @@ def test(
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

# give proper warnings if user only passed in loader without hooks
self.check_testing_model_configuration(model if model else self.model)

if model is not None:
self.model = model
self.fit(model)
Expand All @@ -1042,44 +1039,45 @@ def test(

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training is started.
Checks that the model is configured correctly before training or testing is started.

Args:
model: The model to test.
model: The model to check the configuration.

"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.is_overridden('training_step', model):
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overridden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overridden('configure_optimizers', model):
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overridden('val_dataloader', model):
if not self.is_overridden('validation_step', model):
raise MisconfigurationException('You have passed in a `val_dataloader()`'
' but have not defined `validation_step()`.')
if not self.testing:
if not self.is_overridden('training_step', model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider moving this to __check_model_configuration_test and alternatively bellow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing check_testing_model_configuration because it's just a repeated code and testing configuration can be handled within check_model_configuration. Also during testing check_model_configuration is called which is unnecessary and was not working in #1720. The PR fixes both the issues.

raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overridden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overridden('configure_optimizers', model):
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overridden('val_dataloader', model):
if not self.is_overridden('validation_step', model):
raise MisconfigurationException('You have passed in a `val_dataloader()`'
' but have not defined `validation_step()`.')
else:
if not self.is_overridden('validation_epoch_end', model):
rank_zero_warn(
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
' you may also want to define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning
)
else:
if not self.is_overridden('validation_epoch_end', model):
rank_zero_warn(
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
' you may also want to define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning
)
else:
if self.is_overridden('validation_step', model):
raise MisconfigurationException('You have defined `validation_step()`,'
' but have not passed in a val_dataloader().')
if self.is_overridden('validation_step', model):
raise MisconfigurationException('You have defined `validation_step()`,'
' but have not passed in a `val_dataloader()`.')

# Check test_dataloader, test_step and test_epoch_end
if self.is_overridden('test_dataloader', model):
Expand All @@ -1092,25 +1090,10 @@ def check_model_configuration(self, model: LightningModule):
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)

def check_testing_model_configuration(self, model: LightningModule):

has_test_step = self.is_overridden('test_step', model)
has_test_epoch_end = self.is_overridden('test_epoch_end', model)
gave_test_loader = self.is_overridden('test_dataloader', model)

if gave_test_loader and not has_test_step:
raise MisconfigurationException('You passed in a `test_dataloader` but did not implement `test_step()`')

if has_test_step and not gave_test_loader:
raise MisconfigurationException('You defined `test_step()` but did not implement'
' `test_dataloader` nor passed in `.fit(test_dataloaders`.')

if has_test_step and gave_test_loader and not has_test_epoch_end:
rank_zero_warn(
'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)
else:
if self.testing and self.is_overridden('test_step', model):
raise MisconfigurationException('You have defined `test_step()` but did not'
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')


class _PatchDataLoader(object):
Expand Down