Skip to content
Merged
25 changes: 24 additions & 1 deletion medcat-v2/medcat/cdb/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,25 @@ def save(self, save_path: str,
serialise(serialiser, self, save_path, overwrite=overwrite)

@classmethod
def load(cls, path: str) -> 'CDB':
def load(cls, path: str, perform_fixes: bool = True) -> 'CDB':
"""Load the CDB off disk.

This can load a legacy (v1) CDB (.dat) or a v2 CDB either in its folder
format or the .zip format. The distinction is made automatically.

Args:
path (str): The path to the CDB.
perform_fixes (bool): Whether to perform fixes such as
original names issue. Defaults to True.

Raises:
LegacyConversionDisabledError:
If when a legacy model is found and conversion is not allowed.
ValueError: If the loaded object isn't a CDB.

Returns:
CDB: The loaded CDB.
"""
if should_serialise_as_zip(path, 'auto'):
cdb = deserialise_from_zip(path)
elif os.path.isfile(path) and path.endswith('.dat'):
Expand All @@ -535,4 +553,9 @@ def load(cls, path: str) -> 'CDB':
cdb = deserialise(path)
if not isinstance(cdb, CDB):
raise ValueError(f"The path '{path}' is not a CDB!")
if perform_fixes:
# perform fix(es)
from medcat.utils.legacy.fixes import (
fix_cui2original_names_if_needed)
fix_cui2original_names_if_needed(cdb)
return cdb
7 changes: 7 additions & 0 deletions medcat-v2/medcat/utils/legacy/convert_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _add_cui_info(cdb: CDB, data: dict) -> CDB:
cui2tags, cui2type_ids = data['cui2tags'], data['cui2type_ids']
cui2prefname = data['cui2preferred_name']
cui2av_conf = data['cui2average_confidence']
cui2orig_names = data["addl_info"].get("cui2original_names", {})
for cui in all_cuis:
names = cui2names.get(cui, set())
snames = cui2snames.get(cui, set())
Expand All @@ -134,8 +135,14 @@ def _add_cui_info(cdb: CDB, data: dict) -> CDB:
cui=cui, preferred_name=prefname, names=names, subnames=snames,
type_ids=type_ids, tags=tags, count_train=count_train,
context_vectors=vecs, average_confidence=av_conf,
original_names=cui2orig_names.get(cui, None),
)
cdb.cui2info[cui] = info
# remove cui2original_names from addl_info - we've already used it
if "cui2original_names" in data["addl_info"]:
logger.info("Deleting 'cui2original_names' in addl_info - "
"it was used in CUIInfo already")
del data["addl_info"]["cui2original_names"]
all_cui_tuis = set((ci['cui'], tui) for ci in cdb.cui2info.values()
for tui in ci['type_ids'])
all_tuis = set(tui for _, tui in all_cui_tuis)
Expand Down
48 changes: 48 additions & 0 deletions medcat-v2/medcat/utils/legacy/fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging

from medcat.cdb import CDB


logger = logging.getLogger(__name__)


def _fix_cui2original_names(cdb: CDB) -> None:
cui2on = cdb.addl_info["cui2original_names"]
num_cuis = len(cui2on)
used_cuis = 0
for ci in cdb.cui2info.values():
orig_names: set[str] = cui2on.get(ci["cui"], None)
if orig_names is not None:
if ci["original_names"] is None:
ci["original_names"] = orig_names
else:
ci["original_names"].update(orig_names)
used_cuis += 1
logger.info(
"Used %d out of %d CUIs in the 'cui2original_names' map",
used_cuis, num_cuis)
# delete existing data in cui2original_names
del cdb.addl_info["cui2original_names"]


def fix_cui2original_names_if_needed(cdb: CDB) -> bool:
"""Fix the cui2original names in CDB if needed.

This was an issue caused by faulty legacy conversion
where the data wasn't moved correctly from addl_info.

Args:
cdb (CDB): The CDB in question.

Returns:
bool: Whether the fix / change was made.
"""
if "cui2original_names" in cdb.addl_info:
logger.info(
"CDB addl_info contains legacy data: "
"'cui2original_names' . Moving it to cui2info")
_fix_cui2original_names(cdb)
return True
else:
logger.debug("CDB does not contain legacy 'cui2original_names")
return False
6 changes: 6 additions & 0 deletions medcat-v2/tests/utils/legacy/test_convert_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def test_all_cui_names_in_names(self):
with self.subTest(f"{cui}: {name}"):
self.assertIn(name, self.cdb.name2info)

def test_all_cuis_have_original_names(self):
for cui, ci in self.cdb.cui2info.items():
with self.subTest(cui):
print(cui, ":", ci["original_names"])
self.assertTrue(ci["original_names"])

def test_all_name_cuis_in_per_cui_status(self):
for name, nameinfo in self.cdb.name2info.items():
for cui in nameinfo['per_cui_status']:
Expand Down
63 changes: 63 additions & 0 deletions medcat-v2/tests/utils/legacy/test_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

from medcat.cdb import CDB
from medcat.utils.legacy import fixes
from medcat.utils.cdb_state import captured_state_cdb

import unittest

from ... import UNPACKED_EXAMPLE_MODEL_PACK_PATH


CONVERTED_CDB_PATH = os.path.join(
UNPACKED_EXAMPLE_MODEL_PACK_PATH, "cdb")


class TestCUI2OriginalNamesFix(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.converted_cdb = CDB.load(CONVERTED_CDB_PATH,
perform_fixes=False)

def test_converted_model_does_not_have_orig_names(self):
for ci in self.converted_cdb.cui2info.values():
with self.subTest(ci["cui"]):
self.assertFalse(ci["original_names"])

def test_model_has_orig_names_after_fix(self):
# to make sure this is agnostic to the order
with captured_state_cdb(self.converted_cdb):
changed = fixes.fix_cui2original_names_if_needed(
self.converted_cdb)
self.assertTrue(changed)
# has not cui2original_names
self.assertNotIn("cui2original_names",
self.converted_cdb.addl_info)
for ci in self.converted_cdb.cui2info.values():
with self.subTest(ci["cui"]):
self.assertTrue(ci["original_names"])

def test_will_not_fix_twice(self):
with captured_state_cdb(self.converted_cdb):
fixes.fix_cui2original_names_if_needed(
self.converted_cdb)
changed_twice = fixes.fix_cui2original_names_if_needed(
self.converted_cdb)
self.assertFalse(changed_twice)


class TestCUI2OriginalNamesFixAuto(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.converted_cdb = CDB.load(CONVERTED_CDB_PATH,
perform_fixes=True)

def test_cui2orig_names_fixed_automatically(self):
for ci in self.converted_cdb.cui2info.values():
with self.subTest(ci["cui"]):
self.assertTrue(ci["original_names"])

def test_addl_info_has_no_cui2original_names(self):
self.assertNotIn("cui2original_names", self.converted_cdb.addl_info)
Loading