Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
623aa00
CU-869b44wz8: Create new abstraction layer for entity providing compo…
mart-r Nov 10, 2025
fc5ee2d
CU-869b44wz8: Use new abstraction for linkers
mart-r Nov 10, 2025
03b3ead
CU-869b44wz8: Use new abstraaction for DeID
mart-r Nov 10, 2025
15a5ad8
CU-869b44wz8: Fix setting of linker entities - do it all in one place
mart-r Nov 10, 2025
8a91f9f
Fix NER tests
mart-r Nov 10, 2025
7bc4dd5
Fix postporcesing tests
mart-r Nov 10, 2025
bcd7f18
CU-869b44wz8: Update NER components with new abstraction
mart-r Nov 10, 2025
19f6db8
CU-869b44wz8: Fix issue with wrong base class
mart-r Nov 10, 2025
6d0612c
CU-869b44wz8: Add missing base class init call
mart-r Nov 10, 2025
3747dca
CU-869b44wz8: Fix typo
mart-r Nov 10, 2025
a9fa26a
CU-869b44wz8: Avoid implicit use of doc.ner_ents
mart-r Nov 10, 2025
c4583e0
CU-869b44wz8: Fix issue with entity IDs
mart-r Nov 10, 2025
4926d43
Update tutorial with up to date example
mart-r Nov 11, 2025
7e202a4
CU-869b44wz8: Fix issue with wrong base class in tutorial
mart-r Nov 11, 2025
0b0d698
CU-869b44wz8: Reinstate old signature of create_main_ann and use new one
mart-r Nov 11, 2025
4c4113b
CU-869b44wz8: Deprecate old create_main_ann method
mart-r Nov 11, 2025
4a56bbd
CU-869b44wz8: Use correct syntax in tutorials for maybe_annotate_name
mart-r Nov 11, 2025
f7fe6e9
CU-869b44wz8: Allow None for current ID and produce a unique ID if ne…
mart-r Nov 11, 2025
dafc986
CU-869b44wz8: Add entity to doc.ner_ents during annotate_name if no I…
mart-r Nov 12, 2025
82577ef
CU-869b44wz8: Add a few tests for old and new API for maybe_annnotate…
mart-r Nov 12, 2025
0723230
CU-869b44wz8: Fix old behaviour of create_main_ann
mart-r Nov 12, 2025
5dd5739
CU-869b44wz8: Add a few small tests fro create_main_ann and filter_li…
mart-r Nov 12, 2025
14fc9ae
CU-869b44wz8: Add a baseline test
mart-r Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"from medcat.cdb.cdb import CDB\n",
"from medcat.config.config import Ner\n",
"# for the component itself\n",
"from medcat.components.types import AbstractCoreComponent, CoreComponentType\n",
"from medcat.components.types import CoreComponentType\n",
"from medcat.components.types import AbstractEntityProvidingComponent\n",
"from medcat.tokenizing.tokens import MutableDocument, MutableEntity\n",
"from medcat.components.ner.vocab_based_annotator import maybe_annotate_name\n",
"\n",
Expand All @@ -59,7 +60,7 @@
" return min(max(self.min, num), self.max)\n",
"\n",
"\n",
"class RandomNER(AbstractCoreComponent):\n",
"class RandomNER(AbstractEntityProvidingComponent):\n",
" # NOTE: NEED TO IMPLEMENT\n",
" name = \"RANDOM_NER\"\n",
"\n",
Expand All @@ -73,6 +74,7 @@
" # NOTE: NEED TO IMPLEMENT\n",
" # you can specify whatever init args as long as you define them above\n",
" def __init__(self, tokenizer: BaseTokenizer, cdb: CDB):\n",
" super().__init__()\n",
" self.tokenizer = tokenizer\n",
" self.cdb = cdb\n",
"\n",
Expand All @@ -90,7 +92,9 @@
" return CoreComponentType.ner\n",
"\n",
" # NOTE: NEED TO IMPLEMENT\n",
" def __call__(self, doc: MutableDocument) -> MutableDocument:\n",
" def predict_entities(self, doc: MutableDocument,\n",
" ents: list[MutableEntity] | None = None\n",
" ) -> list[MutableEntity]:\n",
" \"\"\"Detect candidates for concepts - linker will then be able\n",
" to do the rest. It adds `entities` to the doc.entities and each\n",
" entity can have the entity.link_candidates - that the linker\n",
Expand All @@ -99,6 +103,8 @@
" Args:\n",
" doc (MutableDocument):\n",
" Spacy document to be annotated with named entities.\n",
" ents list[MutableEntity] | None = None:\n",
" The entties to use. None expected here.\n",
"\n",
" Returns:\n",
" doc (MutableDocument):\n",
Expand All @@ -113,6 +119,7 @@
" for start in start_tkn_indices]\n",
" choose_from = list(self.cdb.name2info.keys())\n",
" chosen_name = [random.choice(choose_from) for _ in start_tkn_indices]\n",
" ner_ents: list[MutableEntity] = []\n",
" for tkn_start_idx, tkn_end_idx, linked_name in zip(start_tkn_indices, end_tkn_indices, chosen_name):\n",
" char_start_idx = doc[tkn_start_idx].base.char_index\n",
" # NOTE: can only do this since we're never selecting the last token\n",
Expand All @@ -123,8 +130,10 @@
" # safe to assume that these are all lists of tokens\n",
"\n",
" # this checks the config (i.e length and stuff) and then annotes\n",
" maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config)\n",
" return doc\n",
" ent = maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config, len(ner_ents))\n",
" if ent:\n",
" ner_ents.append(ent)\n",
" return ner_ents\n",
"\n"
]
},
Expand Down
40 changes: 25 additions & 15 deletions medcat-v2/medcat/components/linking/context_based_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
import logging
from typing import Iterator, Optional, Union

