Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
from .utils import from_engine, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .validation_handler import ValidationHandler
10 changes: 8 additions & 2 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class MLFlowHandler:
a HTTP/HTTPS URI for a remote server, a database connection string, or a local path
to log data to a directory. The URI defaults to path `mlruns`.
for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.
iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.
epoch_log: whether to log data to MLFlow when epoch completed, default to `True`.
epoch_logger: customized callable logger for epoch level logging with MLFlow.
Must accept parameter "engine", use default logger if None.
iteration_logger: customized callable logger for iteration level logging with MLFlow.
Expand All @@ -76,6 +78,8 @@ class MLFlowHandler:
def __init__(
self,
tracking_uri: Optional[str] = None,
iteration_log: bool = True,
epoch_log: bool = True,
epoch_logger: Optional[Callable[[Engine], Any]] = None,
iteration_logger: Optional[Callable[[Engine], Any]] = None,
output_transform: Callable = lambda x: x[0],
Expand All @@ -86,6 +90,8 @@ def __init__(
if tracking_uri is not None:
mlflow.set_tracking_uri(tracking_uri)

self.iteration_log = iteration_log
self.epoch_log = epoch_log
self.epoch_logger = epoch_logger
self.iteration_logger = iteration_logger
self.output_transform = output_transform
Expand All @@ -103,9 +109,9 @@ def attach(self, engine: Engine) -> None:
"""
if not engine.has_event_handler(self.start, Events.STARTED):
engine.add_event_handler(Events.STARTED, self.start)
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)

def start(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class StatsHandler:

def __init__(
self,
iteration_log: bool = True,
epoch_log: bool = True,
epoch_print_logger: Optional[Callable[[Engine], Any]] = None,
iteration_print_logger: Optional[Callable[[Engine], Any]] = None,
output_transform: Callable = lambda x: x[0],
Expand All @@ -73,6 +75,8 @@ def __init__(
"""

Args:
iteration_log: whether to log data when iteration completed, default to `True`.
epoch_log: whether to log data when epoch completed, default to `True`.
epoch_print_logger: customized callable printer for epoch level logging.
Must accept parameter "engine", use default printer if None.
iteration_print_logger: customized callable printer for iteration level logging.
Expand All @@ -98,6 +102,8 @@ def __init__(

"""

self.iteration_log = iteration_log
self.epoch_log = epoch_log
self.epoch_print_logger = epoch_print_logger
self.iteration_print_logger = iteration_print_logger
self.output_transform = output_transform
Expand All @@ -123,9 +129,9 @@ def attach(self, engine: Engine) -> None:
"the effective log level of engine logger or RootLogger is higher than INFO, may not record log,"
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
)
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
Expand Down
10 changes: 8 additions & 2 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(
self,
summary_writer: Optional[SummaryWriter] = None,
log_dir: str = "./runs",
iteration_log: bool = True,
epoch_log: bool = True,
epoch_event_writer: Optional[Callable[[Engine, SummaryWriter], Any]] = None,
epoch_interval: int = 1,
iteration_event_writer: Optional[Callable[[Engine, SummaryWriter], Any]] = None,
Expand All @@ -98,6 +100,8 @@ def __init__(
summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,
default to create a new TensorBoard writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`.
epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`.
epoch_event_writer: customized callable TensorBoard writer for epoch level.
Must accept parameter "engine" and "summary_writer", use default event writer if None.
epoch_interval: the epoch interval at which the epoch_event_writer is called. Defaults to 1.
Expand All @@ -121,6 +125,8 @@ def __init__(
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
"""
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
self.iteration_log = iteration_log
self.epoch_log = epoch_log
self.epoch_event_writer = epoch_event_writer
self.epoch_interval = epoch_interval
self.iteration_event_writer = iteration_event_writer
Expand All @@ -138,11 +144,11 @@ def attach(self, engine: Engine) -> None:
engine: Ignite Engine, it can be a trainer, validator or evaluator.

"""
if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(
Events.ITERATION_COMPLETED(every=self.iteration_interval), self.iteration_completed
)
if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed)

def epoch_completed(self, engine: Engine) -> None:
Expand Down
11 changes: 10 additions & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -194,3 +194,12 @@ def _wrapper(data):
return tuple(ret) if len(ret) > 1 else ret[0]

return _wrapper


def ignore_data(x: Any):
"""
Always return `None` for any input data.
A typical usage is to avoid logging the engine output of every iteration during evaluation.

"""
return None
4 changes: 3 additions & 1 deletion tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _update_metric(engine):

# set up testing handler
test_path = os.path.join(tempdir, "mlflow_test")
handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri(), state_attributes=["test"])
handler = MLFlowHandler(
iteration_log=False, epoch_log=True, tracking_uri=Path(test_path).as_uri(), state_attributes=["test"]
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
handler.close()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _update_metric(engine):
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(name=key_to_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
Expand Down Expand Up @@ -78,7 +78,7 @@ def _train_func(engine, batch):
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print)
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_handler_tb_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _update_metric(engine):
engine.state.metrics["acc"] = current_metric + 0.1

# set up testing handler
stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=True)
stats_handler.attach(engine)
engine.run(range(3), max_epochs=2)
stats_handler.close()
Expand All @@ -63,6 +63,8 @@ def _update_metric(engine):
writer = SummaryWriter(log_dir=tempdir)
stats_handler = TensorBoardStatsHandler(
summary_writer=writer,
iteration_log=True,
epoch_log=False,
output_transform=lambda x: {"loss": x[0] * 2.0},
global_epoch_transform=lambda x: x * 3.0,
state_attributes=["test"],
Expand Down
6 changes: 3 additions & 3 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def _forward_completed(self, engine):
pass

val_handlers = [
StatsHandler(output_transform=lambda x: None),
TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None),
StatsHandler(iteration_log=False),
TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False),
TensorBoardImageHandler(
log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred")
),
Expand Down Expand Up @@ -250,7 +250,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor
]
)
val_handlers = [
StatsHandler(output_transform=lambda x: None),
StatsHandler(iteration_log=False),
CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
SegmentationSaver(
output_dir=root_dir,
Expand Down