Skip to content

Commit

Permalink
Add on_exception to DataModule (#19601)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
  • Loading branch information
clumsy and azzhipa committed Mar 22, 2024
1 parent 6cfc590 commit d5a9b77
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))

- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))

-

### Changed
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def val_dataloader(self):
def test_dataloader(self):
return data.DataLoader(self.test)
def on_exception(self, exception):
# clean up state after the trainer faced an exception
...
def teardown(self):
# clean up state after the trainer stops, delete files...
# called on every process in DDP
Expand Down Expand Up @@ -161,6 +165,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
pass

def on_exception(self, exception: BaseException) -> None:
"""Called when the trainer execution is interrupted by an exception."""
pass

@_restricted_classmethod
def load_from_checkpoint(
cls,
Expand Down
22 changes: 12 additions & 10 deletions src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,25 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
_interrupt(trainer, exception)
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
_interrupt(trainer, exception)
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise


def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
if trainer.datamodule is not None:
_call_lightning_datamodule_hook(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")


def _call_setup_hook(trainer: "pl.Trainer") -> None:
assert trainer.state.fn is not None
fn = trainer.state.fn
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringModel,
RandomDataset,
RandomIterableDataset,
Expand Down Expand Up @@ -2050,6 +2051,24 @@ def on_fit_start(self):
on_exception_mock.assert_called_once_with(exception)


@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
def test_trainer_calls_datamodule_on_exception(exception_type):
"""Test that when an exception occurs, the Trainer lets the data module process it."""
exception = exception_type("Test exception")

class ExceptionModel(BoringModel):
def on_fit_start(self):
raise exception

datamodule = BoringDataModule()
datamodule.on_exception = Mock()
trainer = Trainer()

with suppress(Exception):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)


def test_init_module_context(monkeypatch):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator="cpu", devices=1)
Expand Down

0 comments on commit d5a9b77

Please sign in to comment.