From a975379513d7bc3e39be80c6505c02c73700a613 Mon Sep 17 00:00:00 2001 From: Matthias Burbach <> Date: Tue, 3 Jan 2023 15:02:52 +0100 Subject: [PATCH] Provide resolved run name --- justmltools/repo/mlflow_repo_downloader.py | 23 ++++++++++++++++++++-- tests/repo/test_mlflow_repo_downloader.py | 4 ++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/justmltools/repo/mlflow_repo_downloader.py b/justmltools/repo/mlflow_repo_downloader.py index 82def62..2b7f691 100644 --- a/justmltools/repo/mlflow_repo_downloader.py +++ b/justmltools/repo/mlflow_repo_downloader.py @@ -24,6 +24,7 @@ def __init__( self.__experiment_name = experiment_name self.__run_id = run_id self.__resolved_run_id = None + self.__resolved_run_name = None self.__resolved_experiment_id = None @property @@ -38,6 +39,13 @@ def resolved_run_id(self): self.__resolved_run_id = self.__resolve_run_id(self.__run_id) return self.__resolved_run_id + @property + def resolved_run_name(self): + if self.__resolved_run_name is None: + run_info: RunInfo = self.__get_run_info() + self.__resolved_run_name = run_info.run_name + return self.__resolved_run_name + @property def relative_run_url(self): return f"#/experiments/{self.resolved_experiment_id}/runs/{self.resolved_run_id}" @@ -65,11 +73,22 @@ def _download_object(self, remote_path: str, target_path: str): @lru_cache(maxsize=1) def __get_run_data(self) -> RunData: - client: MlflowClient = self.__get_mlflow_client() - run: Run = client.get_run(run_id=self.resolved_run_id) + run: Run = self.__get_run() run_data: RunData = run.data return run_data + @lru_cache(maxsize=1) + def __get_run_info(self) -> RunInfo: + run: Run = self.__get_run() + run_info: RunInfo = run.info + return run_info + + @lru_cache(maxsize=1) + def __get_run(self) -> Run: + client: MlflowClient = self.__get_mlflow_client() + run: Run = client.get_run(run_id=self.resolved_run_id) + return run + def __resolve_experiment_id(self, experiment_name: str) -> str: """ :param experiment_name: the name of the experiment diff --git a/tests/repo/test_mlflow_repo_downloader.py b/tests/repo/test_mlflow_repo_downloader.py index e8e203a..18a2142 100644 --- a/tests/repo/test_mlflow_repo_downloader.py +++ b/tests/repo/test_mlflow_repo_downloader.py @@ -27,7 +27,7 @@ def test_find_or_download_input_config_object(self): expected_path = "my_config_dir/my_config.json" with patch.object( sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method: - actual_path: str = sut.find_or_download_input_config_object("my_config.json") + actual_path: str = sut.find_or_download_input_config_object(["my_config.json"]) self.assertEqual(expected_path, actual_path) def test_find_or_download_output_object(self): @@ -37,7 +37,7 @@ def test_find_or_download_output_object(self): expected_path = "my_output_dir/my_output.csv" with patch.object( sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method: - actual_path: str = sut.find_or_download_output_object("my_output.csv") + actual_path: str = sut.find_or_download_output_object(["my_output.csv"]) self.assertEqual(expected_path, actual_path) def test_find_or_download_params(self):