From ee9f369421471ae3a39c0ebe40539132a0de23cf Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Tue, 17 Dec 2024 15:20:25 +0000 Subject: [PATCH 1/3] tests: Ensure tracking ID is included in responses Signed-off-by: Phoevos Kalemkeris --- tests/app/api/test_serving_common.py | 133 +++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/tests/app/api/test_serving_common.py b/tests/app/api/test_serving_common.py index d05e205..b0be036 100644 --- a/tests/app/api/test_serving_common.py +++ b/tests/app/api/test_serving_common.py @@ -20,6 +20,7 @@ config.ENABLE_PREVIEWS_APIS = "true" config.AUTH_USER_ENABLED = "true" +TRACKING_ID = "123e4567-e89b-12d3-a456-426614174000" TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json") NOTE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "note.txt") ANOTHER_TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "another_trainer_export.json") @@ -196,6 +197,19 @@ def test_preview_trainer_export(client): assert response.headers["Content-Type"] == "application/octet-stream" assert len(response.text.split("
")) == 4 + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f1: + with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: + response = client.post(f"/preview_trainer_export?tracking_id={TRACKING_ID}", files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ]) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/octet-stream" + assert len(response.text.split("
")) == 4 + assert TRACKING_ID in response.headers["Content-Disposition"] + def test_preview_trainer_export_str(client): with open(TRAINER_EXPORT_PATH, "r") as f: @@ -252,6 +266,16 @@ def test_train_supervised(model_service, client): assert response.json()["message"] == "Your training started successfully." assert "training_id" in response.json() + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f: + response = client.post(f"/train_supervised?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + + model_service.train_supervised.assert_called() + assert response.status_code == 202 + assert response.json()["message"] == "Your training started successfully." + assert "training_id" in response.json() + assert response.json().get("training_id") == TRACKING_ID + def test_train_unsupervised(model_service, client): with tempfile.TemporaryFile("r+b") as f: @@ -262,6 +286,16 @@ def test_train_unsupervised(model_service, client): assert response.json()["message"] == "Your training started successfully." assert "training_id" in response.json() + # test with provided tracking ID + with tempfile.TemporaryFile("r+b") as f: + f.write(str.encode("[\"Spinal stenosis\"]")) + response = client.post(f"/train_unsupervised?tracking_id={TRACKING_ID}", files=[("training_data", f)]) + + model_service.train_unsupervised.assert_called() + assert response.json()["message"] == "Your training started successfully." + assert "training_id" in response.json() + assert response.json().get("training_id") == TRACKING_ID + def test_train_unsupervised_with_hf_hub_dataset(model_service, client): model_card = ModelCard.parse_obj({ @@ -278,6 +312,14 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): assert response.json()["message"] == "Your training started successfully." assert "training_id" in response.json() + # test with provided tracking ID + response = client.post(f"/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb&tracking_id={TRACKING_ID}") + + model_service.train_unsupervised.assert_called() + assert response.json()["message"] == "Your training started successfully." + assert "training_id" in response.json() + assert response.json().get("training_id") == TRACKING_ID + def test_train_metacat(model_service, client): with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -288,6 +330,16 @@ def test_train_metacat(model_service, client): assert response.json()["message"] == "Your training started successfully." assert "training_id" in response.json() + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f: + response = client.post(f"/train_metacat?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + + model_service.train_metacat.assert_called() + assert response.status_code == 202 + assert response.json()["message"] == "Your training started successfully." + assert "training_id" in response.json() + assert response.json().get("training_id") == TRACKING_ID + def test_evaluate_with_trainer_export(client): with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -297,6 +349,15 @@ def test_evaluate_with_trainer_export(client): assert response.json()["message"] == "Your evaluation started successfully." assert "evaluation_id" in response.json() + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f: + response = client.post(f"/evaluate?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + + assert response.status_code == 202 + assert response.json()["message"] == "Your evaluation started successfully." + assert "evaluation_id" in response.json() + assert response.json().get("evaluation_id") == TRACKING_ID + def test_sanity_check_with_trainer_export(client): with open(TRAINER_EXPORT_PATH, "rb") as f: @@ -306,6 +367,15 @@ def test_sanity_check_with_trainer_export(client): assert response.headers["Content-Type"] == "text/csv; charset=utf-8" assert response.text.split("\n")[0] == "concept,name,precision,recall,f1" + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f: + response = client.post(f"/sanity-check?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/csv; charset=utf-8" + assert response.text.split("\n")[0] == "concept,name,precision,recall,f1" + assert TRACKING_ID in response.headers["Content-Disposition"] + def test_inter_annotator_agreement_scores_per_concept(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: @@ -319,6 +389,19 @@ def test_inter_annotator_agreement_scores_per_concept(client): assert response.headers["Content-Type"] == "text/csv; charset=utf-8" assert response.text.split("\n")[0] == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f1: + with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: + response = client.post(f"/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_concept&tracking_id={TRACKING_ID}", files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ]) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/csv; charset=utf-8" + assert response.text.split("\n")[0] == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + assert TRACKING_ID in response.headers["Content-Disposition"] + @pytest.mark.parametrize("pid_a,pid_b,error_message", [(0, 2, "Cannot find the project with ID: 0"), (1, 3, "Cannot find the project with ID: 3")]) def test_project_not_found_on_getting_iaa_scores(pid_a, pid_b, error_message, client): @@ -381,6 +464,19 @@ def test_concat_trainer_exports(client): assert response.headers["Content-Type"] == "application/json; charset=utf-8" assert len(response.text) == 36918 + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f1: + with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: + response = client.post(f"/concat_trainer_exports?tracking_id={TRACKING_ID}", files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ]) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/json; charset=utf-8" + assert len(response.text) == 36918 + assert TRACKING_ID in response.headers["Content-Disposition"] + def test_get_annotation_stats(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: @@ -394,6 +490,19 @@ def test_get_annotation_stats(client): assert response.headers["Content-Type"] == "text/csv; charset=utf-8" assert response.text.split("\n")[0] == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + # test with provided tracking ID + with open(TRAINER_EXPORT_PATH, "rb") as f1: + with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: + response = client.post(f"/annotation-stats?tracking_id={TRACKING_ID}", files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ]) + + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/csv; charset=utf-8" + assert response.text.split("\n")[0] == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + assert TRACKING_ID in response.headers["Content-Disposition"] + def test_extract_entities_from_text_list_file_as_json_file(model_service, client): annotations_list = [ @@ -435,3 +544,27 @@ def test_extract_entities_from_text_list_file_as_json_file(model_service, client }, }] }] * 15 + + # test with provided tracking ID + with open(MULTI_TEXTS_FILE_PATH, "rb") as f: + response = client.post(f"/process_bulk_file?tracking_id={TRACKING_ID}", files=[("multi_text_file", f)]) + + assert isinstance(response, httpx.Response) + assert json.loads(response.content) == [{ + "text": "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", + "annotations": [{ + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": { + "value": "Affirmed", + "confidence": 0.9999833106994629, + "name": "Status" + } + }, + }] + }] * 15 + assert TRACKING_ID in response.headers["Content-Disposition"] From a8b8a21600370c78ca201e2cf38c836eeab3da46 Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Tue, 17 Dec 2024 15:34:20 +0000 Subject: [PATCH 2/3] api: Add tracking_id query param to API endpoints Signed-off-by: Phoevos Kalemkeris --- app/api/routers/evaluation.py | 27 ++++++++++++++++-------- app/api/routers/invocation.py | 4 +++- app/api/routers/metacat_training.py | 3 ++- app/api/routers/preview.py | 10 ++++++--- app/api/routers/supervised_training.py | 3 ++- app/api/routers/unsupervised_training.py | 6 ++++-- 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/app/api/routers/evaluation.py b/app/api/routers/evaluation.py index 6a39e91..c07069a 100644 --- a/app/api/routers/evaluation.py +++ b/app/api/routers/evaluation.py @@ -4,7 +4,7 @@ import uuid import tempfile -from typing import List +from typing import List, Union from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE from typing_extensions import Annotated from fastapi import APIRouter, Query, Depends, UploadFile, Request, File @@ -34,6 +34,7 @@ description="Evaluate the model being served with a trainer export") async def get_evaluation_with_trainer_export(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the evaluation task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] @@ -54,7 +55,7 @@ async def get_evaluation_with_trainer_export(request: Request, json.dump(concatenated, data_file) data_file.flush() data_file.seek(0) - evaluation_id = str(uuid.uuid4()) + evaluation_id = tracking_id or str(uuid.uuid4()) evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)) if evaluation_accepted: return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED) @@ -69,6 +70,7 @@ async def get_evaluation_with_trainer_export(request: Request, description="Sanity check the model being served with a trainer export") def get_sanity_check_with_trainer_export(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the sanity check task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: files = [] file_names = [] @@ -88,8 +90,9 @@ def get_sanity_check_with_trainer_export(request: Request, metrics = sanity_check_model_with_trainer_export(concatenated, model_service, return_df=True, include_anchors=False) stream = io.StringIO() metrics.to_csv(stream, index=False) + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv") - response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{str(uuid.uuid4())}.csv"' + response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{tracking_id}.csv"' return response @@ -102,7 +105,8 @@ def get_inter_annotator_agreement_scores(request: Request, trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")], annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")], annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")], - scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")]) -> StreamingResponse: + scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the IAA task")] = None) -> StreamingResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -126,8 +130,9 @@ def get_inter_annotator_agreement_scores(request: Request, raise AnnotationException(f'Unknown scope: "{scope}"') stream = io.StringIO() iaa_scores.to_csv(stream, index=False) + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv") - response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{str(uuid.uuid4())}.csv"' + response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{tracking_id}.csv"' return response @@ -137,7 +142,8 @@ def get_inter_annotator_agreement_scores(request: Request, dependencies=[Depends(cms_globals.props.current_active_user)], description="Concatenate multiple trainer export files into a single file for download") def get_concatenated_trainer_exports(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")]) -> JSONResponse: + trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the concatenation task")] = None) -> JSONResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -148,8 +154,9 @@ def get_concatenated_trainer_exports(request: Request, concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) for file in files: file.close() + tracking_id = tracking_id or str(uuid.uuid4()) response = JSONResponse(concatenated, media_type="application/json; charset=utf-8") - response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"' + response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"' return response @@ -159,7 +166,8 @@ def get_concatenated_trainer_exports(request: Request, dependencies=[Depends(cms_globals.props.current_active_user)], description="Get annotation stats of trainer export files") def get_annotation_stats(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")]) -> StreamingResponse: + trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the annotation stats task")] = None) -> StreamingResponse: files = [] file_names = [] for te in trainer_export: @@ -177,6 +185,7 @@ def get_annotation_stats(request: Request, stats = get_stats_from_trainer_export(concatenated, return_df=True) stream = io.StringIO() stats.to_csv(stream, index=False) + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv") - response.headers["Content-Disposition"] = f'attachment ; filename="stats_{str(uuid.uuid4())}.csv"' + response.headers["Content-Disposition"] = f'attachment ; filename="stats_{tracking_id}.csv"' return response diff --git a/app/api/routers/invocation.py b/app/api/routers/invocation.py index dee5a8b..f681355 100644 --- a/app/api/routers/invocation.py +++ b/app/api/routers/invocation.py @@ -132,6 +132,7 @@ def get_entities_from_multiple_texts(request: Request, description="Upload a file containing a list of plain text and extract the NER entities in JSON") def extract_entities_from_multi_text_file(request: Request, multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the bulk processing task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: with tempfile.NamedTemporaryFile() as data_file: for line in multi_text_file.file: @@ -160,8 +161,9 @@ def extract_entities_from_multi_text_file(request: Request, output = json.dumps(body) logger.debug(output) json_file = BytesIO(output.encode()) + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(json_file, media_type="application/json") - response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"' + response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"' return response diff --git a/app/api/routers/metacat_training.py b/app/api/routers/metacat_training.py index d8a80c6..7d4a971 100644 --- a/app/api/routers/metacat_training.py +++ b/app/api/routers/metacat_training.py @@ -29,6 +29,7 @@ async def train_metacat(request: Request, epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None, + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] @@ -49,7 +50,7 @@ async def train_metacat(request: Request, json.dump(concatenated, data_file) data_file.flush() data_file.seek(0) - training_id = str(uuid.uuid4()) + training_id = tracking_id or str(uuid.uuid4()) try: training_accepted = model_service.train_metacat(data_file, epochs, diff --git a/app/api/routers/preview.py b/app/api/routers/preview.py index 831f0f1..f5c36c3 100644 --- a/app/api/routers/preview.py +++ b/app/api/routers/preview.py @@ -27,14 +27,16 @@ description="Extract the NER entities in HTML for preview") async def get_rendered_entities_from_text(request: Request, text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")], + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the preview task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: annotations = model_service.annotate(text) entities = annotations_to_entities(annotations, model_service.model_name) logger.debug("Entities extracted for previewing %s", entities) ent_input = Doc(text=text, ents=entities) data = displacy.render(ent_input.dict(), style="ent", manual=True) + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(BytesIO(data.encode()), media_type="application/octet-stream") - response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"' + response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"' return response @@ -47,7 +49,8 @@ def get_rendered_entities_from_trainer_export(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")] = [], trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}", project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None, - document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None) -> Response: + document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None, + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the trainer export preview task")] = None) -> Response: data: Dict = {"projects": []} if trainer_export is not None: files = [] @@ -88,8 +91,9 @@ def get_rendered_entities_from_trainer_export(request: Request, doc = Doc(text=document["text"], ents=entities, title=f"P{project['id']}/D{document['id']}") htmls.append(displacy.render(doc.dict(), style="ent", manual=True)) if htmls: + tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(BytesIO("
".join(htmls).encode()), media_type="application/octet-stream") - response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"' + response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"' else: logger.debug("Cannot find any matching documents to preview") return JSONResponse(content={"message": "Cannot find any matching documents to preview"}, status_code=HTTP_404_NOT_FOUND) diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index 9d24880..837fde7 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -32,6 +32,7 @@ async def train_supervised(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None, + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] @@ -51,7 +52,7 @@ async def train_supervised(request: Request, json.dump(concatenated, data_file) data_file.flush() data_file.seek(0) - training_id = str(uuid.uuid4()) + training_id = tracking_id or str(uuid.uuid4()) try: training_accepted = model_service.train_supervised(data_file, epochs, diff --git a/app/api/routers/unsupervised_training.py b/app/api/routers/unsupervised_training.py index a88f18a..1680000 100644 --- a/app/api/routers/unsupervised_training.py +++ b/app/api/routers/unsupervised_training.py @@ -33,6 +33,7 @@ async def train_unsupervised(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage", ge=0.0)] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: """ Upload one or more plain text files and trigger the unsupervised training @@ -61,7 +62,7 @@ async def train_unsupervised(request: Request, logger.debug("Training data concatenated") data_file.flush() data_file.seek(0) - training_id = str(uuid.uuid4()) + training_id = tracking_id or str(uuid.uuid4()) try: training_accepted = model_service.train_unsupervised(data_file, epochs, @@ -96,6 +97,7 @@ async def train_unsupervised_with_hf_dataset(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage will only take effect if the dataset does not have predefined validation or test splits", ge=0.0)] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: """ Trigger the unsupervised training with a dataset from Hugging Face Hub @@ -129,7 +131,7 @@ async def train_unsupervised_with_hf_dataset(request: Request, logger.debug("Training dataset downloaded and transformed") hf_dataset.save_to_disk(data_dir.name) - training_id = str(uuid.uuid4()) + training_id = tracking_id or str(uuid.uuid4()) training_accepted = model_service.train_unsupervised(data_dir, epochs, log_frequency, From d9fde4fb7383e0372e55eb743293addbcc7e8ecd Mon Sep 17 00:00:00 2001 From: Phoevos Kalemkeris Date: Wed, 18 Dec 2024 14:25:39 +0000 Subject: [PATCH 3/3] api: Add tracking ID validation Validate the tracking ID in the API endpoints that require it, ensuring it's an alphanumeric string of length 1-256. The implementation and tests are based on MLflow's internal run ID validation: https://github.com/mlflow/mlflow/blob/92a1664ddbd7ef59f8db45e988e41437d179c3b1/mlflow/utils/validation.py#L374-L377 Signed-off-by: Phoevos Kalemkeris --- app/api/dependencies.py | 19 +++++++++++++++ app/api/routers/evaluation.py | 11 +++++---- app/api/routers/invocation.py | 3 ++- app/api/routers/metacat_training.py | 3 ++- app/api/routers/preview.py | 5 ++-- app/api/routers/supervised_training.py | 3 ++- app/api/routers/unsupervised_training.py | 5 ++-- tests/app/api/test_dependencies.py | 30 +++++++++++++++++++++++- 8 files changed, 66 insertions(+), 13 deletions(-) diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 176a478..1fc1082 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -1,4 +1,10 @@ import logging +import re +from typing import Union +from typing_extensions import Annotated + +from fastapi import HTTPException, Query +from starlette.status import HTTP_400_BAD_REQUEST from typing import Optional from config import Settings @@ -6,6 +12,8 @@ from model_services.base import AbstractModelService from management.model_manager import ModelManager +TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$") + logger = logging.getLogger("cms") @@ -45,3 +53,14 @@ def __init__(self, model_service: AbstractModelService) -> None: def __call__(self) -> ModelManager: return self._model_manager + + +def validate_tracking_id( + tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None, +) -> Union[str, None]: + if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256", + ) + return tracking_id diff --git a/app/api/routers/evaluation.py b/app/api/routers/evaluation.py index c07069a..97ff305 100644 --- a/app/api/routers/evaluation.py +++ b/app/api/routers/evaluation.py @@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse, JSONResponse import api.globals as cms_globals +from api.dependencies import validate_tracking_id from domain import Tags, Scope from model_services.base import AbstractModelService from processors.metrics_collector import ( @@ -34,7 +35,7 @@ description="Evaluate the model being served with a trainer export") async def get_evaluation_with_trainer_export(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the evaluation task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] @@ -70,7 +71,7 @@ async def get_evaluation_with_trainer_export(request: Request, description="Sanity check the model being served with a trainer export") def get_sanity_check_with_trainer_export(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the sanity check task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: files = [] file_names = [] @@ -106,7 +107,7 @@ def get_inter_annotator_agreement_scores(request: Request, annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")], annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")], scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the IAA task")] = None) -> StreamingResponse: + tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -143,7 +144,7 @@ def get_inter_annotator_agreement_scores(request: Request, description="Concatenate multiple trainer export files into a single file for download") def get_concatenated_trainer_exports(request: Request, trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the concatenation task")] = None) -> JSONResponse: + tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> JSONResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -167,7 +168,7 @@ def get_concatenated_trainer_exports(request: Request, description="Get annotation stats of trainer export files") def get_annotation_stats(request: Request, trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the annotation stats task")] = None) -> StreamingResponse: + tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse: files = [] file_names = [] for te in trainer_export: diff --git a/app/api/routers/invocation.py b/app/api/routers/invocation.py index f681355..0a8114e 100644 --- a/app/api/routers/invocation.py +++ b/app/api/routers/invocation.py @@ -20,6 +20,7 @@ from domain import TextWithAnnotations, TextWithPublicKey, TextStreamItem, ModelCard, Tags from model_services.base import AbstractModelService from utils import get_settings +from api.dependencies import validate_tracking_id from api.utils import get_rate_limiter, encrypt from management.prometheus_metrics import ( cms_doc_annotations, @@ -132,7 +133,7 @@ def get_entities_from_multiple_texts(request: Request, description="Upload a file containing a list of plain text and extract the NER entities in JSON") def extract_entities_from_multi_text_file(request: Request, multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the bulk processing task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: with tempfile.NamedTemporaryFile() as data_file: for line in multi_text_file.file: diff --git a/app/api/routers/metacat_training.py b/app/api/routers/metacat_training.py index 7d4a971..1e16e20 100644 --- a/app/api/routers/metacat_training.py +++ b/app/api/routers/metacat_training.py @@ -10,6 +10,7 @@ from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE import api.globals as cms_globals +from api.dependencies import validate_tracking_id from domain import Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports @@ -29,7 +30,7 @@ async def train_metacat(request: Request, epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None, - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] diff --git a/app/api/routers/preview.py b/app/api/routers/preview.py index f5c36c3..2def28b 100644 --- a/app/api/routers/preview.py +++ b/app/api/routers/preview.py @@ -11,6 +11,7 @@ from starlette.status import HTTP_404_NOT_FOUND import api.globals as cms_globals +from api.dependencies import validate_tracking_id from domain import Doc, Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports @@ -27,7 +28,7 @@ description="Extract the NER entities in HTML for preview") async def get_rendered_entities_from_text(request: Request, text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")], - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the preview task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: annotations = model_service.annotate(text) entities = annotations_to_entities(annotations, model_service.model_name) @@ -50,7 +51,7 @@ def get_rendered_entities_from_trainer_export(request: Request, trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}", project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None, document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None, - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the trainer export preview task")] = None) -> Response: + tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> Response: data: Dict = {"projects": []} if trainer_export is not None: files = [] diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index 837fde7..fd66443 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -10,6 +10,7 @@ from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE import api.globals as cms_globals +from api.dependencies import validate_tracking_id from domain import Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports @@ -32,7 +33,7 @@ async def train_supervised(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None, - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] diff --git a/app/api/routers/unsupervised_training.py b/app/api/routers/unsupervised_training.py index 1680000..c3925aa 100644 --- a/app/api/routers/unsupervised_training.py +++ b/app/api/routers/unsupervised_training.py @@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE import api.globals as cms_globals +from api.dependencies import validate_tracking_id from domain import Tags, ModelType from model_services.base import AbstractModelService from utils import get_settings @@ -33,7 +34,7 @@ async def train_unsupervised(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage", ge=0.0)] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: """ Upload one or more plain text files and trigger the unsupervised training @@ -97,7 +98,7 @@ async def train_unsupervised_with_hf_dataset(request: Request, test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage will only take effect if the dataset does not have predefined validation or test splits", ge=0.0)] = 0.2, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the training task")] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: """ Trigger the unsupervised training with a dataset from Hugging Face Hub diff --git a/tests/app/api/test_dependencies.py b/tests/app/api/test_dependencies.py index 883eff5..87be16d 100644 --- a/tests/app/api/test_dependencies.py +++ b/tests/app/api/test_dependencies.py @@ -1,4 +1,7 @@ -from api.dependencies import ModelServiceDep +import pytest +from fastapi import HTTPException + +from api.dependencies import ModelServiceDep, validate_tracking_id from config import Settings from model_services.medcat_model import MedCATModel from model_services.medcat_model_icd10 import MedCATModelIcd10 @@ -36,3 +39,28 @@ def test_transformer_deid_dep(): def test_huggingface_ner_dep(): model_service_dep = ModelServiceDep("huggingface_ner", Settings()) assert isinstance(model_service_dep(), HuggingFaceNerModel) + + +@pytest.mark.parametrize( + "run_id", + [ + "a" * 32, + "A" * 32, + "a" * 256, + "f0" * 16, + "abcdef0123456789" * 2, + "abcdefghijklmnopqrstuvqxyz", + "123e4567-e89b-12d3-a456-426614174000", + "123e4567e89b12d3a45642661417400", + ], +) +def test_validate_tracking_id(run_id): + assert validate_tracking_id(run_id) == run_id + + +@pytest.mark.parametrize("run_id", ["a/bc" * 8, "", "a" * 400, "*" * 5]) +def test_validate_tracking_id_invalid(run_id): + with pytest.raises(HTTPException) as exc_info: + validate_tracking_id(run_id) + assert exc_info.value.status_code == 400 + assert "Invalid tracking ID" in exc_info.value.detail