From 56aaf5d183a1fe51450205143466f6598e7fc16b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 11 Oct 2021 16:41:03 +0800 Subject: [PATCH 1/4] [DLMED] add MLFlowHandler Signed-off-by: Nic Ma --- docs/source/handlers.rst | 5 + monai/handlers/__init__.py | 1 + monai/handlers/mlflow_handler.py | 179 +++++++++++++++++++++++++++++++ tests/min_tests.py | 1 + tests/test_handler_mlflow.py | 47 ++++++++ 5 files changed, 233 insertions(+) create mode 100644 monai/handlers/mlflow_handler.py create mode 100644 tests/test_handler_mlflow.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index cb4333d1da..c48ffd412c 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -165,6 +165,11 @@ NVTX Handlers .. automodule:: monai.handlers.nvtx_handlers :members: +MLFlow handler +-------------- +.. autoclass:: MLFlowHandler + :members: + Utilities --------- .. automodule:: monai.handlers.utils diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index bf1a9d3f89..520af0a94c 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -22,6 +22,7 @@ from .mean_dice import MeanDice from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver +from .mlflow_handler import MLFlowHandler from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .parameter_scheduler import ParamSchedulerHandler from .postprocessing import PostProcessing diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py new file mode 100644 index 0000000000..41e69f4e79 --- /dev/null +++ b/monai/handlers/mlflow_handler.py @@ -0,0 +1,179 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch + +from monai.config import IgniteInfo +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +mlflow, _ = optional_import("mlflow") + +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +DEFAULT_TAG = "Loss" + + +class MLFlowHandler: + """ + MLFlowHandler defines a set of Ignite Event-handlers for the MLFlow tracking logics. + It's can be used for any Ignite Engine(trainer, validator and evaluator). + And it can track both epoch level and iteration level logging, then MLFlow can store + the data and visualize. + The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. + Default behaviors: + - When EPOCH_COMPLETED, track each dictionary item in + ``engine.state.metrics`` in MLFlow. + - When ITERATION_COMPLETED, track expected item in + ``self.output_transform(engine.state.output)`` in MLFlow, default to `Loss`. + + For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. + + """ + + def __init__( + self, + tracking_uri: Optional[str] = None, + epoch_logger: Optional[Callable[[Engine], Any]] = None, + iteration_logger: Optional[Callable[[Engine], Any]] = None, + output_transform: Callable = lambda x: x[0], + global_epoch_transform: Callable = lambda x: x, + tag_name: str = DEFAULT_TAG, + ) -> None: + """ + + Args: + tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment + variable to have MLflow find a URI from there. in both cases, the URI can either be + a HTTP/HTTPS URI for a remote server, a database connection string, or a local path + to log data to a directory. The URI defaults to path `mlruns`. + for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. + epoch_logger: customized callable logger for epoch level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + iteration_logger: customized callable logger for iteration level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + output_transform: a callable that is used to transform the + ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. + By default this value logging happens when every iteration completed. + The default behavior is to track loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + global_epoch_transform: a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to track synced epoch number + with the trainer engine. + tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. + + """ + if tracking_uri is not None: + mlflow.set_tracking_uri(tracking_uri) + + self.epoch_logger = epoch_logger + self.iteration_logger = iteration_logger + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + self.tag_name = tag_name + + def attach(self, engine: Engine) -> None: + """ + Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.start, Events.STARTED): + engine.add_event_handler(Events.STARTED, self.start) + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def start(self) -> None: + """ + Check MLFlow status and start if not active. + + """ + if mlflow.active_run() is None: + mlflow.start_run() + + def close(self) -> None: + """ + Stop current running logger of MLFlow. + + """ + mlflow.end_run() + + def epoch_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation epoch completed Event. + Track epoch level log, default values are from Ignite state.metrics dict. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_logger is not None: + self.epoch_logger(engine) + else: + self._default_epoch_log(engine) + + def iteration_completed(self, engine: Engine) -> None: + """ + Handler for train or validation/evaluation iteration completed Event. + Track iteration level log. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_logger is not None: + self.iteration_logger(engine) + else: + self._default_iteration_log(engine) + + def _default_epoch_log(self, engine: Engine) -> None: + """ + Execute epoch level log operation based on Ignite engine.state data. + Track the values from Ignite state.metrics dict. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + log_dict = engine.state.metrics + if not log_dict: + return + + current_epoch = self.global_epoch_transform(engine.state.epoch) + mlflow.log_metrics(log_dict, step=current_epoch) + + def _default_iteration_log(self, engine: Engine) -> None: + """ + Execute iteration log operation based on Ignite engine.state data. + The default behavior is to track loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + + """ + loss = self.output_transform(engine.state.output) + if loss is None: + return + + if not isinstance(loss, dict): + loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss} + + mlflow.log_metrics(loss, step=engine.state.iteration) diff --git a/tests/min_tests.py b/tests/min_tests.py index 5e188a828e..27f45c062a 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_transchex", + "test_handler_mlflow", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py new file mode 100644 index 0000000000..d2bd10baf5 --- /dev/null +++ b/tests/test_handler_mlflow.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import tempfile +import unittest + +from ignite.engine import Engine, Events + +from monai.handlers import MLFlowHandler + + +class TestHandlerMLFlow(unittest.TestCase): + def test_metrics_track(self): + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + + # set up testing handler + handler = MLFlowHandler(tracking_uri="file:/" + tempdir) + handler.attach(engine) + engine.run(range(3), max_epochs=2) + handler.close() + # check logging output + self.assertTrue(len(glob.glob(tempdir)) > 0) + + +if __name__ == "__main__": + unittest.main() From 59ca3f79e1cb5becc6afc5d903d91630c03bb23b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 11 Oct 2021 16:48:45 +0800 Subject: [PATCH 2/4] [DLMED] add optional import Signed-off-by: Nic Ma --- docs/requirements.txt | 1 + docs/source/installation.md | 4 ++-- monai/config/deviceconfig.py | 1 + requirements-dev.txt | 1 + setup.cfg | 3 +++ 5 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 53eb6d3c0d..cefb47e7e0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,3 +21,4 @@ sphinx-autodoc-typehints==1.11.1 pandas einops transformers +mlflow diff --git a/docs/source/installation.md b/docs/source/installation.md index 4bc4aa700a..6936c0bf49 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops` and `transformers`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers` and `mlflow`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index db786a88ef..e542da14ab 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -74,6 +74,7 @@ def get_optional_config_values(): output["pandas"] = get_package_version("pandas") output["einops"] = get_package_version("einops") output["transformers"] = get_package_version("transformers") + output["mlflow"] = get_package_version("mlflow") return output diff --git a/requirements-dev.txt b/requirements-dev.txt index 254cb06d27..9338306d90 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -39,3 +39,4 @@ pandas requests einops transformers +mlflow diff --git a/setup.cfg b/setup.cfg index 19f04de526..aa015d8ec7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ all = pandas einops transformers + mlflow nibabel = nibabel skimage = @@ -77,6 +78,8 @@ einops = einops transformers = transformers +mlflow = + mlflow [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 From b7c7bfce1465af3c5adac89e16f2c3dcb8d9fcfb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 11 Oct 2021 17:09:57 +0800 Subject: [PATCH 3/4] [DLMED] fix doc format Signed-off-by: Nic Ma --- docs/source/handlers.rst | 10 ++++---- monai/handlers/mlflow_handler.py | 44 +++++++++++++++----------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index c48ffd412c..d32b6d88e3 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -160,16 +160,16 @@ Decollate batch .. autoclass:: DecollateBatch :members: -NVTX Handlers -------------- -.. automodule:: monai.handlers.nvtx_handlers - :members: - MLFlow handler -------------- .. autoclass:: MLFlowHandler :members: +NVTX Handlers +------------- +.. automodule:: monai.handlers.nvtx_handlers + :members: + Utilities --------- .. automodule:: monai.handlers.utils diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 41e69f4e79..8c847e0521 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -34,12 +34,33 @@ class MLFlowHandler: And it can track both epoch level and iteration level logging, then MLFlow can store the data and visualize. The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``. + Default behaviors: - When EPOCH_COMPLETED, track each dictionary item in ``engine.state.metrics`` in MLFlow. - When ITERATION_COMPLETED, track expected item in ``self.output_transform(engine.state.output)`` in MLFlow, default to `Loss`. + Args: + tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment + variable to have MLflow find a URI from there. in both cases, the URI can either be + a HTTP/HTTPS URI for a remote server, a database connection string, or a local path + to log data to a directory. The URI defaults to path `mlruns`. + for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. + epoch_logger: customized callable logger for epoch level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + iteration_logger: customized callable logger for iteration level logging with MLFlow. + Must accept parameter "engine", use default logger if None. + output_transform: a callable that is used to transform the + ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. + By default this value logging happens when every iteration completed. + The default behavior is to track loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. + global_epoch_transform: a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to track synced epoch number + with the trainer engine. + tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. + For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. """ @@ -53,29 +74,6 @@ def __init__( global_epoch_transform: Callable = lambda x: x, tag_name: str = DEFAULT_TAG, ) -> None: - """ - - Args: - tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment - variable to have MLflow find a URI from there. in both cases, the URI can either be - a HTTP/HTTPS URI for a remote server, a database connection string, or a local path - to log data to a directory. The URI defaults to path `mlruns`. - for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. - epoch_logger: customized callable logger for epoch level logging with MLFlow. - Must accept parameter "engine", use default logger if None. - iteration_logger: customized callable logger for iteration level logging with MLFlow. - Must accept parameter "engine", use default logger if None. - output_transform: a callable that is used to transform the - ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. - By default this value logging happens when every iteration completed. - The default behavior is to track loss from output[0] as output is a decollated list - and we replicated loss value for every item of the decollated list. - global_epoch_transform: a callable that is used to customize global epoch number. - For example, in evaluation, the evaluator engine might want to track synced epoch number - with the trainer engine. - tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. - - """ if tracking_uri is not None: mlflow.set_tracking_uri(tracking_uri) From ee886aa9b043578e469dfefe0420f0bab9bdc9cd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 11 Oct 2021 18:33:16 +0800 Subject: [PATCH 4/4] [DLMED] fix CI test Signed-off-by: Nic Ma --- tests/test_handler_mlflow.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index d2bd10baf5..f210ebfacc 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -10,8 +10,10 @@ # limitations under the License. import glob +import os import tempfile import unittest +from pathlib import Path from ignite.engine import Engine, Events @@ -35,12 +37,13 @@ def _update_metric(engine): engine.state.metrics["acc"] = current_metric + 0.1 # set up testing handler - handler = MLFlowHandler(tracking_uri="file:/" + tempdir) + test_path = os.path.join(tempdir, "mlflow_test") + handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri()) handler.attach(engine) engine.run(range(3), max_epochs=2) handler.close() # check logging output - self.assertTrue(len(glob.glob(tempdir)) > 0) + self.assertTrue(len(glob.glob(test_path)) > 0) if __name__ == "__main__":