Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inconsistent outputs in on_*_end and *_end #6969

Merged
merged 30 commits into from Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ad8647c
Fix output consistency
ethanwharris Apr 12, 2021
de66f3d
Updates
ethanwharris Apr 12, 2021
2bc7035
Fix tests
ethanwharris Apr 12, 2021
243ad46
Remove commented code
ethanwharris Apr 12, 2021
e4efdb7
Remove commented code
ethanwharris Apr 12, 2021
7ce977a
Remove unused imports
ethanwharris Apr 12, 2021
6c79531
Fix broken test
ethanwharris Apr 12, 2021
e6642be
Fix broken test
ethanwharris Apr 12, 2021
61a777e
Fix broken test
ethanwharris Apr 12, 2021
34bec23
Fix broken docs
ethanwharris Apr 12, 2021
2e57f48
Update CHANGELOG.md
ethanwharris Apr 12, 2021
57ce113
Fix test
ethanwharris Apr 12, 2021
274846b
Update CHANGELOG.md
ethanwharris Apr 12, 2021
f86f141
Add some comments
ethanwharris Apr 12, 2021
5a8ea62
Merge branch 'bugfix/inconsistent_outputs' of https://github.com/PyTo…
ethanwharris Apr 12, 2021
7d641f6
Merge branch 'master' into bugfix/inconsistent_outputs
carmocca Apr 13, 2021
b77b691
Apply suggestions from code review
ananthsub Apr 13, 2021
d4a0fae
Add typing
ethanwharris Apr 13, 2021
b049b6a
Update CHANGELOG.md
ethanwharris Apr 13, 2021
47c9119
Fix typing bug
ethanwharris Apr 13, 2021
4b30fbc
Merge branch 'master' into bugfix/inconsistent_outputs
kaushikb11 Apr 13, 2021
4fac95a
Add tests
ethanwharris Apr 13, 2021
c968cb4
Merge branch 'bugfix/inconsistent_outputs' of https://github.com/PyTo…
ethanwharris Apr 13, 2021
4dfc5be
Update pytorch_lightning/trainer/training_loop.py
ethanwharris Apr 13, 2021
7b84f02
Remove unused imports
ethanwharris Apr 13, 2021
35a3365
Fix test
ethanwharris Apr 13, 2021
597e173
Fix broken test
ethanwharris Apr 13, 2021
793a69b
Fix test name and doc
ethanwharris Apr 13, 2021
036fc15
Remove unused import
ethanwharris Apr 13, 2021
d3b1735
Merge branch 'master' into bugfix/inconsistent_outputs
ethanwharris Apr 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Expand Up @@ -240,6 +240,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed a bug where `LightningModule.training_epoch_end` was called after the `on_train_end_epoch` hook ([#6969](https://github.com/PyTorchLightning/pytorch-lightning/pull/6969))
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved


- Fixed a bug where the outputs object passed to `LightningModule.training_epoch_end` was different from the object passed to the `on_train_end_epoch` hook ([#6969](https://github.com/PyTorchLightning/pytorch-lightning/pull/6969))


- Fixed a bug where the outputs passed to `train_batch_end` would be lists even when using a single optimizer and no truncated backprop through time steps ([#6969](https://github.com/PyTorchLightning/pytorch-lightning/pull/6969))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/lightning_module.rst
Expand Up @@ -1046,7 +1046,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
val_loop()

# end training epoch
outs = training_epoch_end(outs)
training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()

Expand Down
Expand Up @@ -14,7 +14,7 @@
import os
from copy import deepcopy
from pprint import pprint
from typing import Dict, Iterable, Optional, Union
from typing import Dict, Iterable, List, Optional, Union

import torch

Expand All @@ -26,8 +26,6 @@
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.utilities import DeviceType, flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden


class LoggerConnector:
Expand Down Expand Up @@ -328,16 +326,11 @@ def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True

def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
def log_train_epoch_end_metrics(self, epoch_output: List[List[List[Result]]]) -> None:
# epoch output is a list. Each item in that list has all the outputs per optimizer
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)

model = self.trainer.lightning_module

# lightning module hook
self.training_epoch_end(model, epoch_output, num_optimizers)

# log/aggregate metrics automatically
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

Expand Down Expand Up @@ -365,29 +358,6 @@ def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
# reset epoch loop result for next epoch
self.cached_results.reset()

def training_epoch_end(self, model, epoch_output, num_optimizers):
if not is_overridden('training_epoch_end', model=model):
return

# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = 'training_epoch_end'
epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
epoch_output = epoch_output[0]

# lightningmodule hook
epoch_output = model.training_epoch_end(epoch_output)

if epoch_output is not None:
raise MisconfigurationException(
'training_epoch_end expects a return of None. '
'HINT: remove the return statement in training_epoch_end'
)
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
epoch_progress_bar_metrics = {}
Expand All @@ -413,33 +383,6 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):

return epoch_log_metrics, epoch_progress_bar_metrics

def __prepare_epoch_end_inputs(self, epoch_output):
"""
Pulls out only the "extra" information for epoch end

Return:
a single list, each element per optimizer then batch then time
"""
gathered_epoch_outputs = []
for opt_outputs in epoch_output:
# gather across time first
time_gathered_outputs = []
for tbptt_outs in opt_outputs:
result = []
for x in tbptt_outs:
out = x.extra
out['loss'] = x.minimize
result.append(out)

# when time = 0, pass in the literal dict instead of array
if len(result) == 1:
result = result[0]
time_gathered_outputs.append(result)

gathered_epoch_outputs.append(time_gathered_outputs)

return gathered_epoch_outputs

def log_train_step_metrics(self, batch_output):
if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
return
Expand Down
44 changes: 15 additions & 29 deletions pytorch_lightning/trainer/evaluation_loop.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union

import torch

Expand Down Expand Up @@ -186,12 +187,12 @@ def evaluation_step_end(self, *args, **kwargs):
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output

def evaluation_epoch_end(self):
def evaluation_epoch_end(self, outputs):
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()

# call the model epoch end
deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders)
deprecated_results = self.__run_eval_epoch_end(outputs)

# enable returning anything
for i, r in enumerate(deprecated_results):
Expand All @@ -205,46 +206,40 @@ def log_epoch_metrics_on_evaluation_end(self):
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
return eval_loop_results

def __run_eval_epoch_end(self, num_dataloaders):
def __run_eval_epoch_end(self, outputs):
model = self.trainer.lightning_module

# with a single dataloader don't pass an array
outputs = self.outputs

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]

epoch_end_output = None
user_reduced = False

if self.trainer.testing:
if is_overridden('test_epoch_end', model=model):
model._current_fx_name = 'test_epoch_end'
eval_results = model.test_epoch_end(eval_results)
epoch_end_output = model.test_epoch_end(outputs)
user_reduced = True

else:
if is_overridden('validation_epoch_end', model=model):
model._current_fx_name = 'validation_epoch_end'
eval_results = model.validation_epoch_end(eval_results)
epoch_end_output = model.validation_epoch_end(outputs)
user_reduced = True

# capture logging
self.trainer.logger_connector.cache_logged_metrics()
# depre warning
if eval_results is not None and user_reduced:
if epoch_end_output is not None and user_reduced:
step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end'
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(eval_results, list):
eval_results = [eval_results]
if not isinstance(outputs, list):
outputs = [outputs]

self.trainer.logger_connector._track_callback_metrics(eval_results)
self.trainer.logger_connector._track_callback_metrics(outputs)

return eval_results
return outputs

def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
Expand Down Expand Up @@ -307,18 +302,7 @@ def store_predictions(self, output, batch_idx, dataloader_idx):
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output)

def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
self.call_on_evaluation_epoch_end_hook()

self.trainer.call_hook('on_epoch_end')

def call_on_evaluation_epoch_end_hook(self):
outputs = self.outputs

# free memory
self.outputs = []

def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"

Expand All @@ -343,6 +327,8 @@ def call_on_evaluation_epoch_end_hook(self):

self.trainer._cache_logged_metrics()

self.trainer.call_hook('on_epoch_end')

def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.sanity_checking:
return
Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -711,11 +711,20 @@ def run_evaluation(self, on_epoch=False):
# store batch level output per dataloader
self.evaluation_loop.outputs.append(dl_outputs)

outputs = self.evaluation_loop.outputs

# reset outputs
self.evaluation_loop.outputs = []

# with a single dataloader don't pass a 2D list
if self.evaluation_loop.num_dataloaders == 1:
outputs = outputs[0]

# lightning module method
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end(outputs)

# hook
self.evaluation_loop.on_evaluation_epoch_end()
self.evaluation_loop.on_evaluation_epoch_end(outputs)

# update epoch-level lr_schedulers
if on_epoch:
Expand Down