From 1e868cf2dfaee844f1a14e49bed1193c7d57921b Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Wed, 14 Dec 2022 00:15:35 +0000 Subject: [PATCH 1/2] Support MLFlow Handler for single process/multi task enviornment Signed-off-by: Sachidanand Alle --- monai/handlers/mlflow_handler.py | 76 ++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index ee63c951a8..79b7e39a88 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -21,6 +21,7 @@ Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow") +mlflow.entities, _ = optional_import("mlflow.entities") if TYPE_CHECKING: from ignite.engine import Engine @@ -109,9 +110,6 @@ def __init__( optimizer_param_names: Union[str, Sequence[str]] = "lr", close_on_complete: bool = False, ) -> None: - if tracking_uri is not None: - mlflow.set_tracking_uri(tracking_uri) - self.iteration_log = iteration_log self.epoch_log = epoch_log self.epoch_logger = epoch_logger @@ -125,8 +123,10 @@ def __init__( self.experiment_param = experiment_param self.artifacts = ensure_tuple(artifacts) self.optimizer_param_names = ensure_tuple(optimizer_param_names) - self.client = mlflow.MlflowClient() + self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None) self.close_on_complete = close_on_complete + self.experiment = None + self.cur_run = None def _delete_exist_param_in_dict(self, param_dict: Dict) -> None: """ @@ -135,9 +135,11 @@ def _delete_exist_param_in_dict(self, param_dict: Dict) -> None: Args: param_dict: parameter dict to be logged to mlflow. """ + if self.cur_run is None: + return + key_list = list(param_dict.keys()) - cur_run = mlflow.active_run() - log_data = self.client.get_run(cur_run.info.run_id).data + log_data = self.client.get_run(self.cur_run.info.run_id).data log_param_dict = log_data.params for key in key_list: if key in log_param_dict: @@ -167,17 +169,52 @@ def start(self, engine: Engine) -> None: Check MLFlow status and start if not active. """ - mlflow.set_experiment(self.experiment_name) - if mlflow.active_run() is None: + self._set_experiment() + if not self.experiment: + raise ValueError(f"Failed to experiment '{self.experiment_name}' as the active experiment") + + if not self.cur_run: run_name = f"run_{time.strftime('%Y%m%d_%H%M%S')}" if self.run_name is None else self.run_name - mlflow.start_run(run_name=run_name) + runs = self.client.search_runs(self.experiment.experiment_id) + runs = [r for r in runs if r.info.run_name == run_name or not self.run_name] + if runs: + self.cur_run = self.client.get_run(runs[-1].info.run_id) # pick latest active run + else: + self.cur_run = self.client.create_run(experiment_id=self.experiment.experiment_id, run_name=run_name) if self.experiment_param: - mlflow.log_params(self.experiment_param) + self._log_params(self.experiment_param) attrs = {attr: getattr(engine.state, attr, None) for attr in self.default_tracking_params} self._delete_exist_param_in_dict(attrs) - mlflow.log_params(attrs) + self._log_params(attrs) + + def _set_experiment(self): + experiment = self.experiment + if not experiment: + experiment = self.client.get_experiment_by_name(self.experiment_name) + if not experiment: + experiment_id = self.client.create_experiment(self.experiment_name) + experiment = self.client.get_experiment(experiment_id) + + if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE: + raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment") + self.experiment = experiment + + def _log_params(self, params: Dict[str, Any]) -> None: + if not self.cur_run: + raise ValueError("Current Run is not Active to log params") + params_arr = [mlflow.entities.Param(key, str(value)) for key, value in params.items()] + self.client.log_batch(run_id=self.cur_run.info.run_id, metrics=[], params=params_arr, tags=[]) + + def _log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + if not self.cur_run: + raise ValueError("Current Run is not Active to log metrics") + + run_id = self.cur_run.info.run_id + timestamp = int(time.time() * 1000) + metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[]) def _parse_artifacts(self): """ @@ -202,17 +239,20 @@ def complete(self) -> None: """ Handler for train or validation/evaluation completed Event. """ - if self.artifacts: + if self.artifacts and self.cur_run: artifact_list = self._parse_artifacts() for artifact in artifact_list: - mlflow.log_artifact(artifact) + self.client.log_artifact(self.cur_run.info.run_id, artifact) def close(self) -> None: """ Stop current running logger of MLFlow. """ - mlflow.end_run() + if self.cur_run: + status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED) + self.client.set_terminated(self.cur_run.info.run_id, status) + self.cur_run = None def epoch_completed(self, engine: Engine) -> None: """ @@ -257,11 +297,11 @@ def _default_epoch_log(self, engine: Engine) -> None: return current_epoch = self.global_epoch_transform(engine.state.epoch) - mlflow.log_metrics(log_dict, step=current_epoch) + self._log_metrics(log_dict, step=current_epoch) if self.state_attributes is not None: attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes} - mlflow.log_metrics(attrs, step=current_epoch) + self._log_metrics(attrs, step=current_epoch) def _default_iteration_log(self, engine: Engine) -> None: """ @@ -281,7 +321,7 @@ def _default_iteration_log(self, engine: Engine) -> None: if not isinstance(loss, dict): loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss} - mlflow.log_metrics(loss, step=engine.state.iteration) + self._log_metrics(loss, step=engine.state.iteration) # If there is optimizer attr in engine, then record parameters specified in init function. if hasattr(engine, "optimizer"): @@ -291,4 +331,4 @@ def _default_iteration_log(self, engine: Engine) -> None: f"{param_name} group_{i}": float(param_group[param_name]) for i, param_group in enumerate(cur_optimizer.param_groups) } - mlflow.log_metrics(params, step=engine.state.iteration) + self._log_metrics(params, step=engine.state.iteration) From 121ead0ac409dc5c3e0faef5aa0d88c95fa74a4a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 15 Dec 2022 09:01:45 +0000 Subject: [PATCH 2/2] fixes typos Signed-off-by: Wenqi Li --- monai/handlers/mlflow_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 79b7e39a88..340138a372 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -53,7 +53,7 @@ class MLFlowHandler: Args: tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment variable to have MLflow find a URI from there. in both cases, the URI can either be - a HTTP/HTTPS URI for a remote server, a database connection string, or a local path + an HTTP/HTTPS URI for a remote server, a database connection string, or a local path to log data to a directory. The URI defaults to path `mlruns`. for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. iteration_log: whether to log data to MLFlow when iteration completed, default to `True`. @@ -171,7 +171,7 @@ def start(self, engine: Engine) -> None: """ self._set_experiment() if not self.experiment: - raise ValueError(f"Failed to experiment '{self.experiment_name}' as the active experiment") + raise ValueError(f"Failed to set experiment '{self.experiment_name}' as the active experiment") if not self.cur_run: run_name = f"run_{time.strftime('%Y%m%d_%H%M%S')}" if self.run_name is None else self.run_name