Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions medcat-v2/medcat/components/addons/addons.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Callable, Protocol, Any, runtime_checkable
from typing import Callable, Protocol, Any, runtime_checkable, Optional

from medcat.components.types import BaseComponent, MutableEntity
from medcat.utils.registry import Registry
from medcat.config.config import ComponentConfig
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.tokenizing.tokenizers import BaseTokenizer


@runtime_checkable
Expand All @@ -19,9 +22,15 @@ def addon_type(self) -> str:
def is_core(self) -> bool:
return False

@classmethod
def get_folder_name_for_addon_and_name(
cls, addon_type: str, name: str) -> str:
return (cls.NAME_PREFIX + addon_type +
cls.NAME_SPLITTER + name)

def get_folder_name(self) -> str:
return (self.NAME_PREFIX + self.addon_type +
self.NAME_SPLITTER + self.name)
return self.get_folder_name_for_addon_and_name(
self.addon_type, self.name)

@property
def full_name(self) -> str:
Expand All @@ -36,51 +45,64 @@ def get_output_key_val(self, ent: MutableEntity
pass


AddonClass = Callable[[ComponentConfig, BaseTokenizer,
CDB, Vocab, Optional[str]], AddonComponent]


_DEFAULT_ADDONS: dict[str, tuple[str, str]] = {
'meta_cat': ('medcat.components.addons.meta_cat.meta_cat',
'MetaCATAddon.create_new'),
'MetaCATAddon.create_new_component'),
'rel_cat': ('medcat.components.addons.relation_extraction.rel_cat',
'RelCATAddon.create_new')
'RelCATAddon.create_new_component')
}

# NOTE: type error due to non-concrete type
_ADDON_REGISTRY = Registry(AddonComponent, _DEFAULT_ADDONS) # type: ignore


def register_addon(addon_name: str,
addon_cls: Callable[..., AddonComponent]) -> None:
addon_cls: AddonClass) -> None:
"""Register a new addon.

Args:
addon_name (str): The addon name.
addon_cls (Callable[..., AddonComponent]): The addon creator.
addon_cls (AddonClass): The addon creator.
"""
_ADDON_REGISTRY.register(addon_name, addon_cls)


def get_addon_creator(addon_name: str) -> Callable[..., AddonComponent]:
def get_addon_creator(addon_name: str) -> AddonClass:
"""Get the creator for an addon.

Args:
addon_name (str): The name of the addonl

Returns:
Callable[..., AddonComponent]: The creator of the addon.
AddonClass: The creator of the addon.
"""
return _ADDON_REGISTRY.get_component(addon_name)


