diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index 3f43f97fbe..f41957840f 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -11,8 +11,10 @@ import glob import os +import shutil import tempfile import unittest +from concurrent.futures import ThreadPoolExecutor import numpy as np from ignite.engine import Engine, Events @@ -21,7 +23,38 @@ from monai.utils import path_to_uri +def dummy_train(tracking_folder): + tempdir = tempfile.mkdtemp() + + # set up engine + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + # set up testing handler + test_path = os.path.join(tempdir, tracking_folder) + handler = MLFlowHandler( + iteration_log=False, + epoch_log=True, + tracking_uri=path_to_uri(test_path), + state_attributes=["test"], + close_on_complete=True, + ) + handler.attach(engine) + engine.run(range(3), max_epochs=2) + return test_path + + class TestHandlerMLFlow(unittest.TestCase): + def setUp(self): + self.tmpdir_list = [] + + def tearDown(self): + for tmpdir in self.tmpdir_list: + if tmpdir and os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + def test_metrics_track(self): experiment_param = {"backbone": "efficientnet_b0"} with tempfile.TemporaryDirectory() as tempdir: @@ -61,6 +94,18 @@ def _update_metric(engine): # check logging output self.assertTrue(len(glob.glob(test_path)) > 0) + def test_multi_thread(self): + test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"] + with ThreadPoolExecutor(2, "Training") as executor: + futures = {} + for t in test_uri_list: + futures[t] = executor.submit(dummy_train, t) + + for _, future in futures.items(): + res = future.result() + self.tmpdir_list.append(res) + self.assertTrue(len(glob.glob(res)) > 0) + if __name__ == "__main__": unittest.main()