Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,12 @@ def prepare_document(self, doc: MutableDocument, input_ids: list,
# Checking if we've reached at the start of the entity
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind) # End reached
# End reached; update for correct index
ctoken_idx.append(last_ind + ind)
break
else:
ctoken_idx.append(ind) # Keep going
# Keep going; update for correct index
ctoken_idx.append(last_ind + ind)

# Start where the last ent was found, cannot be before it as we've
# sorted
Expand Down
5 changes: 5 additions & 0 deletions v1/medcat/examples/cdb_meta.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cui,name,ontologies,name_status,type_ids,description
C0000039,"gastroesophageal reflux",,,T234,
C0000239,"heartburn",,,,
C0000339,"hypertension",,,,
C0000439,"stroke",,,,
Binary file added v1/medcat/examples/cdb_meta.dat
Binary file not shown.
Binary file added v1/medcat/examples/vocab_meta.dat
Binary file not shown.
4 changes: 2 additions & 2 deletions v1/medcat/medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe
# Checking if we've reached at the start of the entity
if start <= pair[0] or start <= pair[1]:
if end <= pair[1]:
ctoken_idx.append(ind) # End reached
ctoken_idx.append(last_ind+ind) # End reached; update the index to reflect the correct position since iteration does not start from the beginning
break
else:
ctoken_idx.append(ind) # Keep going
ctoken_idx.append(last_ind+ind) # Keep going; update the index to reflect the correct position since iteration does not start from the beginning

# Start where the last ent was found, cannot be before it as we've sorted
last_ind += ind # If we did not start from 0 in the for loop
Expand Down
10,671 changes: 10,671 additions & 0 deletions v1/medcat/tests/resources/mct_export_for_meta_cat_full_text.json

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions v1/medcat/tests/test_meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from transformers import AutoTokenizer

from medcat.vocab import Vocab
from medcat.cdb import CDB
from medcat.cat import CAT
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
import tempfile
import spacy
from spacy.tokens import Span

Expand Down Expand Up @@ -117,6 +121,70 @@ def test_two_phase(self):

self.meta_cat.config.model['phase_number'] = 0

class CAT_METACATTests(unittest.TestCase):
META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources",
"mct_export_for_meta_cat_full_text.json")

@classmethod
def _get_meta_cat(cls, meta_cat_dir):
config = ConfigMetaCAT()
config.general["category_name"] = "Status"
config.general['category_value2id'] = {'Other': 0, 'Confirmed': 1}
config.model['model_name'] = 'bert'
config.model['model_freeze_layers'] = False
config.model['num_layers'] = 10
config.train['lr'] = 0.001
config.train["nepochs"] = 20
config.train.class_weights = [0.75,0.3]
config.train['metric']['base'] = 'macro avg'

meta_cat = MetaCAT(tokenizer=TokenizerWrapperBERT(AutoTokenizer.from_pretrained("bert-base-uncased")),
embeddings=None,
config=config)
os.makedirs(meta_cat_dir, exist_ok=True)
json_path = cls.META_CAT_JSON_PATH
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
return meta_cat

@classmethod
def setUpClass(cls) -> None:
cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_meta.dat"))
cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab_meta.dat"))
cls.vocab.make_unigram_table()
cls._temp_logs_folder = tempfile.TemporaryDirectory()
cls.temp_dir = tempfile.TemporaryDirectory()
cls.cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md")
cls.cdb.config.ner.min_name_len = 2
cls.cdb.config.ner.upper_case_limit_len = 3
cls.cdb.config.general.spell_check = True
cls.cdb.config.linking.train_count_threshold = 10
cls.cdb.config.linking.similarity_threshold = 0.3
cls.cdb.config.linking.train = True
cls.cdb.config.linking.disamb_length_limit = 5
cls.cdb.config.general.full_unlink = True
cls.cdb.config.general.usage_monitor.enabled = True
cls.cdb.config.general.usage_monitor.log_folder = cls._temp_logs_folder.name
cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
cls.meta_cat = cls._get_meta_cat(cls.meta_cat_dir)
cls.cat = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[cls.meta_cat])

@classmethod
def tearDownClass(cls) -> None:
cls.cat.destroy_pipe()
if os.path.exists(cls.meta_cat_dir):
shutil.rmtree(cls.meta_cat_dir)
cls._temp_logs_folder.cleanup()

def test_meta_cat_through_cat(self):
text = "This information is just to add text. The patient denied history of heartburn and/or gastroesophageal reflux disorder. He recently had a stroke in the last week."
entities = self.cat.get_entities(text)
meta_status_values = []
for en in entities['entities']:
meta_status_values.append(entities['entities'][en]['meta_anns']['Status']['value'])

self.assertEqual(meta_status_values,['Other','Other','Confirmed'])

import logging
logging.basicConfig(level=logging.INFO)
if __name__ == '__main__':
unittest.main()
Loading