Skip to content

Commit

Permalink
Added fix to ensure that custom logged metrics within test_epoch_end …
Browse files Browse the repository at this point in the history
…are appended to the result object even without step reduced metrics (#4251)
  • Loading branch information
SeanNaren committed Oct 20, 2020
1 parent 10a5b58 commit c336881
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Expand Up @@ -173,6 +173,9 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
# now we log all of them
for dl_idx, dl_metrics in enumerate(step_metrics):
if len(dl_metrics) == 0:
# Ensure custom logged metrics are included if not included with step metrics
if len(epoch_logger_metrics) > 0:
self.eval_loop_results.append(epoch_logger_metrics)
continue

reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
Expand Down
33 changes: 33 additions & 0 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Expand Up @@ -311,6 +311,39 @@ def validation_epoch_end(self, outputs) -> None:
assert len(logged_val) == 6


@pytest.mark.parametrize(['batches', 'log_interval', 'max_epochs'], [(1, 1, 1), (64, 32, 2)])
def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs):
"""
Tests that only test_epoch_end can be used to log, and we return them in the results.
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):
def test_epoch_end(self, outputs):
self.log('c', torch.tensor(2), on_epoch=True, prog_bar=True, logger=True)
self.log('d/e/f', 2)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
max_epochs=max_epochs,
log_every_n_steps=log_interval,
weights_summary=None,
)
trainer.fit(model)
results = trainer.test(model)

expected_result_metrics = {
'c': torch.tensor(2),
'd/e/f': 2,
}
for result in results:
assert result == expected_result_metrics


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down

0 comments on commit c336881

Please sign in to comment.