Skip to content

Commit

Permalink
Provide resolved run name
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Burbach committed Jan 3, 2023
1 parent 760c6fe commit a975379
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
23 changes: 21 additions & 2 deletions justmltools/repo/mlflow_repo_downloader.py
Expand Up @@ -24,6 +24,7 @@ def __init__(
self.__experiment_name = experiment_name
self.__run_id = run_id
self.__resolved_run_id = None
self.__resolved_run_name = None
self.__resolved_experiment_id = None

@property
Expand All @@ -38,6 +39,13 @@ def resolved_run_id(self):
self.__resolved_run_id = self.__resolve_run_id(self.__run_id)
return self.__resolved_run_id

@property
def resolved_run_name(self):
if self.__resolved_run_name is None:
run_info: RunInfo = self.__get_run_info()
self.__resolved_run_name = run_info.run_name
return self.__resolved_run_name

@property
def relative_run_url(self):
return f"#/experiments/{self.resolved_experiment_id}/runs/{self.resolved_run_id}"
Expand Down Expand Up @@ -65,11 +73,22 @@ def _download_object(self, remote_path: str, target_path: str):

@lru_cache(maxsize=1)
def __get_run_data(self) -> RunData:
client: MlflowClient = self.__get_mlflow_client()
run: Run = client.get_run(run_id=self.resolved_run_id)
run: Run = self.__get_run()
run_data: RunData = run.data
return run_data

@lru_cache(maxsize=1)
def __get_run_info(self) -> RunInfo:
run: Run = self.__get_run()
run_info: RunInfo = run.info
return run_info

@lru_cache(maxsize=1)
def __get_run(self) -> Run:
client: MlflowClient = self.__get_mlflow_client()
run: Run = client.get_run(run_id=self.resolved_run_id)
return run

def __resolve_experiment_id(self, experiment_name: str) -> str:
"""
:param experiment_name: the name of the experiment
Expand Down
4 changes: 2 additions & 2 deletions tests/repo/test_mlflow_repo_downloader.py
Expand Up @@ -27,7 +27,7 @@ def test_find_or_download_input_config_object(self):
expected_path = "my_config_dir/my_config.json"
with patch.object(
sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method:
actual_path: str = sut.find_or_download_input_config_object("my_config.json")
actual_path: str = sut.find_or_download_input_config_object(["my_config.json"])
self.assertEqual(expected_path, actual_path)

def test_find_or_download_output_object(self):
Expand All @@ -37,7 +37,7 @@ def test_find_or_download_output_object(self):
expected_path = "my_output_dir/my_output.csv"
with patch.object(
sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method:
actual_path: str = sut.find_or_download_output_object("my_output.csv")
actual_path: str = sut.find_or_download_output_object(["my_output.csv"])
self.assertEqual(expected_path, actual_path)

def test_find_or_download_params(self):
Expand Down

0 comments on commit a975379

Please sign in to comment.