diff --git a/medcat-v2/medcat/cdb/cdb.py b/medcat-v2/medcat/cdb/cdb.py index 82af709c2..d1592a2ff 100644 --- a/medcat-v2/medcat/cdb/cdb.py +++ b/medcat-v2/medcat/cdb/cdb.py @@ -521,7 +521,25 @@ def save(self, save_path: str, serialise(serialiser, self, save_path, overwrite=overwrite) @classmethod - def load(cls, path: str) -> 'CDB': + def load(cls, path: str, perform_fixes: bool = True) -> 'CDB': + """Load the CDB off disk. + + This can load a legacy (v1) CDB (.dat) or a v2 CDB either in its folder + format or the .zip format. The distinction is made automatically. + + Args: + path (str): The path to the CDB. + perform_fixes (bool): Whether to perform fixes such as + original names issue. Defaults to True. + + Raises: + LegacyConversionDisabledError: + If when a legacy model is found and conversion is not allowed. + ValueError: If the loaded object isn't a CDB. + + Returns: + CDB: The loaded CDB. + """ if should_serialise_as_zip(path, 'auto'): cdb = deserialise_from_zip(path) elif os.path.isfile(path) and path.endswith('.dat'): @@ -535,4 +553,9 @@ def load(cls, path: str) -> 'CDB': cdb = deserialise(path) if not isinstance(cdb, CDB): raise ValueError(f"The path '{path}' is not a CDB!") + if perform_fixes: + # perform fix(es) + from medcat.utils.legacy.fixes import ( + fix_cui2original_names_if_needed) + fix_cui2original_names_if_needed(cdb) return cdb diff --git a/medcat-v2/medcat/utils/legacy/convert_cdb.py b/medcat-v2/medcat/utils/legacy/convert_cdb.py index 6399c5848..119533599 100644 --- a/medcat-v2/medcat/utils/legacy/convert_cdb.py +++ b/medcat-v2/medcat/utils/legacy/convert_cdb.py @@ -121,6 +121,7 @@ def _add_cui_info(cdb: CDB, data: dict) -> CDB: cui2tags, cui2type_ids = data['cui2tags'], data['cui2type_ids'] cui2prefname = data['cui2preferred_name'] cui2av_conf = data['cui2average_confidence'] + cui2orig_names = data["addl_info"].get("cui2original_names", {}) for cui in all_cuis: names = cui2names.get(cui, set()) snames = cui2snames.get(cui, set()) @@ -134,8 +135,14 @@ def _add_cui_info(cdb: CDB, data: dict) -> CDB: cui=cui, preferred_name=prefname, names=names, subnames=snames, type_ids=type_ids, tags=tags, count_train=count_train, context_vectors=vecs, average_confidence=av_conf, + original_names=cui2orig_names.get(cui, None), ) cdb.cui2info[cui] = info + # remove cui2original_names from addl_info - we've already used it + if "cui2original_names" in data["addl_info"]: + logger.info("Deleting 'cui2original_names' in addl_info - " + "it was used in CUIInfo already") + del data["addl_info"]["cui2original_names"] all_cui_tuis = set((ci['cui'], tui) for ci in cdb.cui2info.values() for tui in ci['type_ids']) all_tuis = set(tui for _, tui in all_cui_tuis) diff --git a/medcat-v2/medcat/utils/legacy/fixes.py b/medcat-v2/medcat/utils/legacy/fixes.py new file mode 100644 index 000000000..13c8a864e --- /dev/null +++ b/medcat-v2/medcat/utils/legacy/fixes.py @@ -0,0 +1,48 @@ +import logging + +from medcat.cdb import CDB + + +logger = logging.getLogger(__name__) + + +def _fix_cui2original_names(cdb: CDB) -> None: + cui2on = cdb.addl_info["cui2original_names"] + num_cuis = len(cui2on) + used_cuis = 0 + for ci in cdb.cui2info.values(): + orig_names: set[str] = cui2on.get(ci["cui"], None) + if orig_names is not None: + if ci["original_names"] is None: + ci["original_names"] = orig_names + else: + ci["original_names"].update(orig_names) + used_cuis += 1 + logger.info( + "Used %d out of %d CUIs in the 'cui2original_names' map", + used_cuis, num_cuis) + # delete existing data in cui2original_names + del cdb.addl_info["cui2original_names"] + + +def fix_cui2original_names_if_needed(cdb: CDB) -> bool: + """Fix the cui2original names in CDB if needed. + + This was an issue caused by faulty legacy conversion + where the data wasn't moved correctly from addl_info. + + Args: + cdb (CDB): The CDB in question. + + Returns: + bool: Whether the fix / change was made. + """ + if "cui2original_names" in cdb.addl_info: + logger.info( + "CDB addl_info contains legacy data: " + "'cui2original_names' . Moving it to cui2info") + _fix_cui2original_names(cdb) + return True + else: + logger.debug("CDB does not contain legacy 'cui2original_names") + return False diff --git a/medcat-v2/tests/utils/legacy/test_convert_cdb.py b/medcat-v2/tests/utils/legacy/test_convert_cdb.py index 13bc3d38e..ca7873f97 100644 --- a/medcat-v2/tests/utils/legacy/test_convert_cdb.py +++ b/medcat-v2/tests/utils/legacy/test_convert_cdb.py @@ -32,6 +32,12 @@ def test_all_cui_names_in_names(self): with self.subTest(f"{cui}: {name}"): self.assertIn(name, self.cdb.name2info) + def test_all_cuis_have_original_names(self): + for cui, ci in self.cdb.cui2info.items(): + with self.subTest(cui): + print(cui, ":", ci["original_names"]) + self.assertTrue(ci["original_names"]) + def test_all_name_cuis_in_per_cui_status(self): for name, nameinfo in self.cdb.name2info.items(): for cui in nameinfo['per_cui_status']: diff --git a/medcat-v2/tests/utils/legacy/test_fixes.py b/medcat-v2/tests/utils/legacy/test_fixes.py new file mode 100644 index 000000000..0d8bdc419 --- /dev/null +++ b/medcat-v2/tests/utils/legacy/test_fixes.py @@ -0,0 +1,63 @@ +import os + +from medcat.cdb import CDB +from medcat.utils.legacy import fixes +from medcat.utils.cdb_state import captured_state_cdb + +import unittest + +from ... import UNPACKED_EXAMPLE_MODEL_PACK_PATH + + +CONVERTED_CDB_PATH = os.path.join( + UNPACKED_EXAMPLE_MODEL_PACK_PATH, "cdb") + + +class TestCUI2OriginalNamesFix(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.converted_cdb = CDB.load(CONVERTED_CDB_PATH, + perform_fixes=False) + + def test_converted_model_does_not_have_orig_names(self): + for ci in self.converted_cdb.cui2info.values(): + with self.subTest(ci["cui"]): + self.assertFalse(ci["original_names"]) + + def test_model_has_orig_names_after_fix(self): + # to make sure this is agnostic to the order + with captured_state_cdb(self.converted_cdb): + changed = fixes.fix_cui2original_names_if_needed( + self.converted_cdb) + self.assertTrue(changed) + # has not cui2original_names + self.assertNotIn("cui2original_names", + self.converted_cdb.addl_info) + for ci in self.converted_cdb.cui2info.values(): + with self.subTest(ci["cui"]): + self.assertTrue(ci["original_names"]) + + def test_will_not_fix_twice(self): + with captured_state_cdb(self.converted_cdb): + fixes.fix_cui2original_names_if_needed( + self.converted_cdb) + changed_twice = fixes.fix_cui2original_names_if_needed( + self.converted_cdb) + self.assertFalse(changed_twice) + + +class TestCUI2OriginalNamesFixAuto(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.converted_cdb = CDB.load(CONVERTED_CDB_PATH, + perform_fixes=True) + + def test_cui2orig_names_fixed_automatically(self): + for ci in self.converted_cdb.cui2info.values(): + with self.subTest(ci["cui"]): + self.assertTrue(ci["original_names"]) + + def test_addl_info_has_no_cui2original_names(self): + self.assertNotIn("cui2original_names", self.converted_cdb.addl_info)