diff --git a/medcat-service/medcat_service/config.py b/medcat-service/medcat_service/config.py index 9eeaa918b..0a9e9dec4 100644 --- a/medcat-service/medcat_service/config.py +++ b/medcat-service/medcat_service/config.py @@ -3,6 +3,15 @@ class Settings(BaseSettings): + class Config: + frozen = True + app_root_path: str = Field( default="/", description="The Root Path for the FastAPI App", examples=["/medcat-service"] ) + + deid_mode: bool = Field(default=False, description="Enable DEID mode") + deid_redact: bool = Field( + default=True, + description="Enable DEID redaction. Returns text like [***] instead of [ANNOTATION]", + ) diff --git a/medcat-service/medcat_service/dependencies.py b/medcat-service/medcat_service/dependencies.py index 0d0d29c66..f9cbab061 100644 --- a/medcat-service/medcat_service/dependencies.py +++ b/medcat-service/medcat_service/dependencies.py @@ -1,3 +1,4 @@ +import logging from functools import lru_cache from typing import Annotated @@ -6,15 +7,20 @@ from medcat_service.config import Settings from medcat_service.nlp_processor.medcat_processor import MedCatProcessor +log = logging.getLogger(__name__) + @lru_cache -def get_medcat_processor() -> MedCatProcessor: - return MedCatProcessor() +def get_settings() -> Settings: + settings = Settings() + log.debug("Using settings: %s", settings) + return settings @lru_cache -def get_settings() -> Settings: - return Settings() +def get_medcat_processor(settings: Annotated[Settings, Depends(get_settings)]) -> MedCatProcessor: + log.debug("Creating new Medcat Processsor using settings: %s", settings) + return MedCatProcessor(settings) MedCatProcessorDep = Annotated[MedCatProcessor, Depends(get_medcat_processor)] diff --git a/medcat-service/medcat_service/nlp_processor/medcat_processor.py b/medcat-service/medcat_service/nlp_processor/medcat_processor.py index e0a1aaec2..5b5b33b2f 100644 --- a/medcat-service/medcat_service/nlp_processor/medcat_processor.py +++ b/medcat-service/medcat_service/nlp_processor/medcat_processor.py @@ -15,16 +15,17 @@ from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.vocab import Vocab +from medcat_service.config import Settings from medcat_service.types import HealthCheckResponse, ModelCardInfo, ProcessErrorsResult, ProcessResult, ServiceInfo -class MedCatProcessor(): +class MedCatProcessor: """" MedCAT Processor class is wrapper over MedCAT that implements annotations extractions functionality (both single and bulk processing) that can be easily exposed for an API. """ - def __init__(self): + def __init__(self, settings: Settings): app_log_level = os.getenv("APP_LOG_LEVEL", logging.INFO) medcat_log_level = os.getenv("LOG_LEVEL", logging.INFO) @@ -46,8 +47,8 @@ def __init__(self): self.bulk_nproc = int(os.getenv("APP_BULK_NPROC", 8)) self.torch_threads = int(os.getenv("APP_TORCH_THREADS", -1)) - self.DEID_MODE = eval(os.getenv("DEID_MODE", "False")) - self.DEID_REDACT = eval(os.getenv("DEID_REDACT", "True")) + self.DEID_MODE = settings.deid_mode + self.DEID_REDACT = settings.deid_redact self.model_card_info = ModelCardInfo( ontologies=None, meta_cat_model_names=[], model_last_modified_on=None) @@ -209,13 +210,13 @@ def process_content_bulk(self, content): start_time_ns = time.time_ns() try: + text_input = MedCatProcessor._generate_input_doc(content, invalid_doc_ids) if self.DEID_MODE: - # TODO 2025-07-21: deid_multi_texts doesnt exist in medcat 2? - ann_res = self.cat.deid_multi_texts(MedCatProcessor._generate_input_doc(content, invalid_doc_ids), - redact=self.DEID_REDACT) + text_to_deid_from_tuple = (x[1] for x in text_input) + + ann_res = self.cat.deid_multi_text(list(text_to_deid_from_tuple), + redact=self.DEID_REDACT, n_process=self.bulk_nproc) else: - text_input = MedCatProcessor._generate_input_doc( - content, invalid_doc_ids) ann_res = { ann_id: res for ann_id, res in self.cat.get_entities_multi_texts( @@ -426,9 +427,11 @@ def _generate_result(self, in_documents, annotations, elapsed_time): footer=in_ct.get("footer"), ) elif self.DEID_MODE: - out_res = ProcessResult( - text=str(in_ct["text"]), + # TODO: DEID mode is passing the resulting text in the annotations field here but shouldnt. + text=str(annotations[i]), + # TODO: DEID bulk mode should also be able to return the list of annotations found, + # to match the features of the singular api. CU-869a6wc6z annotations=[], success=True, timestamp=self._get_timestamp(), diff --git a/medcat-service/medcat_service/test/common.py b/medcat-service/medcat_service/test/common.py index 4937dbf2a..e72d9dcb5 100644 --- a/medcat-service/medcat_service/test/common.py +++ b/medcat-service/medcat_service/test/common.py @@ -85,3 +85,6 @@ def setup_medcat_processor(): os.environ["APP_BULK_NPROC"] = "8" os.environ["APP_TRAINING_MODE"] = "False" + + os.environ["DEID_MODE"] = "False" + os.environ["DEID_REDACT"] = "False" diff --git a/medcat-service/medcat_service/test/test_deid.py b/medcat-service/medcat_service/test/test_deid.py index 5c8e16453..24fe8063a 100644 --- a/medcat-service/medcat_service/test/test_deid.py +++ b/medcat-service/medcat_service/test/test_deid.py @@ -1,43 +1,39 @@ +import os import unittest from fastapi.testclient import TestClient import medcat_service.test.common as common +from medcat_service.config import Settings +from medcat_service.dependencies import get_settings +from medcat_service.main import app -class TestMedcatServiceDeId(unittest.TestCase): - """ - Implementation of test cases for MedCAT service - """ +def get_settings_override(): + return Settings(deid_mode=True, deid_redact=True) + - # Available endpoints - # +class TestMedcatServiceDeId(unittest.TestCase): ENDPOINT_PROCESS_SINGLE = "/api/process" ENDPOINT_PROCESS_BULK = "/api/process_bulk" client: TestClient - # Static initialization methods - # @classmethod def setUpClass(cls): - pass - # Enable when test enabled. Complexity around env vars being shared accross tests, - # Should instead move to use pydantic settings for easy test overrides. + common.setup_medcat_processor() - # common.setup_medcat_processor() - # os.environ["DEID_MODE"] = "True" - # os.environ["DEID_REDACT"] = "True" + if "APP_MEDCAT_MODEL_PACK" not in os.environ: + os.environ["APP_MEDCAT_MODEL_PACK"] = "./models/examples/example-deid-model-pack.zip" - # if "APP_MEDCAT_MODEL_PACK" not in os.environ: - # os.environ["APP_MEDCAT_MODEL_PACK"] = "./models/example-deid-model-pack.zip" + app.dependency_overrides[get_settings] = get_settings_override + cls.client = TestClient(app) - # cls.client = TestClient(app) - - @unittest.skip("Disabled until deid model is committed") - def testDeidProcess(self): + def test_deid_process_api(self): payload = common.create_payload_content_from_doc_single( "John had been diagnosed with acute Kidney Failure the week before" ) + app.dependency_overrides[get_settings] = get_settings_override + response = self.client.post(self.ENDPOINT_PROCESS_SINGLE, json=payload) self.assertEqual(response.status_code, 200) @@ -58,3 +54,37 @@ def testDeidProcess(self): self.assertEqual(ann["pretty_name"], expected["pretty_name"]) self.assertEqual(ann["source_value"], expected["source_value"]) self.assertEqual(ann["cui"], expected["cui"]) + app.dependency_overrides = {} + + def test_deid_process_bulk_api(self): + payload = common.create_payload_content_from_doc_bulk([ + "John had been diagnosed with acute Kidney Failure the week before" + ]) + app.dependency_overrides[get_settings] = get_settings_override + + response = self.client.post(self.ENDPOINT_PROCESS_BULK, json=payload) + self.assertEqual(response.status_code, 200) + + actual = response.json() + + expected = { + "pretty_name": "PATIENT", + "source_value": "John", + "cui": "PATIENT", + "text": "[****] had been diagnosed with acute Kidney Failure the week before", + } + self.assertEqual(len(actual["result"]), 1) + self.assertEqual(actual["result"][0]["text"], expected["text"]) + + self.assertEqual( + len(actual["result"][0]["annotations"]), + 0, + "CU-869a6wc6z No annotations are currently returned by the bulk API", + ) + + # Note: CU-869a6wc6z commended out these asserts until annations are returned + # ann = actual["result"][0]["annotations"][0]["0"] + # self.assertEqual(ann["pretty_name"], expected["pretty_name"]) + # self.assertEqual(ann["source_value"], expected["source_value"]) + # self.assertEqual(ann["cui"], expected["cui"]) + app.dependency_overrides = {} diff --git a/medcat-service/medcat_service/test/test_medcat_processor.py b/medcat-service/medcat_service/test/test_medcat_processor.py index 6a5c5c89e..f338755c9 100644 --- a/medcat-service/medcat_service/test/test_medcat_processor.py +++ b/medcat-service/medcat_service/test/test_medcat_processor.py @@ -1,5 +1,6 @@ import unittest +from medcat_service.config import Settings from medcat_service.nlp_processor import MedCatProcessor from medcat_service.test.common import setup_medcat_processor @@ -7,7 +8,7 @@ class TestMedCatProcessorReadiness(unittest.TestCase): def setUp(self): setup_medcat_processor() - self.processor = MedCatProcessor() + self.processor = MedCatProcessor(Settings()) def test_readiness_is_ok(self): result = self.processor._check_medcat_readiness() diff --git a/medcat-service/scripts/integration_test_functions.sh b/medcat-service/scripts/integration_test_functions.sh index 3482e0b59..b5c78ee8f 100644 --- a/medcat-service/scripts/integration_test_functions.sh +++ b/medcat-service/scripts/integration_test_functions.sh @@ -70,26 +70,37 @@ integration_test_medcat_service() { # Test /api/process_bulk - if [[ "$expected_annotation" == "PATIENT" ]]; then - echo "Skipping Process_bulk test for DeID Mode testing " - echo "Process_bulk in DeID mode appears to have a bug making it return the text without deid" - return 0 - fi - local api="http://${localhost_name}:${port}/api/process_bulk" local input_text="Patient J. Smith had been diagnosed with acute kidney failure the week before" local input_payload="{\"content\": [{\"text\":\"${input_text}\"}]}" - local expected_annotation="Kidney Failure" - + local expected_annotation=${3:-Kidney Failure} echo "Calling POST $api with payload '$input_payload'" local actual - actual=$(curl -s -X POST $api \ + # Capture both body and HTTP code + response=$(curl -s -w "\n%{http_code}" -X POST "$api" \ -H 'Content-Type: application/json' \ -d "$input_payload") - echo "Recieved result '$actual'" + # Split body and code + http_code=$(echo "$response" | tail -n1) + actual=$(echo "$response" | sed '$d') + + echo "HTTP status: $http_code" + echo "Response body: '$actual'" + + if [[ "$http_code" != "200" ]]; then + echo "ERROR: Expected HTTP 200, got $http_code" + echo -e "Actual response was:\n${actual}" + return 1 + fi + + if [[ "$expected_annotation" == "PATIENT" ]]; then + echo "CU-869a6wc6z Skipping Process_bulk annotation test for DeID Mode testing " + echo "Process_bulk in DeID mode has missing feature making it not return the annotations, just the deid text" + return 0 + fi local actual_annotation actual_annotation=$(echo "$actual" | jq -r '.result[0].annotations[0]["0"].pretty_name') diff --git a/medcat-v2/tests/utils/ner/test_deid.py b/medcat-v2/tests/utils/ner/test_deid.py index cc5a3e968..c776f229c 100644 --- a/medcat-v2/tests/utils/ner/test_deid.py +++ b/medcat-v2/tests/utils/ner/test_deid.py @@ -16,7 +16,6 @@ import shutil import unittest -# import timeout_decorator FILE_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -31,6 +30,8 @@ TEST_DATA = os.path.join(FILE_DIR, "..", "..", "resources", "deid_test_data.json") +USE_CACHE_TEST_MODEL = False # Choose True for local dev to skip train/test every run + cnf = Config() cnf.general.nlp.provider = 'spacy' @@ -85,6 +86,16 @@ def _create_model() -> deid.DeIdModel: def _train_model_once() -> tuple[tuple[Any, Any, Any], deid.DeIdModel]: + if USE_CACHE_TEST_MODEL: + print("Using cached model for local repeats instead of create/train each time") + model_path = "tmp/deid_test_model.zip" + if not os.path.exists(model_path): + model = _create_model() + retval = model.train(TRAIN_DATA) + model.cat.save_model_pack("tmp", pack_name="deid_test_model", make_archive=True) + model = deid.DeIdModel.load_model_pack(model_path) + return retval, model + model = _create_model() retval = model.train(TRAIN_DATA) # mpp = 'temp/deid_multiprocess/dumps/temp_model_save' @@ -166,26 +177,21 @@ class DeIDModelWorks(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.deid_model = train_model_once()[1] + # import torch + # torch.set_num_threads(1) def tearDown(self): if os.path.exists(self.save_folder): shutil.rmtree(self.save_folder) - def test_model_works_deid_text(self): - anon_text = self.deid_model.deid_text(input_text) + def assert_deid_annotations(self, anon_text: str): self.assertIn("[DOCTOR]", anon_text) self.assertNotIn("M. Sully", anon_text) self.assertIn("[HOSPITAL]", anon_text) # self.assertNotIn("Dublin", anon_text) self.assertNotIn("7 Eccles Street", anon_text) - def test_model_works_dunder_call(self): - anon_doc = self.deid_model(input_text) - self.assertIsInstance(anon_doc, runtime_checkable(MutableDocument)) - self.assertTrue(anon_doc.ner_ents) - - def test_model_works_deid_text_redact(self): - anon_text = self.deid_model.deid_text(input_text, redact=True) + def assert_deid_redact(self, anon_text: str): self.assertIn("****", anon_text) self.assertNotIn("[DOCTOR]", anon_text) self.assertNotIn("M. Sully", anon_text) @@ -193,58 +199,59 @@ def test_model_works_deid_text_redact(self): # self.assertNotIn("Dublin", anon_text) self.assertNotIn("7 Eccles Street", anon_text) + def test_model_works_deid_text(self): + anon_text = self.deid_model.deid_text(input_text) + self.assert_deid_annotations(anon_text) + + def test_model_works_dunder_call(self): + anon_doc = self.deid_model(input_text) + self.assertIsInstance(anon_doc, runtime_checkable(MutableDocument)) + self.assertTrue(anon_doc.ner_ents) -# class DeIDModelMultiprocessingWorks(unittest.TestCase): -# processes = 2 - -# @classmethod -# def setUpClass(cls) -> None: -# Span.set_extension('link_candidates', default=None, force=True) -# _add_model(cls) -# cls.deid_model = train_model_once(cls.deid_model)[1] -# with open(TEST_DATA) as f: -# raw_data = json.load(f) -# cls.data = [] -# for project in raw_data['projects']: -# for doc in project['documents']: -# cls.data.append( -# (f"{project['name']}_{doc['name']}", doc['text'])) -# # NOTE: Comment and subsequent code -# # copied from CAT.multiprocessing_batch_char_size -# # (lines 1234 - 1237) -# # Hack for torch using multithreading, which is not good if not -# # separate_nn_components, need for CPU runs only -# import torch -# torch.set_num_threads(1) - -# def assertTextHasBeenDeIded(self, text: str, redacted: bool): -# if not redacted: -# for cui in self.deid_model.cdb.cui2names: -# cui_name = self.deid_model.cdb.get_name(cui) -# if cui_name in text: -# # all good -# return -# else: -# # if redacted, only check once... -# if "******" in text: -# # all good -# return -# raise AssertionError("None of the CUIs found") - - # # @timeout_decorator.timeout(3 * 60) # 3 minutes max - # def test_model_can_multiprocess_no_redact(self): - # processed = self.deid_model.deid_multi_texts( - # self.data, n_process=self.processes) - # self.assertEqual(len(processed), 5) - # for tid, new_text in enumerate(processed): - # with self.subTest(str(tid)): - # self.assertTextHasBeenDeIded(new_text, redacted=False) - - # # @timeout_decorator.timeout(3 * 60) # 3 minutes max - # def test_model_can_multiprocess_redact(self): - # processed = self.deid_model.deid_multi_texts( - # self.data, n_process=self.processes, redact=True) - # self.assertEqual(len(processed), 5) - # for tid, new_text in enumerate(processed): - # with self.subTest(str(tid)): - # self.assertTextHasBeenDeIded(new_text, redacted=True) + def test_model_works_deid_text_redact(self): + anon_text = self.deid_model.deid_text(input_text, redact=True) + self.assert_deid_redact(anon_text) + + def test_model_works_deid_multi_text_single_threaded(self): + processed = self.deid_model.deid_multi_text([input_text, input_text], n_process=1) + self.assertEqual(len(processed), 2) + for anon_text in processed: + self.assert_deid_annotations(anon_text) + + def test_model_works_deid_multi_text_single_threaded_redact(self): + processed = self.deid_model.deid_multi_text([input_text, input_text], + n_process=1, redact=True) + self.assertEqual(len(processed), 2) + for anon_text in processed: + self.assert_deid_redact(anon_text) + + # @timeout_decorator.timeout(3 * 60) # 3 minutes max + @unittest.skip("Deid Multiprocess is broken. Exits the process, no errors shown") + def test_model_can_multiprocess_no_redact(self): + + processed = self.deid_model.deid_multi_text( + [input_text, input_text], n_process=2) + self.assertEqual(len(processed), 2) + for tid, new_text in enumerate(processed): + with self.subTest(str(tid)): + self.assert_deid_annotations(new_text) + + # @timeout_decorator.timeout(3 * 60) # 3 minutes max + @unittest.skip("Deid Multiprocess is broken. Exits the process, no errors shown.") + def test_model_can_multiprocess_redact(self): + """ + deid_multi_text is broken for n_process >1. + Issue: Running this just exits the whole process, no errors or exceptions shown + """ + try: + print("Calling test_model_can_multiprocess_redact") + processed = self.deid_model.deid_multi_text( + [input_text, input_text], n_process=2, redact=True + ) + print("Finished processing") + self.assertEqual(len(processed), 5) + for tid, new_text in enumerate(processed): + with self.subTest(str(tid)): + self.assert_deid_redact(new_text) + except Exception as e: + self.fail(f"Multiprocessing redact test raised an exception: {e}")