Skip to content

Commit

Permalink
Add synchronous parameter to MLflowLogger (#19639)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
  • Loading branch information
clumsy and azzhipa committed Apr 3, 2024
1 parent 8947d13 commit ce88483
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow")


class MLFlowLogger(Logger):
Expand Down Expand Up @@ -100,6 +101,8 @@ def any_lightning_module_function_or_hook(self):
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
default.
run_id: The run identifier of the experiment. If not provided, a new run is started.
synchronous: Hints mlflow whether to block the execution for every logging call until complete where
applicable. Requires mlflow >= 2.8.0
Raises:
ModuleNotFoundError:
Expand All @@ -120,9 +123,12 @@ def __init__(
prefix: str = "",
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
synchronous: Optional[bool] = None,
):
if not _MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0")
super().__init__()
if not tracking_uri:
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
Expand All @@ -138,7 +144,7 @@ def __init__(
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._prefix = prefix
self._artifact_location = artifact_location

self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
self._initialized = False

from mlflow.tracking import MlflowClient
Expand Down Expand Up @@ -233,7 +239,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:

# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs)

@override
@rank_zero_only
Expand Down Expand Up @@ -261,7 +267,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
k = new_k
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))

self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)

@override
@rank_zero_only
Expand Down
59 changes: 58 additions & 1 deletion tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import pytest
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags
from lightning.pytorch.loggers.mlflow import (
_MLFLOW_AVAILABLE,
_MLFLOW_SYNCHRONOUS_AVAILABLE,
MLFlowLogger,
_get_resolve_tags,
)


def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
Expand Down Expand Up @@ -260,6 +265,58 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
)


@pytest.mark.parametrize("synchronous", [False, True])
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
"""Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
if not _MLFLOW_SYNCHRONOUS_AVAILABLE:
pytest.skip("this test requires mlflow>=2.8.0")

time = mlflow_mock.entities.time
metric = mlflow_mock.entities.Metric
param = mlflow_mock.entities.Param
time.return_value = 1

mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
mlflow_client.get_experiment_by_name.return_value = None
logger = MLFlowLogger(
"test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=synchronous
)

params = {"test": "test_param"}
logger.log_hyperparams(params)

mlflow_client.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test", value="test_param")], synchronous=synchronous
)
param.assert_called_with(key="test", value="test_param")

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

mlflow_client.log_batch.assert_called_with(
run_id=logger.run_id,
metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)],
synchronous=synchronous,
)
metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)

mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location")


@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False})
def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
"""Test that the logger does not support synchronous flag."""
time = mlflow_mock.entities.time
time.return_value = 1

mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
mlflow_client.get_experiment_by_name.return_value = None
with pytest.raises(ModuleNotFoundError):
MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True)


@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
Expand Down

0 comments on commit ce88483

Please sign in to comment.