Skip to content

Commit

Permalink
bugfix: update MLFlowLogger's status to be FAILED when trainig raises…
Browse files Browse the repository at this point in the history
… an error
  • Loading branch information
Ritsuki Yamada committed Mar 10, 2022
1 parent d31126c commit 2bf6244
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
@rank_zero_only
def finalize(self, status: str = "FINISHED") -> None:
super().finalize(status)
status = "FINISHED" if status == "success" else status
if status == "success":
status = "FINISHED"
elif status == "failed":
status = "FAILED"
if self.experiment.get_run(self.run_id):
self.experiment.set_terminated(self.run_id, status)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,16 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
self.state.status = TrainerStatus.INTERRUPTED
self._call_callback_hooks("on_keyboard_interrupt")
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
except BaseException as exception:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaining processes, kill otherwise
self.strategy.reconciliate_processes(traceback.format_exc())
self._call_callback_hooks("on_exception", exception)
for logger in self.loggers:
logger.finalize("failed")
self._teardown()
# teardown might access the stage so we reset it after
self.state.stage = None
Expand Down
19 changes: 19 additions & 0 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,22 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_run_status_failed(client, mlflow):
class CustomModel(BoringModel):
def training_step(self, batch, batch_idx):
super().training_step(batch, batch_idx)
raise BaseException
model = CustomModel()
logger = MLFlowLogger("test")
run = MagicMock()
run.info.run_id = "run_id"
logger._mlflow_client.create_run = MagicMock(return_value=run)
trainer = Trainer(logger=logger)

with pytest.raises(BaseException):
trainer.fit(model)
client.return_value.set_terminated.assert_called_once_with(logger.run_id, "FAILED")

0 comments on commit 2bf6244

Please sign in to comment.