def create_addon(addon_name: str, cnf: ComponentConfig,
*args, **kwargs) -> AddonComponent:
def create_addon(
addon_name: str, cnf: ComponentConfig,
tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> AddonComponent:
"""Create an addon of the specified name with the specified arguments.

All the `*args`, and `**kwrags` are passed to the creator.

Args:
addon_name (str): The name of the addon.
cnf (ComponentConfig): The addon config.
tokenizer (BaseTokenizer): The base tokenizer to be passed to creator.
cdb (CDB): The CDB to be passed to creator.
vocab (Vocab): The Vocab to be passed to creator.
model_load_path (Optional[str]): The optional model load path to be
passed to creator.


Returns:
AddonComponent: The resulting / created addon.
"""
return get_addon_creator(addon_name)(cnf, *args, **kwargs)
return get_addon_creator(addon_name)(
cnf, tokenizer, cdb, vocab, model_load_path)
37 changes: 23 additions & 14 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import nn, Tensor
from medcat.tokenizing.tokenizers import BaseTokenizer
from medcat.config.config import ComponentConfig
from medcat.config.config_meta_cat import ConfigMetaCAT
from medcat.components.addons.meta_cat.ml_utils import (
predict, train_model, set_all_seeds, eval_model)
Expand All @@ -25,6 +26,7 @@
from medcat.tokenizing.tokens import MutableDocument, MutableEntity
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.utils.defaults import COMPONENTS_FOLDER
from peft import get_peft_model, LoraConfig, TaskType

# It should be safe to do this always, as all other multiprocessing
Expand Down Expand Up @@ -84,6 +86,23 @@ def create_new(cls, config: ConfigMetaCAT, base_tokenizer: BaseTokenizer,
meta_cat = MetaCAT(tokenizer, embeddings=None, config=config)
return cls(config, base_tokenizer, meta_cat)

@classmethod
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
) -> 'MetaCATAddon':
if not isinstance(cnf, ConfigMetaCAT):
raise ValueError(f"Incompatible config: {cnf}")
if model_load_path is not None:
components_folder = os.path.join(
model_load_path, COMPONENTS_FOLDER)
folder_name = cls.get_folder_name_for_addon_and_name(
cls.addon_type, str(cnf.general.category_name))
load_path = os.path.join(components_folder, folder_name)
return cls.load_existing(cnf, tokenizer, load_path)
# TODO: tokenizer preprocessing for (e.g) BPE tokenizer (see PR #67)
return cls.create_new(cnf, tokenizer, None)

@classmethod
def load_existing(cls, cnf: ConfigMetaCAT,
base_tokenizer: BaseTokenizer,
Expand All @@ -100,18 +119,6 @@ def name(self) -> str:
def __call__(self, doc: MutableDocument) -> MutableDocument:
return self.mc(doc)

@classmethod
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> list[Any]:
# NOTE: cnf is silent init parameter
return []

@classmethod
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> dict[str, Any]:
# cls.init_tokenizer(cnf, model_load_path)
return {'base_tokenizer': tokenizer}

def load(self, folder_path: str) -> 'MetaCAT':
mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths(
folder_path)
Expand Down Expand Up @@ -169,8 +176,10 @@ def serialise_to(self, folder_path: str) -> None:
@classmethod
def deserialise_from(cls, folder_path: str, **init_kwargs
) -> 'MetaCATAddon':
# NOTE: model load path sent by kwargs
return cls.load_existing(load_path=folder_path, **init_kwargs)
return cls.load_existing(
load_path=folder_path,
cnf=init_kwargs['cnf'],
base_tokenizer=init_kwargs['tokenizer'])

def get_strategy(self) -> SerialisingStrategy:
return SerialisingStrategy.MANUAL
Expand Down
46 changes: 26 additions & 20 deletions medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import random
from typing import Optional, Any
from typing import Optional

from sklearn.utils import compute_class_weight
import torch
Expand All @@ -18,7 +18,7 @@

from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.config import Config
from medcat.config.config import Config, ComponentConfig
from medcat.config.config_rel_cat import ConfigRelCAT
from medcat.storage.serialisers import deserialise
from medcat.storage.serialisables import SerialisingStrategy
Expand All @@ -32,6 +32,7 @@
from medcat.components.addons.relation_extraction.rel_dataset import RelData
from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer
from medcat.tokenizing.tokens import MutableDocument
from medcat.utils.defaults import COMPONENTS_FOLDER


logger = logging.getLogger(__name__)
Expand All @@ -54,6 +55,20 @@ def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer,
return cls(config,
RelCAT(base_tokenizer, cdb, config=config, init_model=True))

@classmethod
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
) -> 'RelCATAddon':
if not isinstance(cnf, ConfigRelCAT):
raise ValueError(f"Incompatible config: {cnf}")
config = cnf
if model_load_path is not None:
load_path = os.path.join(model_load_path, COMPONENTS_FOLDER,
cls.NAME_PREFIX + cls.addon_type)
return cls.load_existing(config, tokenizer, cdb, load_path)
return cls.create_new(config, tokenizer, cdb)

@classmethod
def load_existing(cls, cnf: ConfigRelCAT,
base_tokenizer: BaseTokenizer,
Expand All @@ -70,21 +85,6 @@ def serialise_to(self, folder_path: str) -> None:
os.mkdir(folder_path)
self._rel_cat.save(folder_path)

@classmethod
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> list[Any]:
# NOTE: cnf is silent init parameter
return []

@classmethod
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> dict[str, Any]:
# cls.init_tokenizer(cnf, model_load_path)
return {
'base_tokenizer': tokenizer,
"cdb": cdb
}

@property
def name(self) -> str:
return str(self.addon_type)
Expand All @@ -95,7 +95,12 @@ def name(self) -> str:
def deserialise_from(cls, folder_path: str, **init_kwargs
) -> 'RelCATAddon':
# NOTE: model load path sent by kwargs
return cls.load_existing(load_path=folder_path, **init_kwargs)
return cls.load_existing(
load_path=folder_path,
base_tokenizer=init_kwargs['tokenizer'],
cnf=init_kwargs['cnf'],
cdb=init_kwargs['cdb'],
)

def get_strategy(self) -> SerialisingStrategy:
return SerialisingStrategy.MANUAL
Expand Down Expand Up @@ -232,7 +237,7 @@ def load(cls, load_path: str = "./") -> "RelCAT":

rel_cat = RelCAT(
# NOTE: this is a throaway tokenizer just for registrations
create_tokenizer(cdb.config.general.nlp.provider),
create_tokenizer(cdb.config.general.nlp.provider, cdb.config),
cdb=cdb, config=component.relcat_config, task=component.task)
rel_cat.device = device
rel_cat.component = component
Expand Down Expand Up @@ -883,7 +888,8 @@ def predict_text_with_anns(self, text: str, annotations: list[dict]
Doc: spacy doc with the relations.
"""
# NOTE: This runs not an empty language, but the specified one
base_tokenizer = create_tokenizer(self.cdb.config.general.nlp.provider)
base_tokenizer = create_tokenizer(
self.cdb.config.general.nlp.provider, self.cdb.config)
doc = base_tokenizer(text)

for ann in annotations:
Expand Down
17 changes: 7 additions & 10 deletions medcat-v2/medcat/components/linking/context_based_linker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import random
import logging
from typing import Iterator, Optional, Union, Any
from typing import Iterator, Optional, Union

from medcat.components.types import CoreComponentType, AbstractCoreComponent
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
from medcat.components.linking.vector_context_model import (
ContextModel, PerDocumentTokenCache)
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.config import Config
from medcat.config.config import Config, ComponentConfig
from medcat.utils.defaults import StatusTypes as ST
from medcat.utils.postprocessing import create_main_ann
from medcat.tokenizing.tokenizers import BaseTokenizer
Expand Down Expand Up @@ -245,11 +245,8 @@ def train(self, cui: str,
cui, entity, doc, per_doc_valid_token_cache, negative, names)

@classmethod
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> list[Any]:
return [cdb, vocab, cdb.config]

@classmethod
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> dict[str, Any]:
return {}
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
) -> 'Linker':
return cls(cdb, vocab, cdb.config)
16 changes: 7 additions & 9 deletions medcat-v2/medcat/components/linking/no_action_linker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Optional
from typing import Optional

from medcat.components.types import CoreComponentType, AbstractCoreComponent
from medcat.tokenizing.tokens import MutableDocument
from medcat.tokenizing.tokenizers import BaseTokenizer
from medcat.cdb.cdb import CDB
from medcat.vocab import Vocab
from medcat.config.config import ComponentConfig


class NoActionLinker(AbstractCoreComponent):
Expand All @@ -17,11 +18,8 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
return doc

@classmethod
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> list[Any]:
return []

@classmethod
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> dict[str, Any]:
return {}
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
) -> 'NoActionLinker':
return cls()
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from medcat.cdb.cdb import CDB
from medcat.vocab import Vocab
from medcat.config.config import Config, SerialisableBaseModel
from medcat.config.config import Config, SerialisableBaseModel, ComponentConfig
from medcat.utils.defaults import StatusTypes as ST
from medcat.utils.matutils import sigmoid
from medcat.utils.config_utils import temp_changed_config
Expand Down Expand Up @@ -255,14 +255,11 @@ def train(self, cui: str,
per_doc_valid_token_cache=pdc)

@classmethod
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> list[Any]:
return [cdb, vocab, cdb.config]

@classmethod
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
model_load_path: Optional[str]) -> dict[str, Any]:
return {}
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
) -> 'TwoStepLinker':
return cls(cdb, vocab, cdb.config)

@property
def two_step_config(self) -> 'TwoStepLinkerConfig':
Expand Down
Loading