Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval epoch can now log independently #3843

Merged
merged 4 commits into from Oct 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 45 additions & 9 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Expand Up @@ -22,6 +22,7 @@
from pprint import pprint
from typing import Iterable
from copy import deepcopy
from collections import ChainMap


class LoggerConnector:
Expand Down Expand Up @@ -105,12 +106,12 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
self._track_callback_metrics(eval_results, using_eval_result)
self._log_on_evaluation_epoch_end_metrics()
def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs, using_eval_result, test_mode):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
self._log_on_evaluation_epoch_end_metrics(epoch_logs)

# TODO: deprecate parts of this for 1.0 (when removing results)
self.__process_eval_epoch_end_results_and_log_legacy(eval_results, test_mode)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)

# get the final loop results
eval_loop_results = self._get_evaluate_epoch_results(test_mode)
Expand All @@ -131,15 +132,43 @@ def _get_evaluate_epoch_results(self, test_mode):
self.eval_loop_results = []
return results

def _log_on_evaluation_epoch_end_metrics(self):
def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
step_metrics = self.trainer.evaluation_loop.step_metrics

num_loaders = len(step_metrics)

# clear mem
self.trainer.evaluation_loop.step_metrics = []

num_loaders = len(step_metrics)
if self.trainer.running_sanity_check:
return

# track all metrics we want to log
metrics_to_log = []

# process metrics per dataloader
# ---------------------------
# UPDATE EPOCH LOGGED METRICS
# ---------------------------
# (ie: in methods at the val_epoch_end level)
# union the epoch logs with whatever was returned from loaders and reduced
epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()

self.logged_metrics.update(epoch_logger_metrics)
self.progress_bar_metrics.update(epoch_pbar_metrics)

# enable the metrics to be monitored
self.callback_metrics.update(epoch_logger_metrics)
self.callback_metrics.update(epoch_pbar_metrics)

if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# --------------------------------
# UPDATE METRICS PER DATALOADER
# --------------------------------
# each dataloader aggregated metrics
# now we log all of them
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
continue
Expand All @@ -162,7 +191,13 @@ def _log_on_evaluation_epoch_end_metrics(self):
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)
if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {}, step=self.trainer.global_step)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
Expand Down Expand Up @@ -240,7 +275,8 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
self.trainer.logger_connector.log_metrics(log_metrics, {})
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

# track metrics for callbacks (all prog bar, logged and callback metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/evaluation_loop.py
Expand Up @@ -171,19 +171,23 @@ def evaluation_epoch_end(self, num_dataloaders):
using_eval_result = self.is_using_eval_results()

# call the model epoch end
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)

# 1.0
epoch_logs = self.trainer.get_model()._results

# enable returning anything
for r in eval_results:
for i, r in enumerate(deprecated_results):
if not isinstance(r, (dict, Result, torch.Tensor)):
return []
deprecated_results[i] = []

return eval_results
return deprecated_results, epoch_logs

def log_epoch_metrics(self, eval_results, test_mode):
def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode):
using_eval_result = self.is_using_eval_results()
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
eval_results,
deprecated_eval_results,
epoch_logs,
using_eval_result,
test_mode
)
Expand Down Expand Up @@ -228,10 +232,6 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)

result = model._results
if len(result) > 0 and eval_results is None:
eval_results = result.get_epoch_log_metrics()

if not isinstance(eval_results, list):
eval_results = [eval_results]

Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -602,10 +602,12 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
self.evaluation_loop.step_metrics.append(dl_step_metrics)

# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
deprecated_eval_results, epoch_logs = self.evaluation_loop.evaluation_epoch_end(
num_dataloaders=len(dataloaders)
)

# bookkeeping
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode)
self.evaluation_loop.predictions.to_disk()

# hook
Expand All @@ -618,7 +620,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
# hook
self.evaluation_loop.on_evaluation_end()

return eval_loop_results, eval_results
return eval_loop_results, deprecated_eval_results

def run_test(self):
# only load test dataloader for testing
Expand Down
Expand Up @@ -63,7 +63,7 @@ def test_val_step_result_callbacks(tmpdir):

# did not request any metrics to log (except the metrics saying which epoch we are on)
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.dev_debugger.logged_metrics) == 5
assert len(trainer.dev_debugger.logged_metrics) == 0


def test_val_step_using_train_callbacks(tmpdir):
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_val_step_using_train_callbacks(tmpdir):

# did not request any metrics to log (except the metrics saying which epoch we are on)
assert len(trainer.logger_connector.progress_bar_metrics) == 0
assert len(trainer.dev_debugger.logged_metrics) == expected_epochs
assert len(trainer.dev_debugger.logged_metrics) == 0


def test_val_step_only_epoch_metrics(tmpdir):
Expand Down Expand Up @@ -210,40 +210,9 @@ def test_val_step_only_step_metrics(tmpdir):
assert len(trainer.dev_debugger.early_stopping_history) == 0

# make sure we logged the exact number of metrics
assert len(trainer.dev_debugger.logged_metrics) == epochs * batches + (epochs)
assert len(trainer.dev_debugger.logged_metrics) == epochs * batches
assert len(trainer.dev_debugger.pbar_added_metrics) == epochs * batches + (epochs)

# make sure we logged the correct epoch metrics
total_empty_epoch_metrics = 0
epoch = 0
for metric in trainer.dev_debugger.logged_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_pbar_acc' not in metric
assert metric[f'val_step_log_acc/epoch_{epoch}']
assert metric[f'val_step_log_pbar_acc/epoch_{epoch}']
else:
total_empty_epoch_metrics += 1

assert total_empty_epoch_metrics == 3

# make sure we logged the correct epoch pbar metrics
total_empty_epoch_metrics = 0
for metric in trainer.dev_debugger.pbar_added_metrics:
if 'epoch' in metric:
epoch += 1
if len(metric) > 2:
assert 'no_val_no_pbar' not in metric
assert 'val_step_log_acc' not in metric
assert metric['val_step_log_pbar_acc']
assert metric['val_step_pbar_acc']
else:
total_empty_epoch_metrics += 1

assert total_empty_epoch_metrics == 3

# only 1 checkpoint expected since values didn't change after that
assert len(trainer.dev_debugger.checkpoint_callback_history) == 1

Expand Down
50 changes: 49 additions & 1 deletion tests/trainer/logging/test_eval_loop_logging_1_0.py
Expand Up @@ -4,9 +4,10 @@
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule
from tests.base import SimpleModule, BoringModel
import os
import torch
import pytest


def test__validation_step__log(tmpdir):
Expand Down Expand Up @@ -148,6 +149,53 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
assert expected_cb_metrics == callback_metrics


@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)])
def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs):
"""
Tests that only training_step can be used
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):
def validation_epoch_end(self, outputs):
self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True)
self.log('d/e/f', 2)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
max_epochs=max_epochs,
row_log_interval=log_interval,
weights_summary=None,
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
logged_metrics = set(trainer.logged_metrics.keys())
expected_logged_metrics = {
'c',
'd/e/f',
}
assert logged_metrics == expected_logged_metrics

pbar_metrics = set(trainer.progress_bar_metrics.keys())
expected_pbar_metrics = {'c'}
assert pbar_metrics == expected_pbar_metrics

callback_metrics = set(trainer.callback_metrics.keys())
expected_callback_metrics = set()
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
callback_metrics.remove('debug_epoch')
assert callback_metrics == expected_callback_metrics

# assert the loggers received the expected number
assert len(trainer.dev_debugger.logged_metrics) == max_epochs


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down