From b90b33fc91a446102df5d3b7679bf074ba3dc0b2 Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Tue, 17 Dec 2024 18:28:21 +0000 Subject: [PATCH] train: Return MLflow tracking information Extend the API to return MLflow tracking information as part of a training or evaluation response, including the experiment and run IDs, and update the tests accordingly. If training is already in progress, the API returns the experiment and run IDs of the current training run. This affects the following routes: * POST /train_supervised * POST /train_unsupervised * POST /train_unsupervised_with_hf_hub_dataset * POST /train_metacat * POST /evaluate Signed-off-by: Phoevos Kalemkeris --- app/api/routers/evaluation.py | 21 ++++++++++++--- app/api/routers/metacat_training.py | 26 +++++++++++++----- app/api/routers/supervised_training.py | 26 +++++++++++++----- app/api/routers/unsupervised_training.py | 30 +++++++++++++++------ app/model_services/base.py | 6 ++--- app/model_services/huggingface_ner_model.py | 4 +-- app/model_services/medcat_model.py | 6 ++--- app/model_services/medcat_model_deid.py | 4 +-- app/trainers/base.py | 28 ++++++++++++------- tests/app/api/test_serving_common.py | 17 +++++++----- tests/app/api/test_serving_hf_ner.py | 3 ++- 11 files changed, 122 insertions(+), 49 deletions(-) diff --git a/app/api/routers/evaluation.py b/app/api/routers/evaluation.py index 97ff305..6cc49a5 100644 --- a/app/api/routers/evaluation.py +++ b/app/api/routers/evaluation.py @@ -57,11 +57,26 @@ async def get_evaluation_with_trainer_export(request: Request, data_file.flush() data_file.seek(0) evaluation_id = tracking_id or str(uuid.uuid4()) - evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)) + evaluation_accepted, experiment_id, run_id = model_service.train_supervised( + data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names) + ) if evaluation_accepted: - return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your evaluation started successfully.", + "evaluation_id": evaluation_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) @router.post("/sanity-check", diff --git a/app/api/routers/metacat_training.py b/app/api/routers/metacat_training.py index 1e16e20..5e60adb 100644 --- a/app/api/routers/metacat_training.py +++ b/app/api/routers/metacat_training.py @@ -2,7 +2,7 @@ import uuid import json import logging -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File @@ -53,7 +53,7 @@ async def train_metacat(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_metacat(data_file, + training_response = model_service.train_metacat(data_file, epochs, log_frequency, training_id, @@ -65,13 +65,27 @@ async def train_metacat(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry your training later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index fd66443..9a49c60 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -2,7 +2,7 @@ import uuid import json import logging -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form @@ -55,7 +55,7 @@ async def train_supervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_supervised(data_file, + training_response = model_service.train_supervised(data_file, epochs, log_frequency, training_id, @@ -69,13 +69,27 @@ async def train_supervised(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry your training later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/api/routers/unsupervised_training.py b/app/api/routers/unsupervised_training.py index c3925aa..831cd64 100644 --- a/app/api/routers/unsupervised_training.py +++ b/app/api/routers/unsupervised_training.py @@ -5,7 +5,7 @@ import logging import datasets import zipfile -from typing import List, Union +from typing import List, Tuple, Union from typing_extensions import Annotated from fastapi import APIRouter, Depends, UploadFile, Query, Request, File @@ -65,7 +65,7 @@ async def train_unsupervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_unsupervised(data_file, + training_response = model_service.train_unsupervised(data_file, epochs, log_frequency, training_id, @@ -79,7 +79,7 @@ async def train_unsupervised(request: Request, for file in files: file.close() - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) @router.post("/train_unsupervised_with_hf_hub_dataset", @@ -133,7 +133,7 @@ async def train_unsupervised_with_hf_dataset(request: Request, hf_dataset.save_to_disk(data_dir.name) training_id = tracking_id or str(uuid.uuid4()) - training_accepted = model_service.train_unsupervised(data_dir, + training_response = model_service.train_unsupervised(data_dir, epochs, log_frequency, training_id, @@ -143,13 +143,27 @@ async def train_unsupervised_with_hf_dataset(request: Request, lr_override=lr_override, test_size=test_size, description=description) - return _get_training_response(training_accepted, training_id) + return _get_training_response(training_response, training_id) -def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: +def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse: + training_accepted, experiment_id, run_id = training_response if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your training started successfully.", + "training_id": training_id, + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_202_ACCEPTED + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": "Another training or evaluation on this model is still active. Please retry later.", + "experiment_id": experiment_id, + "run_id": run_id, + }, status_code=HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/app/model_services/base.py b/app/model_services/base.py index fceb9a8..b431eff 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -56,11 +56,11 @@ def batch_annotate(self, texts: List[str]) -> List[List[Dict[str, Any]]]: def init_model(self) -> None: raise NotImplementedError - def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError - def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError - def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool: + def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]: raise NotImplementedError diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index 2c8b0d4..afd8e24 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -156,7 +156,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -170,7 +170,7 @@ def train_unsupervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index 214414a..928c48a 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -119,7 +119,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -133,7 +133,7 @@ def train_unsupervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) @@ -147,7 +147,7 @@ def train_metacat(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._metacat_trainer is None: raise ConfigurationException("The metacat trainer is not enabled") return self._metacat_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index deba5fe..43ed794 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -2,7 +2,7 @@ import inspect import threading import torch -from typing import Dict, List, TextIO, Optional, Any, final, Callable +from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable from functools import partial from transformers import pipeline from medcat.cat import CAT @@ -147,7 +147,7 @@ def train_supervised(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: if self._supervised_trainer is None: raise ConfigurationException("Trainers are not enabled") return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) diff --git a/app/trainers/base.py b/app/trainers/base.py index 2cc22fe..da4f795 100644 --- a/app/trainers/base.py +++ b/app/trainers/base.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import TextIO, Callable, Dict, Optional, Any, List, Union, final +from typing import TextIO, Callable, Dict, Tuple, Optional, Any, List, Union, final from config import Settings from management.tracker_client import TrackerClient from data import doc_dataset, anno_dataset @@ -26,6 +26,8 @@ def __init__(self, config: Settings, model_name: str) -> None: self._model_name = model_name self._training_lock = threading.Lock() self._training_in_progress = False + self._experiment_id = None + self._run_id = None self._tracker_client = TrackerClient(self._config.MLFLOW_TRACKING_URI) self._executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(max_workers=1) @@ -37,6 +39,14 @@ def model_name(self) -> str: def model_name(self, model_name: str) -> None: self._model_name = model_name + @property + def experiment_id(self) -> str: + return self._experiment_id or "" + + @property + def run_id(self) -> str: + return self._run_id or "" + @final def start_training(self, run: Callable, @@ -48,13 +58,13 @@ def start_training(self, input_file_name: str, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, - synchronised: bool = False) -> bool: + synchronised: bool = False) -> Tuple[bool, str, str]: with self._training_lock: if self._training_in_progress: - return False + return False, self.experiment_id, self.run_id else: loop = asyncio.get_event_loop() - experiment_id, run_id = self._tracker_client.start_tracking( + self._experiment_id, self._run_id = self._tracker_client.start_tracking( model_name=self._model_name, input_file_name=input_file_name, base_model_original=self._config.BASE_MODEL_FULL_PATH, @@ -101,15 +111,15 @@ def start_training(self, else: raise ValueError(f"Unknown training type: {training_type}") - logger.info("Starting training job: %s with experiment ID: %s", training_id, experiment_id) + logger.info("Starting training job: %s with experiment ID: %s", training_id, self.experiment_id) self._training_in_progress = True training_task = asyncio.ensure_future(loop.run_in_executor(self._executor, - partial(run, self, training_params, data_file, log_frequency, run_id, description))) + partial(run, self, training_params, data_file, log_frequency, self.run_id, description))) if synchronised: loop.run_until_complete(training_task) - return True + return True, self.experiment_id, self.run_id @staticmethod def _make_model_file_copy(model_file_path: str, run_id: str) -> str: @@ -161,7 +171,7 @@ def train(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: training_type = TrainingType.SUPERVISED.value training_params = { "data_path": data_file.name, @@ -204,7 +214,7 @@ def train(self, raw_data_files: Optional[List[TextIO]] = None, description: Optional[str] = None, synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + **hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]: training_type = TrainingType.UNSUPERVISED.value training_params = { "nepochs": epochs, diff --git a/tests/app/api/test_serving_common.py b/tests/app/api/test_serving_common.py index b0be036..91b156c 100644 --- a/tests/app/api/test_serving_common.py +++ b/tests/app/api/test_serving_common.py @@ -11,7 +11,7 @@ from utils import get_settings from model_services.medcat_model import MedCATModel from management.model_manager import ModelManager -from unittest.mock import create_autospec +from unittest.mock import create_autospec, patch config = get_settings() config.ENABLE_TRAINING_APIS = "true" @@ -258,13 +258,14 @@ def test_preview_trainer_export_on_missing_project_or_document(pid, did, client) def test_train_supervised(model_service, client): + model_service.train_supervised.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/train_supervised", files=[("trainer_export", f)]) model_service.train_supervised.assert_called() assert response.status_code == 202 assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -278,13 +279,14 @@ def test_train_supervised(model_service, client): def test_train_unsupervised(model_service, client): + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") with tempfile.TemporaryFile("r+b") as f: f.write(str.encode("[\"Spinal stenosis\"]")) response = client.post("/train_unsupervised", files=[("training_data", f)]) model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with tempfile.TemporaryFile("r+b") as f: @@ -305,12 +307,13 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): "model_card": None, }) model_service.info.return_value = model_card + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID response = client.post(f"/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb&tracking_id={TRACKING_ID}") @@ -322,13 +325,14 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): def test_train_metacat(model_service, client): + model_service.train_metacat.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/train_metacat", files=[("trainer_export", f)]) model_service.train_metacat.assert_called() assert response.status_code == 202 assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all(key in response.json() for key in ["training_id", "experiment_id", "run_id"]) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -341,7 +345,8 @@ def test_train_metacat(model_service, client): assert response.json().get("training_id") == TRACKING_ID -def test_evaluate_with_trainer_export(client): +def test_evaluate_with_trainer_export(model_service, client): + model_service.train_supervised.return_value = (True, "experiment_id", "run_id") with open(TRAINER_EXPORT_PATH, "rb") as f: response = client.post("/evaluate", files=[("trainer_export", f)]) diff --git a/tests/app/api/test_serving_hf_ner.py b/tests/app/api/test_serving_hf_ner.py index c6ac825..c0f8dfa 100644 --- a/tests/app/api/test_serving_hf_ner.py +++ b/tests/app/api/test_serving_hf_ner.py @@ -44,9 +44,10 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): "model_card": None, }) model_service.info.return_value = model_card + model_service.train_unsupervised.return_value = (True, "experiment_id", "run_id") response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." - assert "training_id" in response.json() + assert all([key in response.json() for key in ["training_id", "experiment_id", "run_id"]])