-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_mlflow_repo_downloader.py
50 lines (43 loc) · 2.44 KB
/
test_mlflow_repo_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from unittest import TestCase
from unittest.mock import Mock, patch
from mlflow.entities import RunData, Param
from typing import Any, Dict
from justmltools.repo.mlflow_repo_downloader import MlflowRepoDownloader
from justmltools.config.local_data_path_config import LocalDataPathConfig
from justmltools.config.mlflow_data_path_config import MlflowDataPathConfig
class TestMlflowRepoDownloader(TestCase):
def create_downloader(self, run_id: str) -> MlflowRepoDownloader:
downloader: MlflowRepoDownloader = MlflowRepoDownloader(
local_data_path_config=LocalDataPathConfig(prefix="my_local_test_prefix"),
remote_data_path_config=MlflowDataPathConfig(),
aws_credentials=Mock(),
experiment_name="my_test_experiment",
run_id=run_id
)
return downloader
def test_find_or_download_input_config_object(self):
run_id: str = "my_test_run_id"
sut: MlflowRepoDownloader = self.create_downloader(run_id)
self.assertEqual(run_id, sut.resolved_run_id)
expected_path = "my_config_dir/my_config.json"
with patch.object(
sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method:
actual_path: str = sut.find_or_download_input_config_object(["my_config.json"])
self.assertEqual(expected_path, actual_path)
def test_find_or_download_output_object(self):
run_id: str = "my_test_run_id"
sut: MlflowRepoDownloader = self.create_downloader(run_id)
self.assertEqual(run_id, sut.resolved_run_id)
expected_path = "my_output_dir/my_output.csv"
with patch.object(
sut, '_AbstractRepoDownloader__find_or_download_object', return_value=expected_path) as method:
actual_path: str = sut.find_or_download_output_object(["my_output.csv"])
self.assertEqual(expected_path, actual_path)
def test_find_or_download_params(self):
run_id: str = "my_test_run_id"
sut: MlflowRepoDownloader = self.create_downloader(run_id)
params = [Param(key="test_param", value="test_param_value")]
run_data: RunData = RunData(metrics=None, params=params, tags=None)
with patch.object(sut, '_MlflowRepoDownloader__get_run_data', return_value=run_data) as method:
run_params: Dict[str, Any] = sut.find_or_download_run_params()
self.assertEqual("test_param_value", run_params["test_param"])