Skip to content
15 changes: 11 additions & 4 deletions medcat-v2/medcat/components/addons/meta_cat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
def load(self, folder_path: str) -> 'MetaCAT':
mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths(
folder_path)
mc = cast(MetaCAT, deserialise(mc_path))
mc = cast(MetaCAT, deserialise(mc_path, save_dir_path=folder_path))
mc.tokenizer = self._load_tokenizer(self.config, tokenizer_folder)
return mc

Expand All @@ -150,6 +150,11 @@ def save(self, folder_path: str) -> None:
raise MisconfiguredMetaCATException(
"Unable to save MetaCAT without a tokenizer")
self.mc.tokenizer.save(tokenizer_folder)
if self.config.model.model_name == 'bert':
model_config_save_path = os.path.join(
folder_path, 'bert_config.json')
self._mc.model.bert_config.to_json_file( # type: ignore
model_config_save_path)

def _init_data_paths(self, base_tokenizer: BaseTokenizer):
# a dictionary like {category_name: value, ...}
Expand Down Expand Up @@ -293,7 +298,7 @@ def get_init_attrs(cls) -> list[str]:

@classmethod
def ignore_attrs(cls) -> list[str]:
return ['model']
return ['model', 'save_dir_path']

@classmethod
def include_properties(cls) -> list[str]:
Expand All @@ -308,10 +313,12 @@ def __init__(self,
tokenizer: Optional[TokenizerWrapperBase] = None,
embeddings: Optional[Union[Tensor, numpy.ndarray]] = None,
config: Optional[ConfigMetaCAT] = None,
_model_state_dict: Optional[dict[str, Any]] = None) -> None:
_model_state_dict: Optional[dict[str, Any]] = None,
save_dir_path: Optional[str] = None) -> None:
if config is None:
config = ConfigMetaCAT()
self.config = config
self.save_dir_path = save_dir_path
set_all_seeds(config.general.seed)

self.tokenizer = tokenizer
Expand Down Expand Up @@ -355,7 +362,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
elif config.model.model_name == 'bert':
from medcat.components.addons.meta_cat.models import (
BertForMetaAnnotation)
model = BertForMetaAnnotation(config)
model = BertForMetaAnnotation(config, self.save_dir_path)

if not config.model.model_freeze_layers:
peft_config = LoraConfig(
Expand Down
44 changes: 38 additions & 6 deletions medcat-v2/medcat/components/addons/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,54 @@ def forward(self,
class BertForMetaAnnotation(nn.Module):
_keys_to_ignore_on_load_unexpected: list[str] = [r"pooler"] # type: ignore

def __init__(self, config: ConfigMetaCAT):
def __init__(self, config: ConfigMetaCAT,
save_dir_path: Optional[str] = None):
super(BertForMetaAnnotation, self).__init__()
_bertconfig = AutoConfig.from_pretrained(
config.model.model_variant,
num_hidden_layers=config.model.num_layers)
if save_dir_path:
try:
_bertconfig = AutoConfig.from_pretrained(
save_dir_path + "/bert_config.json",
num_hidden_layers=config.model.num_layers)
except Exception as e:
_bertconfig = AutoConfig.from_pretrained(
config.model.model_variant,
num_hidden_layers=config.model.num_layers)
logger.info("BERT config not found locally — "
"downloaded successfully from Hugging Face.")
raise e
else:
_bertconfig = AutoConfig.from_pretrained(
config.model.model_variant,
num_hidden_layers=config.model.num_layers)

if config.model.input_size != _bertconfig.hidden_size:
logger.warning(
"Input size for %s model should be %d, provided input size is "
"%d. Input size changed to %d", config.model.model_variant,
_bertconfig.hidden_size, config.model.input_size,
_bertconfig.hidden_size)

bert = BertModel.from_pretrained(config.model.model_variant,
config=_bertconfig)
try:
bert = BertModel.from_pretrained(
config.model.model_variant,
config=_bertconfig)
except Exception as e:
bert = BertModel(_bertconfig)
if save_dir_path:
logger.info(
"Could not load BERT pretrained weights from Hugging Face."
" BERT model was loaded with random weights.\n"
"This will work the weights will be loaded off disk.")
else:
logger.warning(
"Could not load BERT pretrained weights from Hugging Face."
" BERT model was loaded with random weights.\n"
"DO NOT use this model without loading the model state!",
exc_info=e)

self.config = config
self.bert = bert
self.bert_config = _bertconfig
self.num_labels = config.model.nclasses
for param in self.bert.parameters():
param.requires_grad = not config.model.model_freeze_layers
Expand Down
78 changes: 78 additions & 0 deletions medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import socket
from contextlib import contextmanager

from medcat.components.addons.meta_cat import meta_cat
from medcat.storage.serialisers import serialise, deserialise

import unittest
import tempfile
import os
from functools import partial

import transformers

from .test_meta_cat import FakeTokenizer


@contextmanager
def assert_tries_network():
real_socket = socket.socket
calls = []

def guard(*args, **kwargs):
calls.append((len(args), len(kwargs)))
raise OSError("Network disabled for test")

socket.socket = guard
try:
yield
finally:
socket.socket = real_socket
assert calls, "No network calls were made during the test"


# NOTE: need to disable the usage of the cache
# otherwise other parts of the test suite
# might have already downloaded and cached
# the model and no network calls may be made
# in such a situation
@contextmanager
def force_hf_download():
orig_from_pretrained = transformers.BertModel.from_pretrained
transformers.BertModel.from_pretrained = partial(
orig_from_pretrained, force_download=True)
try:
yield
finally:
transformers.BertModel.from_pretrained = orig_from_pretrained


class BERTMetaCATTests(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.cnf = meta_cat.ConfigMetaCAT()
cls.cnf.model.model_name = 'bert'
cls.cnf.general.vocab_size = 10
cls.cnf.model.padding_idx = 5
cls.cnf.general.tokenizer_name = 'bert-tokenizer'
cls.cnf.model.model_variant = 'prajjwal1/bert-tiny'
cls.cnf.general.category_name = 'FAKE_category'
cls.cnf.general.category_value2id = {
'Future': 0, 'Past': 2, 'Recent': 1}
cls.tokenizer = FakeTokenizer()
cls.meta_cat = meta_cat.MetaCATAddon.create_new(cls.cnf, cls.tokenizer)

cls.temp_dir = tempfile.TemporaryDirectory()
cls.mc_save_path = os.path.join(cls.temp_dir.name, "bert_meta_cat")
serialise('dill', cls.meta_cat, cls.mc_save_path)

@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()

def test_no_network_load(self):
with assert_tries_network():
with force_hf_download():
mc = deserialise(self.mc_save_path)
self.assertIsInstance(mc, meta_cat.MetaCATAddon)
19 changes: 19 additions & 0 deletions medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from medcat.storage.serialisers import serialise, AvailableSerialisers
from medcat.config.config_meta_cat import ConfigMetaCAT
from medcat.config.config import Config
from medcat.utils.defaults import COMPONENTS_FOLDER

import os
import unittest.mock
import unittest
import tempfile

Expand Down Expand Up @@ -104,6 +107,22 @@ def test_can_save_and_load(self):
cat2 = CAT.load_model_pack(file_name)
self.assert_has_meta_cat(cat2, False)

def test_loading_uses_save_dir_path(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_name = self.cat.save_model_pack(
temp_dir, serialiser_type=self.SER_TYPE)
exp_meta_cat_path = os.path.join(
file_name.removesuffix(".zip"),
COMPONENTS_FOLDER,
self.meta_cat.get_folder_name()
)
cat = CAT.load_model_pack(file_name)
meta_cats = cat.get_addons_of_type(meta_cat.MetaCATAddon)
self.assertEqual(len(meta_cats), 1)
mc = meta_cats[0]
self.assertIsNotNone(mc.mc.save_dir_path)
self.assertEqual(mc.mc.save_dir_path, exp_meta_cat_path)

def test_turns_up_in_output(self):
ents = self.cat.get_entities(
"This is a fit text for rich and chronic disease like fittest.")
Expand Down
Loading