diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 8b723648e..1b48afa77 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -532,8 +532,9 @@ def _get_entity(self, ent: MutableEntity, # addons: out_dict.update(self.get_addon_output(ent)) # type: ignore # other ontologies - if self.config.general.map_to_other_ontologies: - for ont in self.config.general.map_to_other_ontologies: + other_onts = self._set_and_get_mapped_ontologies() + if other_onts: + for ont in other_onts: if ont in out_dict: logger.warning( "Trying to map to ontology '%s', but it already " @@ -553,6 +554,28 @@ def _get_entity(self, ent: MutableEntity, out_dict[ont] = ont_values # type: ignore return out_dict + def _set_and_get_mapped_ontologies( + self, + ignore_set: set[str] = {"ontologies", "original_names", + "description", "group"}, + ignore_empty: bool = True) -> list[str]: + other_onts = self.config.general.map_to_other_ontologies + if other_onts == "auto": + self.config.general.map_to_other_ontologies = other_onts = [ + npkey + for key, val in self.cdb.addl_info.items() + if key.startswith("cui2") and + # ignore empty if required / expected + (not ignore_empty or val) and + # these are things that get auto-populated in addl_info + # but don't generally contain ontology mapping information + # directly + (npkey := key.removeprefix("cui2")) not in ignore_set + ] + logger.info( + "Automatically finding ontologies to map to: %s", other_onts) + return other_onts + def get_addon_output(self, ent: MutableEntity) -> dict[str, dict]: """Get the addon output for the entity. @@ -809,6 +832,8 @@ def load_model_pack(cls, model_pack_path: str, # will be dealt with upon pipeline creation automatically if not isinstance(cat, CAT): raise ValueError(f"Unable to load CAT. Got: {cat}") + # reset mapped ontologies at load time but after CDB load + cat._set_and_get_mapped_ontologies() return cat @classmethod diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 5c8c4261d..8b75a1caa 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -1,6 +1,6 @@ import os from typing import (Optional, Iterator, Iterable, TypeVar, cast, Type, Any, - Literal) + Literal, Union) from typing import Protocol, runtime_checkable from typing_extensions import Self import logging @@ -252,13 +252,17 @@ class General(SerialisableBaseModel): map_cui_to_group: bool = False """If the cdb.addl_info['cui2group'] is provided and this option enabled, each CUI will be mapped to the group""" - map_to_other_ontologies: list[str] = ["opcs4", "icd10"] + map_to_other_ontologies: Union[Literal["auto"], list[str]] = "auto" """Which other ontologies to map to if possible. This will force medcat to include mapping for other ontologies in its outputs. It will use the mappings in `cdb.addl_info["cui2"]` are present. + If set to "auto" (or missing), the value will be inferred from available + data at first init time. That is to say, it'll map to all ontologies + available. + NB! This will only work if the `cdb.addl_info["cui2"]` exists. Otherwise, no mapping will be done. diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 9ada8deae..ff518f21a 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -133,6 +133,78 @@ def test_can_merge_config(self): model.config.general.nlp.modelname, self.spacy_model_name) +class OntologiesMapTests(TrainedModelTests): + + def test_does_not_have_auto(self): + self.assertNotEqual(self.model.config.general.map_to_other_ontologies, + "auto") + + def test_is_empty(self): + self.assertFalse(self.model.config.general.map_to_other_ontologies) + + +class OntologiesMapWithOntologiesTests(TrainedModelTests): + MY_ONT_NAME = "My_Ontology" + EXP_GET = [MY_ONT_NAME] + MY_ONT_MAPPING = { + # mapping doens't matter here, really + "ABC": "BBC" + } + + @classmethod + def reset_mappings(cls): + # set to auto + cls.model.config.general.map_to_other_ontologies = "auto" + # redo process + cls.model._set_and_get_mapped_ontologies() + + @classmethod + def setUpClass(cls): + super().setUpClass() + # add "mapping" + cls.model.cdb.addl_info[f"cui2{cls.MY_ONT_NAME}"] = cls.MY_ONT_MAPPING + cls.reset_mappings() + + def test_has_correct_results(self): + got = sorted(self.model.config.general.map_to_other_ontologies) + self.assertEqual(len(got), len(self.EXP_GET)) + self.assertEqual(got, self.EXP_GET) + + +class OntologiesMapWithOntologiesAndNoIgnoresTests( + OntologiesMapWithOntologiesTests): + EXTRA_ONTS = ["original_names"] + + @classmethod + def reset_mappings(cls): + # set to auto + cls.model.config.general.map_to_other_ontologies = "auto" + # redo process + cls.model._set_and_get_mapped_ontologies(ignore_set=set()) + + @classmethod + def setUpClass(cls): + super().setUpClass() + # I need to redefine for specific class + # instead of changing instance in base class + cls.EXP_GET = OntologiesMapWithOntologiesTests.EXP_GET.copy() + cls.EXP_GET.extend(cls.EXTRA_ONTS) + cls.EXP_GET.sort() + cls.reset_mappings() + + +class OntologiesMapWithOntologiesAndAllowEmpty( + OntologiesMapWithOntologiesAndNoIgnoresTests): + EXTRA_ONTS = ["icd10", "opcs4"] + + @classmethod + def reset_mappings(cls): + # set to auto + cls.model.config.general.map_to_other_ontologies = "auto" + # redo process + cls.model._set_and_get_mapped_ontologies(ignore_empty=False) + + class InferenceFromLoadedTests(TrainedModelTests): def test_can_load_model(self):