Skip to content
13 changes: 10 additions & 3 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)

def _started(self, engine: Engine) -> None:
def _started(self, _engine: Engine) -> None:
"""
Initialize internal buffers.

Args:
_engine: Ignite Engine, unused argument.

"""
self._outputs = []
self._filenames = []

Expand All @@ -120,12 +127,12 @@ def __call__(self, engine: Engine) -> None:
o = o.detach()
self._outputs.append(o)

def _finalize(self, engine: Engine) -> None:
def _finalize(self, _engine: Engine) -> None:
"""
All gather classification results from ranks and save to CSV file.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
_engine: Ignite Engine, unused argument.
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
Expand Down
9 changes: 8 additions & 1 deletion monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,14 @@ def attach(self, engine: Engine) -> None:
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
engine.add_event_handler(Events.EPOCH_COMPLETED, self)

def _started(self, engine: Engine) -> None:
def _started(self, _engine: Engine) -> None:
"""
Initialize internal buffers.

Args:
_engine: Ignite Engine, unused argument.

"""
self._filenames = []

def _get_filenames(self, engine: Engine) -> None:
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def iteration_completed(self, engine: Engine) -> None:
else:
self._default_iteration_print(engine)

def exception_raised(self, engine: Engine, e: Exception) -> None:
def exception_raised(self, _engine: Engine, e: Exception) -> None:
"""
Handler for train or validation/evaluation exception raised Event.
Print the exception information and traceback. This callback may be skipped because the logic
with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
_engine: Ignite Engine, unused argument.
e: the exception caught in Ignite during engine.run().

"""
Expand Down
35 changes: 29 additions & 6 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ def iteration_completed(self, engine: Engine) -> None:
else:
self._default_iteration_writer(engine, self._writer)

def _write_scalar(self, _engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None:
"""
Write scale value into TensorBoard.
Default to call `SummaryWriter.add_scalar()`.

Args:
_engine: Ignite Engine, unused argument.
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.
tag: tag name in the TensorBoard.
value: value of the scalar data for current step.
step: index of current step.

"""
writer.add_scalar(tag, value, step)

def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
"""
Execute epoch level event write operation.
Expand All @@ -188,11 +203,11 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
summary_dict = engine.state.metrics
for name, value in summary_dict.items():
if is_scalar(value):
writer.add_scalar(name, value, current_epoch)
self._write_scalar(engine, writer, name, value, current_epoch)

if self.state_attributes is not None:
for attr in self.state_attributes:
writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch)
self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)
writer.flush()

def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None:
Expand Down Expand Up @@ -221,12 +236,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
" {}:{}".format(name, type(value))
)
continue # not plot multi dimensional output
writer.add_scalar(
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
self._write_scalar(
_engine=engine,
writer=writer,
tag=name,
value=value.item() if isinstance(value, torch.Tensor) else value,
step=engine.state.iteration,
)
elif is_scalar(loss): # not printing multi dimensional output
writer.add_scalar(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
self._write_scalar(
_engine=engine,
writer=writer,
tag=self.tag_name,
value=loss.item() if isinstance(loss, torch.Tensor) else loss,
step=engine.state.iteration,
)
else:
warnings.warn(
Expand Down