Skip to content

Commit 7442cec

Browse files
authored
refactor(medcat): CU-869b44wz8 Better internal components (#219)
* CU-869b44wz8: Create new abstraction layer for entity providing components (e.g NER and Linker) * CU-869b44wz8: Use new abstraction for linkers * CU-869b44wz8: Use new abstraaction for DeID * CU-869b44wz8: Fix setting of linker entities - do it all in one place * Fix NER tests * Fix postporcesing tests * CU-869b44wz8: Update NER components with new abstraction * CU-869b44wz8: Fix issue with wrong base class * CU-869b44wz8: Add missing base class init call * CU-869b44wz8: Fix typo * CU-869b44wz8: Avoid implicit use of doc.ner_ents * CU-869b44wz8: Fix issue with entity IDs * Update tutorial with up to date example * CU-869b44wz8: Fix issue with wrong base class in tutorial * CU-869b44wz8: Reinstate old signature of create_main_ann and use new one * CU-869b44wz8: Deprecate old create_main_ann method * CU-869b44wz8: Use correct syntax in tutorials for maybe_annotate_name * CU-869b44wz8: Allow None for current ID and produce a unique ID if needed * CU-869b44wz8: Add entity to doc.ner_ents during annotate_name if no ID (i.e old API) is used to preserve previous functionality * CU-869b44wz8: Add a few tests for old and new API for maybe_annnotate_name * CU-869b44wz8: Fix old behaviour of create_main_ann * CU-869b44wz8: Add a few small tests fro create_main_ann and filter_linked_annotations * CU-869b44wz8: Add a baseline test
1 parent 12d60ad commit 7442cec

File tree

12 files changed

+346
-99
lines changed

12 files changed

+346
-99
lines changed

medcat-v2-tutorials/notebooks/advanced/2._Create_and_use_component.ipynb

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
"from medcat.cdb.cdb import CDB\n",
3939
"from medcat.config.config import Ner\n",
4040
"# for the component itself\n",
41-
"from medcat.components.types import AbstractCoreComponent, CoreComponentType\n",
41+
"from medcat.components.types import CoreComponentType\n",
42+
"from medcat.components.types import AbstractEntityProvidingComponent\n",
4243
"from medcat.tokenizing.tokens import MutableDocument, MutableEntity\n",
4344
"from medcat.components.ner.vocab_based_annotator import maybe_annotate_name\n",
4445
"\n",
@@ -59,7 +60,7 @@
5960
" return min(max(self.min, num), self.max)\n",
6061
"\n",
6162
"\n",
62-
"class RandomNER(AbstractCoreComponent):\n",
63+
"class RandomNER(AbstractEntityProvidingComponent):\n",
6364
" # NOTE: NEED TO IMPLEMENT\n",
6465
" name = \"RANDOM_NER\"\n",
6566
"\n",
@@ -73,6 +74,7 @@
7374
" # NOTE: NEED TO IMPLEMENT\n",
7475
" # you can specify whatever init args as long as you define them above\n",
7576
" def __init__(self, tokenizer: BaseTokenizer, cdb: CDB):\n",
77+
" super().__init__()\n",
7678
" self.tokenizer = tokenizer\n",
7779
" self.cdb = cdb\n",
7880
"\n",
@@ -90,7 +92,9 @@
9092
" return CoreComponentType.ner\n",
9193
"\n",
9294
" # NOTE: NEED TO IMPLEMENT\n",
93-
" def __call__(self, doc: MutableDocument) -> MutableDocument:\n",
95+
" def predict_entities(self, doc: MutableDocument,\n",
96+
" ents: list[MutableEntity] | None = None\n",
97+
" ) -> list[MutableEntity]:\n",
9498
" \"\"\"Detect candidates for concepts - linker will then be able\n",
9599
" to do the rest. It adds `entities` to the doc.entities and each\n",
96100
" entity can have the entity.link_candidates - that the linker\n",
@@ -99,6 +103,8 @@
99103
" Args:\n",
100104
" doc (MutableDocument):\n",
101105
" Spacy document to be annotated with named entities.\n",
106+
" ents list[MutableEntity] | None = None:\n",
107+
" The entties to use. None expected here.\n",
102108
"\n",
103109
" Returns:\n",
104110
" doc (MutableDocument):\n",
@@ -113,6 +119,7 @@
113119
" for start in start_tkn_indices]\n",
114120
" choose_from = list(self.cdb.name2info.keys())\n",
115121
" chosen_name = [random.choice(choose_from) for _ in start_tkn_indices]\n",
122+
" ner_ents: list[MutableEntity] = []\n",
116123
" for tkn_start_idx, tkn_end_idx, linked_name in zip(start_tkn_indices, end_tkn_indices, chosen_name):\n",
117124
" char_start_idx = doc[tkn_start_idx].base.char_index\n",
118125
" # NOTE: can only do this since we're never selecting the last token\n",
@@ -123,8 +130,10 @@
123130
" # safe to assume that these are all lists of tokens\n",
124131
"\n",
125132
" # this checks the config (i.e length and stuff) and then annotes\n",
126-
" maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config)\n",
127-
" return doc\n",
133+
" ent = maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config, len(ner_ents))\n",
134+
" if ent:\n",
135+
" ner_ents.append(ent)\n",
136+
" return ner_ents\n",
128137
"\n"
129138
]
130139
},

medcat-v2/medcat/components/linking/context_based_linker.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
import logging
33
from typing import Iterator, Optional, Union
44

5-
from medcat.components.types import CoreComponentType, AbstractCoreComponent
5+
from medcat.components.types import CoreComponentType
6+
from medcat.components.types import AbstractEntityProvidingComponent
67
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
78
from medcat.components.linking.vector_context_model import (
89
ContextModel, PerDocumentTokenCache)
910
from medcat.cdb import CDB
1011
from medcat.vocab import Vocab
1112
from medcat.config.config import Config, ComponentConfig
1213
from medcat.utils.defaults import StatusTypes as ST
13-
from medcat.utils.postprocessing import create_main_ann
14+
from medcat.utils.postprocessing import filter_linked_annotations
1415
from medcat.tokenizing.tokenizers import BaseTokenizer
1516

1617

1718
logger = logging.getLogger(__name__)
1819

1920

2021
# class Linker(PipeRunner):
21-
class Linker(AbstractCoreComponent):
22+
class Linker(AbstractEntityProvidingComponent):
2223
"""Link to a biomedical database.
2324
2425
Args:
@@ -32,6 +33,7 @@ class Linker(AbstractCoreComponent):
3233

3334
# Override
3435
def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None:
36+
super().__init__()
3537
self.cdb = cdb
3638
self.vocab = vocab
3739
self.config = config
@@ -105,9 +107,11 @@ def _process_entity_train(self, doc: MutableDocument,
105107
entity.context_similarity = 1
106108
yield entity
107109

108-
def _train_on_doc(self, doc: MutableDocument) -> Iterator[MutableEntity]:
110+
def _train_on_doc(self, doc: MutableDocument,
111+
ner_ents: list[MutableEntity]
112+
) -> Iterator[MutableEntity]:
109113
# Run training
110-
for entity in doc.ner_ents:
114+
for entity in ner_ents:
111115
yield from self._process_entity_train(
112116
doc, entity, PerDocumentTokenCache())
113117

@@ -186,35 +190,41 @@ def _process_entity_inference(
186190
entity.context_similarity = context_similarity
187191
yield entity
188192

189-
def _inference(self, doc: MutableDocument) -> Iterator[MutableEntity]:
193+
def _inference(self, doc: MutableDocument,
194+
ner_ents: list[MutableEntity]
195+
) -> Iterator[MutableEntity]:
190196
per_doc_valid_token_cache = PerDocumentTokenCache()
191-
for entity in doc.ner_ents:
197+
for entity in ner_ents:
192198
logger.debug("Linker started with entity: %s", entity.base.text)
193199
yield from self._process_entity_inference(
194200
doc, entity, per_doc_valid_token_cache)
195201

196-
def __call__(self, doc: MutableDocument) -> MutableDocument:
202+
def predict_entities(self, doc: MutableDocument,
203+
ents: list[MutableEntity] | None = None
204+
) -> list[MutableEntity]:
197205
# Reset main entities, will be recreated later
198-
doc.linked_ents.clear()
199206
cnf_l = self.config.components.linking
200207

208+
if ents is None:
209+
raise ValueError("Need to have NER'ed entities provided")
210+
201211
if cnf_l.train:
202-
linked_entities = self._train_on_doc(doc)
212+
linked_entities = self._train_on_doc(doc, ents)
203213
else:
204-
linked_entities = self._inference(doc)
214+
linked_entities = self._inference(doc, ents)
205215
# evaluating generator here because the `all_ents` list gets
206216
# cleared afterwards otherwise
207217
le = list(linked_entities)
208218

209-
doc.ner_ents.clear()
210-
doc.ner_ents.extend(le)
211-
create_main_ann(doc, self.config.general.show_nested_entities)
219+
# doc.ner_ents.clear()
220+
# doc.ner_ents.extend(le)
212221

213222
# TODO - reintroduce pretty labels? and apply here?
214223

215224
# TODO - reintroduce groups? and map here?
216225

217-
return doc
226+
return filter_linked_annotations(
227+
doc, le, self.config.general.show_nested_entities)
218228

219229
def train(self, cui: str,
220230
entity: MutableEntity,

medcat-v2/medcat/components/linking/embedding_linker.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from medcat.cdb import CDB
22
from medcat.config.config import Config, ComponentConfig, EmbeddingLinking
3-
from medcat.components.types import CoreComponentType, AbstractCoreComponent
3+
from medcat.components.types import CoreComponentType
4+
from medcat.components.types import AbstractEntityProvidingComponent
45
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
56
from medcat.tokenizing.tokenizers import BaseTokenizer
67
from typing import Optional, Iterator, Set
78
from medcat.vocab import Vocab
8-
from medcat.utils.postprocessing import create_main_ann
9+
from medcat.utils.postprocessing import filter_linked_annotations
910
from tqdm import tqdm
1011
from collections import defaultdict
1112
import logging
@@ -27,7 +28,7 @@
2728
logger = logging.getLogger(__name__)
2829

2930

30-
class Linker(AbstractCoreComponent):
31+
class Linker(AbstractEntityProvidingComponent):
3132
name = "embedding_linker"
3233

3334
def __init__(self, cdb: CDB, config: Config) -> None:
@@ -36,6 +37,7 @@ def __init__(self, cdb: CDB, config: Config) -> None:
3637
cdb (CDB): The concept database to use.
3738
config (Config): The base config.
3839
"""
40+
super().__init__()
3941
self.cdb = cdb
4042
self.config = config
4143
if not isinstance(config.components.linking, EmbeddingLinking):
@@ -92,7 +94,7 @@ def create_embeddings(self,
9294
using the chosen embedding model."""
9395
if embedding_model_name is None:
9496
embedding_model_name = self.cnf_l.embedding_model_name # fallback
95-
97+
9698
if max_length is not None and max_length != self.max_length:
9799
logger.info(
98100
"Updating max_length from %s to %s", self.max_length, max_length
@@ -548,10 +550,9 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]:
548550
to_infer.append(entity)
549551
return le, to_infer
550552

551-
def __call__(self, doc: MutableDocument) -> MutableDocument:
552-
# Reset main entities, will be recreated later
553-
doc.linked_ents.clear()
554-
553+
def predict_entities(self, doc: MutableDocument,
554+
ents: list[MutableEntity] | None = None
555+
) -> list[MutableEntity]:
555556
if self.cdb.is_dirty:
556557
logging.warning(
557558
"CDB has been modified since last save/load. "
@@ -580,11 +581,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
580581
for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size):
581582
le.extend(list(self._inference(doc, entities)))
582583

583-
doc.ner_ents.clear()
584-
doc.ner_ents.extend(le)
585-
create_main_ann(doc)
586-
587-
return doc
584+
return filter_linked_annotations(doc, le)
588585

589586
@property
590587
def names_context_matrix(self):

medcat-v2/medcat/components/ner/dict_based_ner.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Optional
22

33
import logging
4-
from medcat.tokenizing.tokens import MutableDocument
5-
from medcat.components.types import CoreComponentType, AbstractCoreComponent
4+
from medcat.tokenizing.tokens import MutableDocument, MutableEntity
5+
from medcat.components.types import CoreComponentType
6+
from medcat.components.types import AbstractEntityProvidingComponent
67
from medcat.components.ner.vocab_based_annotator import maybe_annotate_name
78
from medcat.utils.import_utils import ensure_optional_extras_installed
89
from medcat.tokenizing.tokenizers import BaseTokenizer
@@ -24,11 +25,12 @@
2425
logger = logging.getLogger(__name__)
2526

2627

27-
class NER(AbstractCoreComponent):
28+
class NER(AbstractEntityProvidingComponent):
2829
name = 'cat_dict_ner'
2930

3031
def __init__(self, tokenizer: BaseTokenizer,
3132
cdb: CDB) -> None:
33+
super().__init__()
3234
self.tokenizer = tokenizer
3335
self.cdb = cdb
3436
self.config = self.cdb.config
@@ -60,7 +62,9 @@ def _rebuild_automaton(self):
6062
def get_type(self) -> CoreComponentType:
6163
return CoreComponentType.ner
6264

63-
def __call__(self, doc: MutableDocument) -> MutableDocument:
65+
def predict_entities(self, doc: MutableDocument,
66+
ents: list[MutableEntity] | None = None
67+
) -> list[MutableEntity]:
6468
"""Detect candidates for concepts - linker will then be able
6569
to do the rest. It adds `entities` to the doc.entities and each
6670
entity can have the entity.link_candidates - that the linker
@@ -69,15 +73,20 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
6973
Args:
7074
doc (MutableDocument):
7175
Spacy document to be annotated with named entities.
76+
ents (list[MutableEntity] | None):
77+
The entities given. This should be None.
7278
7379
Returns:
74-
doc (MutableDocument):
75-
Spacy document with detected entities.
80+
list[MutableEntity]:
81+
The NER'ed entities.
7682
"""
83+
if ents is not None:
84+
ValueError(f"Unexpected entities sent to NER: {ents}")
7785
if self.cdb.has_changed_names:
7886
self.cdb._reset_subnames()
7987
self._rebuild_automaton()
8088
text = doc.base.text.lower()
89+
ner_ents: list[MutableEntity] = []
8190
for end_idx, raw_name in self.automaton.iter(text):
8291
start_idx = end_idx - len(raw_name) + 1
8392
cur_tokens = doc.get_tokens(start_idx, end_idx)
@@ -96,9 +105,12 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
96105
continue
97106
preprocessed_name = raw_name.replace(
98107
' ', self.config.general.separator)
99-
maybe_annotate_name(self.tokenizer, preprocessed_name, cur_tokens,
100-
doc, self.cdb, self.config)
101-
return doc
108+
ent = maybe_annotate_name(
109+
self.tokenizer, preprocessed_name, cur_tokens,
110+
doc, self.cdb, self.config, len(ner_ents))
111+
if ent:
112+
ner_ents.append(ent)
113+
return ner_ents
102114

103115
@classmethod
104116
def create_new_component(

0 commit comments

Comments
 (0)