from medcat.components.types import CoreComponentType, AbstractCoreComponent
from medcat.components.types import CoreComponentType
from medcat.components.types import AbstractEntityProvidingComponent
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
from medcat.components.linking.vector_context_model import (
ContextModel, PerDocumentTokenCache)
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.config.config import Config, ComponentConfig
from medcat.utils.defaults import StatusTypes as ST
from medcat.utils.postprocessing import create_main_ann
from medcat.utils.postprocessing import filter_linked_annotations
from medcat.tokenizing.tokenizers import BaseTokenizer


logger = logging.getLogger(__name__)


# class Linker(PipeRunner):
class Linker(AbstractCoreComponent):
class Linker(AbstractEntityProvidingComponent):
"""Link to a biomedical database.

Args:
Expand All @@ -32,6 +33,7 @@ class Linker(AbstractCoreComponent):

# Override
def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None:
super().__init__()
self.cdb = cdb
self.vocab = vocab
self.config = config
Expand Down Expand Up @@ -105,9 +107,11 @@ def _process_entity_train(self, doc: MutableDocument,
entity.context_similarity = 1
yield entity

def _train_on_doc(self, doc: MutableDocument) -> Iterator[MutableEntity]:
def _train_on_doc(self, doc: MutableDocument,
ner_ents: list[MutableEntity]
) -> Iterator[MutableEntity]:
# Run training
for entity in doc.ner_ents:
for entity in ner_ents:
yield from self._process_entity_train(
doc, entity, PerDocumentTokenCache())

Expand Down Expand Up @@ -186,35 +190,41 @@ def _process_entity_inference(
entity.context_similarity = context_similarity
yield entity

def _inference(self, doc: MutableDocument) -> Iterator[MutableEntity]:
def _inference(self, doc: MutableDocument,
ner_ents: list[MutableEntity]
) -> Iterator[MutableEntity]:
per_doc_valid_token_cache = PerDocumentTokenCache()
for entity in doc.ner_ents:
for entity in ner_ents:
logger.debug("Linker started with entity: %s", entity.base.text)
yield from self._process_entity_inference(
doc, entity, per_doc_valid_token_cache)

def __call__(self, doc: MutableDocument) -> MutableDocument:
def predict_entities(self, doc: MutableDocument,
ents: list[MutableEntity] | None = None
) -> list[MutableEntity]:
# Reset main entities, will be recreated later
doc.linked_ents.clear()
cnf_l = self.config.components.linking

if ents is None:
raise ValueError("Need to have NER'ed entities provided")

if cnf_l.train:
linked_entities = self._train_on_doc(doc)
linked_entities = self._train_on_doc(doc, ents)
else:
linked_entities = self._inference(doc)
linked_entities = self._inference(doc, ents)
# evaluating generator here because the `all_ents` list gets
# cleared afterwards otherwise
le = list(linked_entities)

doc.ner_ents.clear()
doc.ner_ents.extend(le)
create_main_ann(doc, self.config.general.show_nested_entities)
# doc.ner_ents.clear()
# doc.ner_ents.extend(le)

# TODO - reintroduce pretty labels? and apply here?

# TODO - reintroduce groups? and map here?

return doc
return filter_linked_annotations(
doc, le, self.config.general.show_nested_entities)

def train(self, cui: str,
entity: MutableEntity,
Expand Down
23 changes: 10 additions & 13 deletions medcat-v2/medcat/components/linking/embedding_linker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from medcat.cdb import CDB
from medcat.config.config import Config, ComponentConfig, EmbeddingLinking
from medcat.components.types import CoreComponentType, AbstractCoreComponent
from medcat.components.types import CoreComponentType
from medcat.components.types import AbstractEntityProvidingComponent
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
from medcat.tokenizing.tokenizers import BaseTokenizer
from typing import Optional, Iterator, Set
from medcat.vocab import Vocab
from medcat.utils.postprocessing import create_main_ann
from medcat.utils.postprocessing import filter_linked_annotations
from tqdm import tqdm
from collections import defaultdict
import logging
Expand All @@ -27,7 +28,7 @@
logger = logging.getLogger(__name__)


class Linker(AbstractCoreComponent):
class Linker(AbstractEntityProvidingComponent):
name = "embedding_linker"

def __init__(self, cdb: CDB, config: Config) -> None:
Expand All @@ -36,6 +37,7 @@ def __init__(self, cdb: CDB, config: Config) -> None:
cdb (CDB): The concept database to use.
config (Config): The base config.
"""
super().__init__()
self.cdb = cdb
self.config = config
if not isinstance(config.components.linking, EmbeddingLinking):
Expand Down Expand Up @@ -92,7 +94,7 @@ def create_embeddings(self,
using the chosen embedding model."""
if embedding_model_name is None:
embedding_model_name = self.cnf_l.embedding_model_name # fallback

if max_length is not None and max_length != self.max_length:
logger.info(
"Updating max_length from %s to %s", self.max_length, max_length
Expand Down Expand Up @@ -548,10 +550,9 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]:
to_infer.append(entity)
return le, to_infer

def __call__(self, doc: MutableDocument) -> MutableDocument:
# Reset main entities, will be recreated later
doc.linked_ents.clear()

def predict_entities(self, doc: MutableDocument,
ents: list[MutableEntity] | None = None
) -> list[MutableEntity]:
if self.cdb.is_dirty:
logging.warning(
"CDB has been modified since last save/load. "
Expand Down Expand Up @@ -580,11 +581,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size):
le.extend(list(self._inference(doc, entities)))

doc.ner_ents.clear()
doc.ner_ents.extend(le)
create_main_ann(doc)

return doc
return filter_linked_annotations(doc, le)

@property
def names_context_matrix(self):
Expand Down
30 changes: 21 additions & 9 deletions medcat-v2/medcat/components/ner/dict_based_ner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional

import logging
from medcat.tokenizing.tokens import MutableDocument
from medcat.components.types import CoreComponentType, AbstractCoreComponent
from medcat.tokenizing.tokens import MutableDocument, MutableEntity
from medcat.components.types import CoreComponentType
from medcat.components.types import AbstractEntityProvidingComponent
from medcat.components.ner.vocab_based_annotator import maybe_annotate_name
from medcat.utils.import_utils import ensure_optional_extras_installed
from medcat.tokenizing.tokenizers import BaseTokenizer
Expand All @@ -24,11 +25,12 @@
logger = logging.getLogger(__name__)


class NER(AbstractCoreComponent):
class NER(AbstractEntityProvidingComponent):
name = 'cat_dict_ner'

def __init__(self, tokenizer: BaseTokenizer,
cdb: CDB) -> None:
super().__init__()
self.tokenizer = tokenizer
self.cdb = cdb
self.config = self.cdb.config
Expand Down Expand Up @@ -60,7 +62,9 @@ def _rebuild_automaton(self):
def get_type(self) -> CoreComponentType:
return CoreComponentType.ner

def __call__(self, doc: MutableDocument) -> MutableDocument:
def predict_entities(self, doc: MutableDocument,
ents: list[MutableEntity] | None = None
) -> list[MutableEntity]:
"""Detect candidates for concepts - linker will then be able
to do the rest. It adds `entities` to the doc.entities and each
entity can have the entity.link_candidates - that the linker
Expand All @@ -69,15 +73,20 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
Args:
doc (MutableDocument):
Spacy document to be annotated with named entities.
ents (list[MutableEntity] | None):
The entities given. This should be None.

Returns:
doc (MutableDocument):
Spacy document with detected entities.
list[MutableEntity]:
The NER'ed entities.
"""
if ents is not None:
ValueError(f"Unexpected entities sent to NER: {ents}")
if self.cdb.has_changed_names:
self.cdb._reset_subnames()
self._rebuild_automaton()
text = doc.base.text.lower()
ner_ents: list[MutableEntity] = []
for end_idx, raw_name in self.automaton.iter(text):
start_idx = end_idx - len(raw_name) + 1
cur_tokens = doc.get_tokens(start_idx, end_idx)
Expand All @@ -96,9 +105,12 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
continue
preprocessed_name = raw_name.replace(
' ', self.config.general.separator)
maybe_annotate_name(self.tokenizer, preprocessed_name, cur_tokens,
doc, self.cdb, self.config)
return doc
ent = maybe_annotate_name(
self.tokenizer, preprocessed_name, cur_tokens,
doc, self.cdb, self.config, len(ner_ents))
if ent:
ner_ents.append(ent)
return ner_ents

@classmethod
def create_new_component(
Expand Down
Loading
Loading