diff --git a/medcat-v2-tutorials/notebooks/advanced/2._Create_and_use_component.ipynb b/medcat-v2-tutorials/notebooks/advanced/2._Create_and_use_component.ipynb index 24a2bb7c3..17f571c5f 100644 --- a/medcat-v2-tutorials/notebooks/advanced/2._Create_and_use_component.ipynb +++ b/medcat-v2-tutorials/notebooks/advanced/2._Create_and_use_component.ipynb @@ -38,7 +38,8 @@ "from medcat.cdb.cdb import CDB\n", "from medcat.config.config import Ner\n", "# for the component itself\n", - "from medcat.components.types import AbstractCoreComponent, CoreComponentType\n", + "from medcat.components.types import CoreComponentType\n", + "from medcat.components.types import AbstractEntityProvidingComponent\n", "from medcat.tokenizing.tokens import MutableDocument, MutableEntity\n", "from medcat.components.ner.vocab_based_annotator import maybe_annotate_name\n", "\n", @@ -59,7 +60,7 @@ " return min(max(self.min, num), self.max)\n", "\n", "\n", - "class RandomNER(AbstractCoreComponent):\n", + "class RandomNER(AbstractEntityProvidingComponent):\n", " # NOTE: NEED TO IMPLEMENT\n", " name = \"RANDOM_NER\"\n", "\n", @@ -73,6 +74,7 @@ " # NOTE: NEED TO IMPLEMENT\n", " # you can specify whatever init args as long as you define them above\n", " def __init__(self, tokenizer: BaseTokenizer, cdb: CDB):\n", + " super().__init__()\n", " self.tokenizer = tokenizer\n", " self.cdb = cdb\n", "\n", @@ -90,7 +92,9 @@ " return CoreComponentType.ner\n", "\n", " # NOTE: NEED TO IMPLEMENT\n", - " def __call__(self, doc: MutableDocument) -> MutableDocument:\n", + " def predict_entities(self, doc: MutableDocument,\n", + " ents: list[MutableEntity] | None = None\n", + " ) -> list[MutableEntity]:\n", " \"\"\"Detect candidates for concepts - linker will then be able\n", " to do the rest. It adds `entities` to the doc.entities and each\n", " entity can have the entity.link_candidates - that the linker\n", @@ -99,6 +103,8 @@ " Args:\n", " doc (MutableDocument):\n", " Spacy document to be annotated with named entities.\n", + " ents list[MutableEntity] | None = None:\n", + " The entties to use. None expected here.\n", "\n", " Returns:\n", " doc (MutableDocument):\n", @@ -113,6 +119,7 @@ " for start in start_tkn_indices]\n", " choose_from = list(self.cdb.name2info.keys())\n", " chosen_name = [random.choice(choose_from) for _ in start_tkn_indices]\n", + " ner_ents: list[MutableEntity] = []\n", " for tkn_start_idx, tkn_end_idx, linked_name in zip(start_tkn_indices, end_tkn_indices, chosen_name):\n", " char_start_idx = doc[tkn_start_idx].base.char_index\n", " # NOTE: can only do this since we're never selecting the last token\n", @@ -123,8 +130,10 @@ " # safe to assume that these are all lists of tokens\n", "\n", " # this checks the config (i.e length and stuff) and then annotes\n", - " maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config)\n", - " return doc\n", + " ent = maybe_annotate_name(self.tokenizer, linked_name, cur_tokens, doc, self.cdb, self.cdb.config, len(ner_ents))\n", + " if ent:\n", + " ner_ents.append(ent)\n", + " return ner_ents\n", "\n" ] }, diff --git a/medcat-v2/medcat/components/linking/context_based_linker.py b/medcat-v2/medcat/components/linking/context_based_linker.py index 860258bb5..f171a931b 100644 --- a/medcat-v2/medcat/components/linking/context_based_linker.py +++ b/medcat-v2/medcat/components/linking/context_based_linker.py @@ -2,7 +2,8 @@ import logging from typing import Iterator, Optional, Union -from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.components.types import CoreComponentType +from medcat.components.types import AbstractEntityProvidingComponent from medcat.tokenizing.tokens import MutableEntity, MutableDocument from medcat.components.linking.vector_context_model import ( ContextModel, PerDocumentTokenCache) @@ -10,7 +11,7 @@ from medcat.vocab import Vocab from medcat.config.config import Config, ComponentConfig from medcat.utils.defaults import StatusTypes as ST -from medcat.utils.postprocessing import create_main_ann +from medcat.utils.postprocessing import filter_linked_annotations from medcat.tokenizing.tokenizers import BaseTokenizer @@ -18,7 +19,7 @@ # class Linker(PipeRunner): -class Linker(AbstractCoreComponent): +class Linker(AbstractEntityProvidingComponent): """Link to a biomedical database. Args: @@ -32,6 +33,7 @@ class Linker(AbstractCoreComponent): # Override def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None: + super().__init__() self.cdb = cdb self.vocab = vocab self.config = config @@ -105,9 +107,11 @@ def _process_entity_train(self, doc: MutableDocument, entity.context_similarity = 1 yield entity - def _train_on_doc(self, doc: MutableDocument) -> Iterator[MutableEntity]: + def _train_on_doc(self, doc: MutableDocument, + ner_ents: list[MutableEntity] + ) -> Iterator[MutableEntity]: # Run training - for entity in doc.ner_ents: + for entity in ner_ents: yield from self._process_entity_train( doc, entity, PerDocumentTokenCache()) @@ -186,35 +190,41 @@ def _process_entity_inference( entity.context_similarity = context_similarity yield entity - def _inference(self, doc: MutableDocument) -> Iterator[MutableEntity]: + def _inference(self, doc: MutableDocument, + ner_ents: list[MutableEntity] + ) -> Iterator[MutableEntity]: per_doc_valid_token_cache = PerDocumentTokenCache() - for entity in doc.ner_ents: + for entity in ner_ents: logger.debug("Linker started with entity: %s", entity.base.text) yield from self._process_entity_inference( doc, entity, per_doc_valid_token_cache) - def __call__(self, doc: MutableDocument) -> MutableDocument: + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: # Reset main entities, will be recreated later - doc.linked_ents.clear() cnf_l = self.config.components.linking + if ents is None: + raise ValueError("Need to have NER'ed entities provided") + if cnf_l.train: - linked_entities = self._train_on_doc(doc) + linked_entities = self._train_on_doc(doc, ents) else: - linked_entities = self._inference(doc) + linked_entities = self._inference(doc, ents) # evaluating generator here because the `all_ents` list gets # cleared afterwards otherwise le = list(linked_entities) - doc.ner_ents.clear() - doc.ner_ents.extend(le) - create_main_ann(doc, self.config.general.show_nested_entities) + # doc.ner_ents.clear() + # doc.ner_ents.extend(le) # TODO - reintroduce pretty labels? and apply here? # TODO - reintroduce groups? and map here? - return doc + return filter_linked_annotations( + doc, le, self.config.general.show_nested_entities) def train(self, cui: str, entity: MutableEntity, diff --git a/medcat-v2/medcat/components/linking/embedding_linker.py b/medcat-v2/medcat/components/linking/embedding_linker.py index cd72e7652..c13a48fb6 100644 --- a/medcat-v2/medcat/components/linking/embedding_linker.py +++ b/medcat-v2/medcat/components/linking/embedding_linker.py @@ -1,11 +1,12 @@ from medcat.cdb import CDB from medcat.config.config import Config, ComponentConfig, EmbeddingLinking -from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.components.types import CoreComponentType +from medcat.components.types import AbstractEntityProvidingComponent from medcat.tokenizing.tokens import MutableEntity, MutableDocument from medcat.tokenizing.tokenizers import BaseTokenizer from typing import Optional, Iterator, Set from medcat.vocab import Vocab -from medcat.utils.postprocessing import create_main_ann +from medcat.utils.postprocessing import filter_linked_annotations from tqdm import tqdm from collections import defaultdict import logging @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) -class Linker(AbstractCoreComponent): +class Linker(AbstractEntityProvidingComponent): name = "embedding_linker" def __init__(self, cdb: CDB, config: Config) -> None: @@ -36,6 +37,7 @@ def __init__(self, cdb: CDB, config: Config) -> None: cdb (CDB): The concept database to use. config (Config): The base config. """ + super().__init__() self.cdb = cdb self.config = config if not isinstance(config.components.linking, EmbeddingLinking): @@ -92,7 +94,7 @@ def create_embeddings(self, using the chosen embedding model.""" if embedding_model_name is None: embedding_model_name = self.cnf_l.embedding_model_name # fallback - + if max_length is not None and max_length != self.max_length: logger.info( "Updating max_length from %s to %s", self.max_length, max_length @@ -548,10 +550,9 @@ def _pre_inference(self, doc: MutableDocument) -> tuple[list, list]: to_infer.append(entity) return le, to_infer - def __call__(self, doc: MutableDocument) -> MutableDocument: - # Reset main entities, will be recreated later - doc.linked_ents.clear() - + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: if self.cdb.is_dirty: logging.warning( "CDB has been modified since last save/load. " @@ -580,11 +581,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: for entities in self._batch_data(to_infer, self.cnf_l.linking_batch_size): le.extend(list(self._inference(doc, entities))) - doc.ner_ents.clear() - doc.ner_ents.extend(le) - create_main_ann(doc) - - return doc + return filter_linked_annotations(doc, le) @property def names_context_matrix(self): diff --git a/medcat-v2/medcat/components/ner/dict_based_ner.py b/medcat-v2/medcat/components/ner/dict_based_ner.py index 463c62201..eefec3070 100644 --- a/medcat-v2/medcat/components/ner/dict_based_ner.py +++ b/medcat-v2/medcat/components/ner/dict_based_ner.py @@ -1,8 +1,9 @@ from typing import Optional import logging -from medcat.tokenizing.tokens import MutableDocument -from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.tokenizing.tokens import MutableDocument, MutableEntity +from medcat.components.types import CoreComponentType +from medcat.components.types import AbstractEntityProvidingComponent from medcat.components.ner.vocab_based_annotator import maybe_annotate_name from medcat.utils.import_utils import ensure_optional_extras_installed from medcat.tokenizing.tokenizers import BaseTokenizer @@ -24,11 +25,12 @@ logger = logging.getLogger(__name__) -class NER(AbstractCoreComponent): +class NER(AbstractEntityProvidingComponent): name = 'cat_dict_ner' def __init__(self, tokenizer: BaseTokenizer, cdb: CDB) -> None: + super().__init__() self.tokenizer = tokenizer self.cdb = cdb self.config = self.cdb.config @@ -60,7 +62,9 @@ def _rebuild_automaton(self): def get_type(self) -> CoreComponentType: return CoreComponentType.ner - def __call__(self, doc: MutableDocument) -> MutableDocument: + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: """Detect candidates for concepts - linker will then be able to do the rest. It adds `entities` to the doc.entities and each entity can have the entity.link_candidates - that the linker @@ -69,15 +73,20 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: Args: doc (MutableDocument): Spacy document to be annotated with named entities. + ents (list[MutableEntity] | None): + The entities given. This should be None. Returns: - doc (MutableDocument): - Spacy document with detected entities. + list[MutableEntity]: + The NER'ed entities. """ + if ents is not None: + ValueError(f"Unexpected entities sent to NER: {ents}") if self.cdb.has_changed_names: self.cdb._reset_subnames() self._rebuild_automaton() text = doc.base.text.lower() + ner_ents: list[MutableEntity] = [] for end_idx, raw_name in self.automaton.iter(text): start_idx = end_idx - len(raw_name) + 1 cur_tokens = doc.get_tokens(start_idx, end_idx) @@ -96,9 +105,12 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: continue preprocessed_name = raw_name.replace( ' ', self.config.general.separator) - maybe_annotate_name(self.tokenizer, preprocessed_name, cur_tokens, - doc, self.cdb, self.config) - return doc + ent = maybe_annotate_name( + self.tokenizer, preprocessed_name, cur_tokens, + doc, self.cdb, self.config, len(ner_ents)) + if ent: + ner_ents.append(ent) + return ner_ents @classmethod def create_new_component( diff --git a/medcat-v2/medcat/components/ner/trf/transformers_ner.py b/medcat-v2/medcat/components/ner/trf/transformers_ner.py index 459ca88b2..513c50ec8 100644 --- a/medcat-v2/medcat/components/ner/trf/transformers_ner.py +++ b/medcat-v2/medcat/components/ner/trf/transformers_ner.py @@ -12,7 +12,7 @@ from medcat.cdb.cdb import CDB from medcat.components.addons.meta_cat.ml_utils import set_all_seeds from medcat.utils.ner import transformers_ner -from medcat.utils.postprocessing import create_main_ann +from medcat.utils.postprocessing import filter_linked_annotations from medcat.utils.hasher import Hasher from medcat.config.config_transformers_ner import ConfigTransformersNER from medcat.config.config import ComponentConfig @@ -26,7 +26,8 @@ serialise, AvailableSerialisers, deserialise) from medcat.storage.serialisables import SerialisingStrategy from medcat.preprocessors.cleaners import NameDescriptor -from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.components.types import CoreComponentType +from medcat.components.types import AbstractEntityProvidingComponent from medcat.vocab import Vocab from medcat.utils.defaults import COMPONENTS_FOLDER @@ -44,7 +45,7 @@ logger = logging.getLogger(__name__) -class TransformersNER(AbstractCoreComponent): +class TransformersNER(AbstractEntityProvidingComponent): name = 'transformers_ner' _def_serialiser = AvailableSerialisers.dill @@ -53,6 +54,7 @@ def __init__(self, cdb: CDB, component: 'TransformersNERComponent', config: Optional[ConfigTransformersNER] = None, training_arguments=None,) -> None: + super().__init__(write_to_linked_ents=True) self._component = component @classmethod @@ -106,8 +108,13 @@ def save(self, folder: str, overwrite: bool = False) -> None: folder, serialiser=self._def_serialiser, overwrite=overwrite) - def __call__(self, doc: MutableDocument) -> MutableDocument: - return self._component(doc) + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + if ents: + raise ValueError( + "This method should ne be called with pre-defined entities") + return self._component(doc)[1] # for manual serialisability @@ -687,7 +694,8 @@ def batch_generator(stream: Iterable[MutableDocument], yield docs def pipe(self, stream: Iterable[Union[MutableDocument, None]], - *args, **kwargs) -> Iterator[MutableDocument]: + *args, **kwargs) -> Iterator[tuple[MutableDocument, + list[MutableEntity]]]: """Process many documents at once. Args: @@ -700,7 +708,8 @@ def pipe(self, stream: Iterable[Union[MutableDocument, None]], Doc: The same document. Returns: - Iterator[MutableDocument]: If the stream is None or empty. + Iterator[tuple[MutableDocument, list[MutableEntity]]]: The stream + of documents and entities """ # Just in case if stream is None or not stream: @@ -710,11 +719,11 @@ def pipe(self, stream: Iterable[Union[MutableDocument, None]], batch_size_chars = self.config.general.pipe_batch_size_in_chars yield from self._process(stream, batch_size_chars) # type: ignore - def _process_doc(self, doc: MutableDocument): + def _process_doc(self, doc: MutableDocument) -> list[MutableEntity]: aggr_strat = self.config.general.ner_aggregation_strategy res = self.ner_pipe(doc.base.text, aggregation_strategy=aggr_strat) - doc.ner_ents = [] # type: ignore + ents: list[MutableEntity] = [] for r in res: inds = [] for ind, word in enumerate(doc): @@ -732,15 +741,16 @@ def _process_doc(self, doc: MutableDocument): label=r['entity_group']) entity.cui = r['entity_group'] entity.context_similarity = r['score'] - entity.id = len(doc.ner_ents) + entity.id = len(ents) entity.confidence = r['score'] - doc.ner_ents.append(entity) - create_main_ann(doc) + ents.append(entity) + return filter_linked_annotations(doc, ents) def _process(self, stream: Iterable[Union[MutableDocument, None]], - batch_size_chars: int) -> Iterator[Optional[MutableDocument]]: + batch_size_chars: int) -> Iterator[ + tuple[MutableDocument, list[MutableEntity]]]: if not hasattr(self, "ner_pipe"): self.create_eval_pipeline() for docs in self.batch_generator( @@ -748,11 +758,12 @@ def _process(self, # For now we will process the documents one by one, should be # improved in the future to use batching for doc in docs: - self._process_doc(doc) - yield from docs + ents = self._process_doc(doc) + yield doc, ents # Override - def __call__(self, doc: MutableDocument) -> MutableDocument: + def __call__(self, doc: MutableDocument, + ) -> tuple[MutableDocument, list[MutableEntity]]: """Process one document, used in the spacy pipeline for sequential document processing. @@ -761,13 +772,10 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: A spacy document Returns: - Doc: The same spacy document. + tuple[MutableDocument, list[MutableEntity]]: The document and + the corresponding entities. """ - - # Just call the pipe method - doc = next(self.pipe(iter([doc]))) - - return doc + return next(self.pipe(iter([doc]))) # NOTE: Only needed for datasets backwards compatibility diff --git a/medcat-v2/medcat/components/ner/vocab_based_annotator.py b/medcat-v2/medcat/components/ner/vocab_based_annotator.py index 6e65651ad..6e28e6306 100644 --- a/medcat-v2/medcat/components/ner/vocab_based_annotator.py +++ b/medcat-v2/medcat/components/ner/vocab_based_annotator.py @@ -13,9 +13,13 @@ logger = logging.getLogger(__name__) +_START_INDEX_MULT = 1000 + + def annotate_name(tokenizer: BaseTokenizer, name: str, tkns: list[MutableToken], doc: MutableDocument, cdb: CDB, + cur_id: int | None, label: str): entity: MutableEntity = tokenizer.create_entity( doc, tkns[0].base.index, tkns[-1].base.index + 1, label=label) @@ -24,10 +28,29 @@ def annotate_name(tokenizer: BaseTokenizer, name: str, # All standard name entity recognition models will not set this. entity.detected_name = name entity.link_candidates = list(cdb.name2info[name]['per_cui_status']) - entity.id = len(doc.ner_ents) + + if cur_id is None: + logger.warning( + "`medcat.components.ner.vocab_based_annotator.annotate_name` " + "was called with no `cur_id`. This behaviour is not fully " + "supported anymore.") + start_index = entity.base.start_char_index + span_len = len(name) + cur_id = start_index * _START_INDEX_MULT + span_len + # NOTE: These will be unique if the maximum length of each + # entity does not exceed _START_INDEX_MULT (1000) + logger.warning( + "Using the text start index %d (multiplied by %d) and adding " + "the span length %d to get the id of %d", start_index, + _START_INDEX_MULT, span_len, cur_id) + logger.warning( + "Setting MutableDocument.ner_ents during the method " + "`medcat.components.ner.vocab_based_annotator.annotate_name` " + "because the old API (without an ID) was used") + doc.ner_ents.append(entity) # TODO: remove this + + entity.id = cur_id entity.confidence = -1 # This does not calculate confidence - # Append the entity to the document - doc.ner_ents.append(entity) # Not necessary, but why not logger.debug("NER detected an entity.\n\tDetected name: %s" + @@ -39,6 +62,7 @@ def annotate_name(tokenizer: BaseTokenizer, name: str, def maybe_annotate_name(tokenizer: BaseTokenizer, name: str, tkns: list[MutableToken], doc: MutableDocument, cdb: CDB, config: Config, + cur_id: int | None = None, label: str = 'concept' ) -> Optional[MutableEntity]: """Given a name it will check should it be annotated based on config rules. @@ -57,6 +81,8 @@ def maybe_annotate_name(tokenizer: BaseTokenizer, name: str, Concept database. config (Config): Global config for medcat. + cur_id (int | None): + The potential ID for the entity. Defaults to None. label (str): Label for this name (usually `concept` if we are using a vocab based approach). @@ -85,6 +111,7 @@ def maybe_annotate_name(tokenizer: BaseTokenizer, name: str, if (len(name) >= config.components.ner.upper_case_limit_len or (len(tkns) == 1 and tkns[0].base.is_upper)): # Everything is fine, mark name - return annotate_name(tokenizer, name, tkns, doc, cdb, label) + return annotate_name( + tokenizer, name, tkns, doc, cdb, cur_id, label) return None diff --git a/medcat-v2/medcat/components/ner/vocab_based_ner.py b/medcat-v2/medcat/components/ner/vocab_based_ner.py index afd12e41e..d214714a0 100644 --- a/medcat-v2/medcat/components/ner/vocab_based_ner.py +++ b/medcat-v2/medcat/components/ner/vocab_based_ner.py @@ -1,8 +1,9 @@ from typing import Optional import logging -from medcat.tokenizing.tokens import MutableDocument -from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.tokenizing.tokens import MutableDocument, MutableEntity +from medcat.components.types import CoreComponentType +from medcat.components.types import AbstractEntityProvidingComponent from medcat.components.ner.vocab_based_annotator import maybe_annotate_name from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.vocab import Vocab @@ -13,11 +14,12 @@ logger = logging.getLogger(__name__) -class NER(AbstractCoreComponent): +class NER(AbstractEntityProvidingComponent): name = 'cat_ner' def __init__(self, tokenizer: BaseTokenizer, cdb: CDB) -> None: + super().__init__() self.tokenizer = tokenizer self.cdb = cdb self.config = self.cdb.config @@ -25,7 +27,9 @@ def __init__(self, tokenizer: BaseTokenizer, def get_type(self) -> CoreComponentType: return CoreComponentType.ner - def __call__(self, doc: MutableDocument) -> MutableDocument: + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: """Detect candidates for concepts - linker will then be able to do the rest. It adds `entities` to the doc.entities and each entity can have the entity.link_candidates - that the linker @@ -34,15 +38,18 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: Args: doc (MutableDocument): Spacy document to be annotated with named entities. + ents (list[MutableEntity] | None): + The entities given. This should be None. Returns: - doc (MutableDocument): - Spacy document with detected entities. + list[MutableEntity]: + The NER'ed entities. """ max_skip_tokens = self.config.components.ner.max_skip_tokens _sep = self.config.general.separator # Just take the tokens we need _doc = [tkn for tkn in doc if not tkn.to_skip] + ner_ents: list[MutableEntity] = [] for i, tkn in enumerate(_doc): tkn = _doc[i] tkns = [tkn] @@ -60,8 +67,11 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: break # if name is in CDB if name in self.cdb.name2info and not tkn.base.is_stop: - maybe_annotate_name(self.tokenizer, name, tkns, doc, - self.cdb, self.config) + ent = maybe_annotate_name( + self.tokenizer, name, tkns, doc, + self.cdb, self.config, len(ner_ents)) + if ent: + ner_ents.append(ent) # if name is not a subname CDB (explicitly) if not name: # There has to be at least something appended to the name @@ -97,16 +107,21 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: if name_changed: if name in self.cdb.name2info: - maybe_annotate_name(self.tokenizer, name, tkns, doc, - self.cdb, self.config) + ent = maybe_annotate_name( + self.tokenizer, name, tkns, doc, + self.cdb, self.config, len(ner_ents)) + if ent: + ner_ents.append(ent) elif name_reverse is not None: if name_reverse in self.cdb.name2info: - maybe_annotate_name(self.tokenizer, name_reverse, tkns, - doc, self.cdb, self.config) + ent = maybe_annotate_name( + self.tokenizer, name_reverse, tkns, + doc, self.cdb, self.config, len(ner_ents)) + if ent: + ner_ents.append(ent) else: break - - return doc + return ner_ents @classmethod def create_new_component( diff --git a/medcat-v2/medcat/components/types.py b/medcat-v2/medcat/components/types.py index e2c28706d..a3aa549eb 100644 --- a/medcat-v2/medcat/components/types.py +++ b/medcat-v2/medcat/components/types.py @@ -1,6 +1,8 @@ from typing import Optional, Protocol, Callable, runtime_checkable, Union +from typing import Literal from typing_extensions import Self from enum import Enum, auto +from abc import ABC, abstractmethod from medcat.utils.registry import Registry, MedCATRegistryException from medcat.tokenizing.tokens import MutableDocument, MutableEntity @@ -69,7 +71,7 @@ def get_type(self) -> CoreComponentType: pass -class AbstractCoreComponent(CoreComponent): +class AbstractCoreComponent(ABC, CoreComponent): NAME_PREFIX = "core_" @property @@ -80,6 +82,78 @@ def is_core(self) -> bool: return True +class AbstractEntityProvidingComponent(AbstractCoreComponent): + """This is an abstract NER or linker component. + + The class simplifies some things so that they don't have to be + re-implemented in each implementation. + """ + + def __init__(self, + read_from_linked_ents: bool | Literal['auto'] = 'auto', + write_to_linked_ents: bool | Literal['auto'] = 'auto'): + is_linker = self.get_type() == CoreComponentType.linking + if read_from_linked_ents == 'auto': + self._read_from_linked_ents = is_linker + else: + self._read_from_linked_ents = read_from_linked_ents + if write_to_linked_ents == 'auto': + self._write_to_linked_ents = is_linker + else: + self._write_to_linked_ents = write_to_linked_ents + + # NOTE: These 2 are separated as methods to allow for custom behaviour + # when deeriving from this class + def get_ents_in(self, doc: MutableDocument) -> list[MutableEntity] | None: + return doc.ner_ents.copy() if self._read_from_linked_ents else None + + def set_ents(self, doc: MutableDocument, ents: list[MutableEntity] + ) -> None: + if self._write_to_linked_ents: + self.set_linked_ents(doc, ents) + else: + self.set_ner_ents(doc, ents) + + @classmethod + def set_ner_ents(cls, doc: MutableDocument, ents: list[MutableEntity] + ) -> None: + doc.ner_ents.clear() + doc.ner_ents.extend(ents) + + @classmethod + def set_linked_ents(cls, doc: MutableDocument, ents: list[MutableEntity] + ) -> None: + doc.linked_ents.clear() + doc.linked_ents.extend(ents) + + @abstractmethod + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + """Predict the relevant entities for the document. + + This is meant to be used for the NER or the Linker component. + The idea is that this is the specific implementation only really + needs to implement this method for inference to work. + + Args: + doc (MutableDocument): The document. + ents (list[MutableEntity] | None, optional): The entities to + consider (if any). If None, all possible entities in the + document are considered. Defaults to None. + + Returns: + list[MutableEntity]: The predicted entities in document. + """ + pass + + def __call__(self, doc: MutableDocument) -> MutableDocument: + in_ents = self.get_ents_in(doc) + out_ents = self.predict_entities(doc, in_ents) + self.set_ents(doc, out_ents) + return doc + + @runtime_checkable class HashableComponet(Protocol): diff --git a/medcat-v2/medcat/utils/postprocessing.py b/medcat-v2/medcat/utils/postprocessing.py index 3e313825f..817876b8b 100644 --- a/medcat-v2/medcat/utils/postprocessing.py +++ b/medcat-v2/medcat/utils/postprocessing.py @@ -1,29 +1,51 @@ +import warnings + from medcat.tokenizing.tokenizers import MutableDocument, MutableEntity +def create_main_ann(doc: MutableDocument, show_nested_entities: bool = False) -> None: + warnings.warn( + "The `medcat.utils.postprocessing.create_main_ann` method is" + "depreacated and subject to removal in a future release. Please " + "use `medcat.utils.postprocessing.filter_linked_annotations` instead.", + DeprecationWarning, + stacklevel=2 + ) + doc.linked_ents = filter_linked_annotations( # type: ignore + doc, doc.ner_ents, show_nested_entities=show_nested_entities) + + # NOTE: the following used (in medcat v1) check tuis # but they were never passed to the method so # I've omitted it now -def create_main_ann(doc: MutableDocument, show_nested_entities: bool = False) -> None: +def filter_linked_annotations( + doc: MutableDocument, + linked_ents: list[MutableEntity], + show_nested_entities: bool = False + ) -> list[MutableEntity]: """Creates annotation in the spacy ents list from all the annotations for this document. Args: doc (Doc): Spacy document. + linked_ents (list[MutableEntity]): The linked entities. show_nested_entities (bool): Whether to keep overlapping/nested entities. If True, keeps all entities. If False, filters overlapping entities keeping only the longest matches. Defaults to False. + + Returns: + list[MutbaleEntity]: The resulting entities """ if show_nested_entities: - doc.linked_ents = sorted(list(doc.linked_ents) + doc.ner_ents, # type: ignore - key=lambda ent: ent.base.start_char_index) + return sorted(list(linked_ents), + key=lambda ent: ent.base.start_char_index) else: # Filter overlapping entities using token indices (not object identity) - doc.ner_ents.sort(key=lambda x: len(x.base.text), reverse=True) + linked_ents.sort(key=lambda x: len(x.base.text), reverse=True) tkns_in = set() # Set of token indices main_anns: list[MutableEntity] = [] - for ent in doc.ner_ents: + for ent in linked_ents: to_add = True for tkn in ent: if tkn.base.index in tkns_in: # Use token index instead @@ -34,7 +56,5 @@ def create_main_ann(doc: MutableDocument, show_nested_entities: bool = False) -> tkns_in.add(tkn.base.index) main_anns.append(ent) - # unclear why the original doc.linked_ents needs to be preserved here. - doc.linked_ents = sorted(list(doc.linked_ents) + main_anns, # type: ignore - key=lambda ent: ent.base.start_char_index) - + return sorted(main_anns, + key=lambda ent: ent.base.start_char_index) diff --git a/medcat-v2/tests/components/ner/test_vocab_based_annotator.py b/medcat-v2/tests/components/ner/test_vocab_based_annotator.py new file mode 100644 index 000000000..877cd683a --- /dev/null +++ b/medcat-v2/tests/components/ner/test_vocab_based_annotator.py @@ -0,0 +1,41 @@ +from collections import defaultdict + +from medcat.components.ner import vocab_based_annotator +from medcat.tokenizing.tokenizers import create_tokenizer +from medcat.config import Config + +import unittest +import unittest.mock + + +class MaybeAnnotateNameTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + cls.tokenizer = create_tokenizer("regex", cls.cnf) + cls.example_name = "some long name" + cls.tokens = list(cls.tokenizer(cls.example_name)[:]) + cls.mock_cdb = unittest.mock.Mock() + cls.mock_cdb.name2info = defaultdict(lambda: defaultdict(lambda: "P")) + + def setUp(self): + self.mock_doc = unittest.mock.Mock() + # self.mock_doc.ner_ents = unittest.mock.Mock() + self.mock_doc._tokens = self.tokens + self.mock_doc.ner_ents.append = unittest.mock.Mock() + self.mock_doc.ner_ents.__len__ = unittest.mock.Mock(return_value=0) + + def test_old_API_has_side_effects(self): + vocab_based_annotator.maybe_annotate_name( + self.tokenizer, self.example_name, + tkns=self.tokens, doc=self.mock_doc, cdb=self.mock_cdb, + config=self.cnf) + self.mock_doc.ner_ents.append.assert_called_once() + + def test_new_API_has_no_side_effects(self): + vocab_based_annotator.maybe_annotate_name( + self.tokenizer, self.example_name, + tkns=self.tokens, doc=self.mock_doc, cdb=self.mock_cdb, + config=self.cnf, cur_id=1) + self.mock_doc.ner_ents.append.assert_not_called() diff --git a/medcat-v2/tests/utils/ner/test_deid.py b/medcat-v2/tests/utils/ner/test_deid.py index df684d9de..35289083e 100644 --- a/medcat-v2/tests/utils/ner/test_deid.py +++ b/medcat-v2/tests/utils/ner/test_deid.py @@ -215,7 +215,7 @@ def test_model_works_deid_text(self): def test_model_works_dunder_call(self): anon_doc = self.deid_model(input_text) self.assertIsInstance(anon_doc, runtime_checkable(MutableDocument)) - self.assertTrue(anon_doc.ner_ents) + self.assertTrue(anon_doc.linked_ents) def test_model_works_deid_text_redact(self): anon_text = self.deid_model.deid_text(input_text, redact=True) diff --git a/medcat-v2/tests/utils/test_postprocessing.py b/medcat-v2/tests/utils/test_postprocessing.py index 12b56ed9f..42c1410c8 100644 --- a/medcat-v2/tests/utils/test_postprocessing.py +++ b/medcat-v2/tests/utils/test_postprocessing.py @@ -1,8 +1,11 @@ import unittest +import unittest.mock from unittest.mock import Mock, MagicMock from typing import List -from medcat.utils.postprocessing import create_main_ann +from medcat.utils.postprocessing import filter_linked_annotations, create_main_ann +from medcat.components.types import AbstractEntityProvidingComponent + def create_mock_entity(text: str, start_char: int, end_char: int, cui: str = None, tokens: List = None): """Helper function to create a mock entity with minimal setup.""" @@ -60,7 +63,8 @@ def test_show_nested_entities_false_should_filter_overlaps(self): self.doc.ner_ents = [self.entity_chest_pain, self.entity_chest, self.entity_pain] - create_main_ann(self.doc, show_nested_entities=False) + AbstractEntityProvidingComponent.set_linked_ents( + self.doc, filter_linked_annotations(self.doc, self.doc.ner_ents, show_nested_entities=False)) entity_texts = [ent.base.text for ent in self.doc.linked_ents] @@ -75,7 +79,8 @@ def test_show_nested_entities_true_should_keep_overlaps(self): self.doc.ner_ents = [self.entity_chest_pain, self.entity_chest, self.entity_pain] - create_main_ann(self.doc, show_nested_entities=True) + AbstractEntityProvidingComponent.set_linked_ents( + self.doc, filter_linked_annotations(self.doc, self.doc.ner_ents, show_nested_entities=True)) entity_texts = [ent.base.text for ent in self.doc.linked_ents] @@ -96,7 +101,8 @@ def test_non_overlapping_entities_always_kept(self): self.doc.ner_ents = [self.entity_chest_pain, entity_dm] # Test with show_nested_entities=False - create_main_ann(self.doc, show_nested_entities=False) + AbstractEntityProvidingComponent.set_linked_ents( + self.doc, filter_linked_annotations(self.doc, self.doc.ner_ents, show_nested_entities=False)) entity_texts = [ent.base.text for ent in self.doc.linked_ents] @@ -130,7 +136,8 @@ def test_same_concept_multiple_locations(self): # Test with show_nested_entities=False self.doc.ner_ents = [entity_chest_pain_1, entity_chest_pain_2, entity_chest_1, entity_pain_1_overlap] - create_main_ann(self.doc, show_nested_entities=False) + AbstractEntityProvidingComponent.set_linked_ents( + self.doc, filter_linked_annotations(self.doc, self.doc.ner_ents, show_nested_entities=False)) entity_texts = [ent.base.text for ent in self.doc.linked_ents] entity_positions = [(ent.base.text, ent.base.start_char_index, ent.base.end_char_index) @@ -170,7 +177,8 @@ def test_same_concept_multiple_locations_with_nested_true(self): # Test with show_nested_entities=True self.doc.ner_ents = [entity_chest_pain_1, entity_chest_pain_2, entity_chest_1, entity_pain_1_overlap] - create_main_ann(self.doc, show_nested_entities=True) + AbstractEntityProvidingComponent.set_linked_ents( + self.doc, filter_linked_annotations(self.doc, self.doc.ner_ents, show_nested_entities=True)) entity_texts = [ent.base.text for ent in self.doc.linked_ents] @@ -181,5 +189,31 @@ def test_same_concept_multiple_locations_with_nested_true(self): self.assertIn("pain", entity_texts, "Should keep overlapping 'pain' entity") +class TestCreateMainAnn(unittest.TestCase): + + def setUp(self): + # self.mock_doc = unittest.mock.Mock() + # self.mock_doc.linked_ents.__iter__ = unittest.mock.Mock( + # return_value=iter([])) + self.mock_doc = create_mock_document( + f"{'st0':10s}{'st1':10s}{'st2':10s}{'st3':10s}") + # self.mock_doc.linked_ents.append = unittest.mock.Mock() + self.mock_entities = [create_mock_entity( + f"st{index}", index * 10, index * 10 + 3, cui="C1" + ) for index in range(4)] + self.mock_doc.ner_ents = self.mock_entities + + def test_init_doc_has_no_linked_ents(self): + self.assertEqual(len(self.mock_doc.linked_ents), 0) + + def test_create_main_ann_has_side_effect(self): + create_main_ann(self.mock_doc) + self.assertGreaterEqual(len(self.mock_doc.linked_ents), 1) + + def test_filter_linked_annotations_has_no_side_effect(self): + filter_linked_annotations(self.mock_doc, self.mock_entities) + self.assertEqual(len(self.mock_doc.linked_ents), 0) + + if __name__ == '__main__': unittest.main() \ No newline at end of file