Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent outputs format between training_epoch_end and on_train_epoch_end #6921

Closed
carmocca opened this issue Apr 9, 2021 · 3 comments 路 Fixed by #6969
Closed

Inconsistent outputs format between training_epoch_end and on_train_epoch_end #6921

carmocca opened this issue Apr 9, 2021 · 3 comments 路 Fixed by #6969
Assignees
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task
Milestone

Comments

@carmocca
Copy link
Member

carmocca commented Apr 9, 2021

馃悰 Bug

The outputs object for on_train_epoch_end should not include the extra field

To Reproduce

def test_bug(tmpdir):
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            output = self(batch)
            loss = self.loss(batch, output)
            return {"loss": loss, "foo": 123}

        def training_epoch_end(self, outputs):
            print("training_epoch_end:", outputs)

        def on_train_epoch_end(self, outputs):
            print("on_train_epoch_end:", outputs)

    class TestCallback(Callback):
        def on_train_epoch_end(self, trainer, pl_module, outputs):
            print("callback on_train_epoch_end:", outputs)

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[TestCallback()], progress_bar_refresh_rate=0)
    trainer.fit(TestModel())
callback on_train_epoch_end: [[[{'extra': {'foo': 123}, 'minimize': tensor(1.1792)}]]]
on_train_epoch_end: [[[{'extra': {'foo': 123}, 'minimize': tensor(1.1792)}]]]
training_epoch_end: [{'foo': 123, 'loss': tensor(1.1792)}]

Expected behavior

on_train_epoch_end acts as training_epoch_end

Environment

master

Additional context

Reported by Marek O in Slack

@carmocca carmocca added bug Something isn't working help wanted Open to be worked on priority: 0 High priority task logging Related to the `LoggerConnector` and `log()` labels Apr 9, 2021
@carmocca carmocca added this to the 1.2.x milestone Apr 9, 2021
@ethanwharris ethanwharris self-assigned this Apr 9, 2021
@breznak
Copy link

breznak commented Apr 9, 2021

Hi, I'm Marek from slack, thank you for the reproducible!

I'll add extended test that shows also difference in on_xxx_batch_end() between train and val/test:

Test

import os
from typing import Any

import torch
from pl_examples.bug_report_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer, Callback, LightningModule


def test_bug(tmpdir):
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            output = self(batch)
            loss = self.loss(batch, output)
            return {"loss": loss, "foo": 123}

        def validation_step(self, batch, batch_idx):
            output = self(batch)
            loss = self.loss(batch, output)
            return {"loss": loss, "foo": 'from_val_step'}

        def training_epoch_end(self, outputs):
            print("training_epoch_end:", outputs)

        def validation_epoch_end(self, outputs) -> None:
            print("validation_epoch_end:", outputs)

        def on_train_epoch_end(self, outputs):
            print("on_train_epoch_end:", outputs)

        def on_validation_epoch_end(self) -> None:  #FIXME should have (optional) outputs as validation_epoch_end(self, outputs)
            #print("on_validation_epoch_end:", outputs)
            print("on_validation_epoch_end:", 'FIXME should have outputs')

    class TestCallback(Callback):
        def on_train_epoch_end(self, trainer, pl_module, outputs):
            print("callback on_train_epoch_end:", outputs)

        def on_train_batch_end(self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
            print('callback on_train_batch_end:', outputs)


        def on_validation_batch_end(self, trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
            print("callback on_validation_batch_end:", outputs)

    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=False,
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        callbacks=[TestCallback()],
        progress_bar_refresh_rate=0,
        weights_summary=None,
    )

    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)



if __name__ == '__main__':
    test_bug('/tmp')

output:

callback on_validation_batch_end: {'loss': tensor(2.7839), 'foo': 'from_val_step'}
validation_epoch_end: [{'loss': tensor(2.7839), 'foo': 'from_val_step'}]
on_validation_epoch_end: FIXME should have outputs
callback on_train_batch_end: [[{'extra': {'foo': 123}, 'minimize': tensor(1.2255)}]]
callback on_train_epoch_end: [[[{'extra': {'foo': 123}, 'minimize': tensor(1.2255)}]]]
on_train_epoch_end: [[[{'extra': {'foo': 123}, 'minimize': tensor(1.2255)}]]]
training_epoch_end: [{'foo': 123, 'loss': tensor(1.2255)}]
callback on_validation_batch_end: {'loss': tensor(2.1134), 'foo': 'from_val_step'}
validation_epoch_end: [{'loss': tensor(2.1134), 'foo': 'from_val_step'}]
on_validation_epoch_end: FIXME should have outputs

Expected behavior:

on_validation_epoch_end(self) (and on_test_epoch_end(self)) have same API as on_train_epoch_end(self, outputs)

@ananthsub
Copy link
Contributor

@carmocca i'm running into this bug as well in #6944
The issue is that the pseudocode here is not right! https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks

# end training epoch
outs = training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()

From the code: https://github.com/PyTorchLightning/pytorch-lightning/blob/b85cfbe8f350a89c93ff967fa416859be7ebb4f3/pytorch_lightning/trainer/training_loop.py#L480-L484

the model's training_epoch_end is called inside of the logger connector: https://github.com/PyTorchLightning/pytorch-lightning/blob/b85cfbe8f350a89c93ff967fa416859be7ebb4f3/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L331-L339

which happens after on_train_epoch_end is called

@edenafek @tchaton

@carmocca
Copy link
Member Author

Thanks for looking @ananthsub! So we should fix the order and maybe also call logger_connector.training_epoch_end in the training loop instead so the logic is a tiny bit cleaner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants