From 9fa9dca910db8060310b7c614a1f0ea5c71175b6 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 12:38:22 +0100 Subject: [PATCH 1/9] CU-8699wc4zb: Port PR 67 (offline BERT-based MetaCAT load) to v2 --- .../components/addons/meta_cat/meta_cat.py | 8 ++-- .../components/addons/meta_cat/models.py | 43 ++++++++++++++++--- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index d91ecc90..05ba28a8 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -125,7 +125,7 @@ def __call__(self, doc: MutableDocument) -> MutableDocument: def load(self, folder_path: str) -> 'MetaCAT': mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths( folder_path) - mc = cast(MetaCAT, deserialise(mc_path)) + mc = cast(MetaCAT, deserialise(mc_path, save_dir_path=folder_path)) mc.tokenizer = self._load_tokenizer(self.config, tokenizer_folder) return mc @@ -308,10 +308,12 @@ def __init__(self, tokenizer: Optional[TokenizerWrapperBase] = None, embeddings: Optional[Union[Tensor, numpy.ndarray]] = None, config: Optional[ConfigMetaCAT] = None, - _model_state_dict: Optional[dict[str, Any]] = None) -> None: + _model_state_dict: Optional[dict[str, Any]] = None, + save_dir_path: Optional[str] = None) -> None: if config is None: config = ConfigMetaCAT() self.config = config + self.save_dir_path = save_dir_path set_all_seeds(config.general.seed) self.tokenizer = tokenizer @@ -355,7 +357,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: elif config.model.model_name == 'bert': from medcat.components.addons.meta_cat.models import ( BertForMetaAnnotation) - model = BertForMetaAnnotation(config) + model = BertForMetaAnnotation(config, self.save_dir_path) if not config.model.model_freeze_layers: peft_config = LoraConfig( diff --git a/medcat-v2/medcat/components/addons/meta_cat/models.py b/medcat-v2/medcat/components/addons/meta_cat/models.py index 0f6074db..914e5a93 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/models.py +++ b/medcat-v2/medcat/components/addons/meta_cat/models.py @@ -98,11 +98,26 @@ def forward(self, class BertForMetaAnnotation(nn.Module): _keys_to_ignore_on_load_unexpected: list[str] = [r"pooler"] # type: ignore - def __init__(self, config: ConfigMetaCAT): + def __init__(self, config: ConfigMetaCAT, + save_dir_path: Optional[str] = None): super(BertForMetaAnnotation, self).__init__() - _bertconfig = AutoConfig.from_pretrained( - config.model.model_variant, - num_hidden_layers=config.model.num_layers) + if save_dir_path: + try: + _bertconfig = AutoConfig.from_pretrained( + save_dir_path + "/bert_config.json", + num_hidden_layers=config.model.num_layers) + except Exception as e: + _bertconfig = AutoConfig.from_pretrained( + config.model.model_variant, + num_hidden_layers=config.model.num_layers) + logger.info("BERT config not found locally — " + "downloaded successfully from Hugging Face.") + raise e + else: + _bertconfig = AutoConfig.from_pretrained( + config.model.model_variant, + num_hidden_layers=config.model.num_layers) + if config.model.input_size != _bertconfig.hidden_size: logger.warning( "Input size for %s model should be %d, provided input size is " @@ -110,8 +125,24 @@ def __init__(self, config: ConfigMetaCAT): _bertconfig.hidden_size, config.model.input_size, _bertconfig.hidden_size) - bert = BertModel.from_pretrained(config.model.model_variant, - config=_bertconfig) + try: + bert = BertModel.from_pretrained( + config.model.model_variant, + config=_bertconfig) + except Exception as e: + bert = BertModel(_bertconfig) + if save_dir_path: + logger.info( + "Could not load BERT pretrained weights from Hugging Face." + " BERT model was loaded with random weights.\n" + "This will work the weights will be loaded off disk.") + else: + logger.warning( + "Could not load BERT pretrained weights from Hugging Face." + " BERT model was loaded with random weights.\n" + "DO NOT use this model without loading the model state!", + exc_info=e) + self.config = config self.bert = bert self.num_labels = config.model.nclasses From 018446332544f3f913ffc0f8cc5e645cba1aab27 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 12:38:58 +0100 Subject: [PATCH 2/9] CU-8699wc4zb: Add a simple test to make sure that offline loading uses correct path to load model --- .../addons/meta_cat/test_meta_cat.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py index 71a49e7e..75f0126f 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py @@ -6,7 +6,10 @@ from medcat.storage.serialisers import serialise, AvailableSerialisers from medcat.config.config_meta_cat import ConfigMetaCAT from medcat.config.config import Config +from medcat.utils.defaults import COMPONENTS_FOLDER +import os +import unittest.mock import unittest import tempfile @@ -104,6 +107,26 @@ def test_can_save_and_load(self): cat2 = CAT.load_model_pack(file_name) self.assert_has_meta_cat(cat2, False) + def test_loading_uses_save_dir_path(self): + with tempfile.TemporaryDirectory() as temp_dir: + file_name = self.cat.save_model_pack( + temp_dir, serialiser_type=self.SER_TYPE) + exp_meta_cat_path = os.path.join( + file_name.removesuffix(".zip"), + COMPONENTS_FOLDER, + self.meta_cat.get_folder_name() + ) + with unittest.mock.patch.object( + meta_cat.MetaCAT, "__init__", + wraps=meta_cat.MetaCAT.__init__, + autospec=True) as mock_load: + CAT.load_model_pack(file_name) + mock_load.assert_called_once() + _, call_kwargs = mock_load.call_args + self.assertIn('save_dir_path', call_kwargs) + self.assertEqual( + call_kwargs['save_dir_path'], exp_meta_cat_path) + def test_turns_up_in_output(self): ents = self.cat.get_entities( "This is a fit text for rich and chronic disease like fittest.") From fc043fa2f9faf77a728cc28f6c796ce2078d716c Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 13:46:04 +0100 Subject: [PATCH 3/9] CU-8699wc4zb: Fix offline BERT based model load --- medcat-v2/medcat/components/addons/meta_cat/meta_cat.py | 5 +++++ medcat-v2/medcat/components/addons/meta_cat/models.py | 1 + 2 files changed, 6 insertions(+) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index 05ba28a8..b63d3dd4 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -150,6 +150,11 @@ def save(self, folder_path: str) -> None: raise MisconfiguredMetaCATException( "Unable to save MetaCAT without a tokenizer") self.mc.tokenizer.save(tokenizer_folder) + if self.config.model.model_name == 'bert': + model_config_save_path = os.path.join( + folder_path, 'bert_config.json') + self._mc.model.bert_config.to_json_file( # type: ignore + model_config_save_path) def _init_data_paths(self, base_tokenizer: BaseTokenizer): # a dictionary like {category_name: value, ...} diff --git a/medcat-v2/medcat/components/addons/meta_cat/models.py b/medcat-v2/medcat/components/addons/meta_cat/models.py index 914e5a93..5a7ff660 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/models.py +++ b/medcat-v2/medcat/components/addons/meta_cat/models.py @@ -145,6 +145,7 @@ def __init__(self, config: ConfigMetaCAT, self.config = config self.bert = bert + self.bert_config = _bertconfig self.num_labels = config.model.nclasses for param in self.bert.parameters(): param.requires_grad = not config.model.model_freeze_layers From 533c9583d529ab9a8c75ff659ab4f677a8a2ffe4 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 13:46:36 +0100 Subject: [PATCH 4/9] CU-8699wc4zb: Add test to make sure BERT MetaCATs can be loaded when offline --- .../addons/meta_cat/test_bert_meta_cat.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py diff --git a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py new file mode 100644 index 00000000..4770d8f8 --- /dev/null +++ b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py @@ -0,0 +1,55 @@ +import socket +from contextlib import contextmanager + +from medcat.components.addons.meta_cat import meta_cat +from medcat.storage.serialisers import serialise, deserialise + +import unittest +import tempfile +import os + +from .test_meta_cat import FakeTokenizer + + +@contextmanager +def no_network(): + real_socket = socket.socket + + def guard(*args, **kwargs): + raise OSError("Network disabled for test") + + socket.socket = guard + try: + yield + finally: + socket.socket = real_socket + + +class BERTMetaCATTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.cnf = meta_cat.ConfigMetaCAT() + cls.cnf.model.model_name = 'bert' + cls.cnf.general.vocab_size = 10 + cls.cnf.model.padding_idx = 5 + cls.cnf.general.tokenizer_name = 'bert-tokenizer' + cls.cnf.model.model_variant = 'prajjwal1/bert-tiny' + cls.cnf.general.category_name = 'FAKE_category' + cls.cnf.general.category_value2id = { + 'Future': 0, 'Past': 2, 'Recent': 1} + cls.tokenizer = FakeTokenizer() + cls.meta_cat = meta_cat.MetaCATAddon.create_new(cls.cnf, cls.tokenizer) + + cls.temp_dir = tempfile.TemporaryDirectory() + cls.mc_save_path = os.path.join(cls.temp_dir.name, "bert_meta_cat") + serialise('dill', cls.meta_cat, cls.mc_save_path) + + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + + def test_no_network_load(self): + with no_network(): + mc = deserialise(self.mc_save_path) + self.assertIsInstance(mc, meta_cat.MetaCATAddon) From ae1f6a6256fc91e41780cace3d50617019e3eeeb Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 14:30:29 +0100 Subject: [PATCH 5/9] CU-8699wc4zb: Update Bert MetaCAT online test to include checking that online calls were made --- .../tests/components/addons/meta_cat/test_bert_meta_cat.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py index 4770d8f8..d98c03a5 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py @@ -12,10 +12,12 @@ @contextmanager -def no_network(): +def assert_tries_network(): real_socket = socket.socket + calls = [] def guard(*args, **kwargs): + calls.append((len(args), len(kwargs))) raise OSError("Network disabled for test") socket.socket = guard @@ -23,6 +25,7 @@ def guard(*args, **kwargs): yield finally: socket.socket = real_socket + assert calls, "No network calls were made during the test" class BERTMetaCATTests(unittest.TestCase): @@ -50,6 +53,6 @@ def tearDownClass(cls): cls.temp_dir.cleanup() def test_no_network_load(self): - with no_network(): + with assert_tries_network(): mc = deserialise(self.mc_save_path) self.assertIsInstance(mc, meta_cat.MetaCATAddon) From b45b6512d4e168bbbf12f05769c48fd94dc0c727 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 14:38:57 +0100 Subject: [PATCH 6/9] CU-8699wc4zb: Force HF model download during test time --- .../addons/meta_cat/test_bert_meta_cat.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py index d98c03a5..0e82610a 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_bert_meta_cat.py @@ -7,6 +7,9 @@ import unittest import tempfile import os +from functools import partial + +import transformers from .test_meta_cat import FakeTokenizer @@ -28,6 +31,22 @@ def guard(*args, **kwargs): assert calls, "No network calls were made during the test" +# NOTE: need to disable the usage of the cache +# otherwise other parts of the test suite +# might have already downloaded and cached +# the model and no network calls may be made +# in such a situation +@contextmanager +def force_hf_download(): + orig_from_pretrained = transformers.BertModel.from_pretrained + transformers.BertModel.from_pretrained = partial( + orig_from_pretrained, force_download=True) + try: + yield + finally: + transformers.BertModel.from_pretrained = orig_from_pretrained + + class BERTMetaCATTests(unittest.TestCase): @classmethod @@ -54,5 +73,6 @@ def tearDownClass(cls): def test_no_network_load(self): with assert_tries_network(): - mc = deserialise(self.mc_save_path) + with force_hf_download(): + mc = deserialise(self.mc_save_path) self.assertIsInstance(mc, meta_cat.MetaCATAddon) From 714c8c7639143d429892808584d819451e5042f2 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 15:23:57 +0100 Subject: [PATCH 7/9] CU-8699wc4zb: Make sure not to overwrite save_dir_path when loading MetaCAT --- medcat-v2/medcat/components/addons/meta_cat/meta_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py index b63d3dd4..483af253 100644 --- a/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py +++ b/medcat-v2/medcat/components/addons/meta_cat/meta_cat.py @@ -298,7 +298,7 @@ def get_init_attrs(cls) -> list[str]: @classmethod def ignore_attrs(cls) -> list[str]: - return ['model'] + return ['model', 'save_dir_path'] @classmethod def include_properties(cls) -> list[str]: From 15979a1b9ec272fbcbf942b2381969d289727a2c Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 15:24:43 +0100 Subject: [PATCH 8/9] CU-8699wc4zb: Simplify MetaCAT save_dir_path checking test --- .../components/addons/meta_cat/test_meta_cat.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py index 75f0126f..9543af6f 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py @@ -116,16 +116,12 @@ def test_loading_uses_save_dir_path(self): COMPONENTS_FOLDER, self.meta_cat.get_folder_name() ) - with unittest.mock.patch.object( - meta_cat.MetaCAT, "__init__", - wraps=meta_cat.MetaCAT.__init__, - autospec=True) as mock_load: - CAT.load_model_pack(file_name) - mock_load.assert_called_once() - _, call_kwargs = mock_load.call_args - self.assertIn('save_dir_path', call_kwargs) - self.assertEqual( - call_kwargs['save_dir_path'], exp_meta_cat_path) + cat = CAT.load_model_pack(file_name) + meta_cats = cat.get_addons_of_type(meta_cat.MetaCATAddon) + self.assertEqual(len(meta_cats), 1) + mc = meta_cats[0] + self.assertIsNotNone(mc.mc.save_dir_path) + self.assertEquals(mc.mc.save_dir_path, exp_meta_cat_path) def test_turns_up_in_output(self): ents = self.cat.get_entities( From b960ec73aca2d08c6afb7163f97c20b2e93556b5 Mon Sep 17 00:00:00 2001 From: mart-r Date: Tue, 12 Aug 2025 15:38:59 +0100 Subject: [PATCH 9/9] CU-8699wc4zb: Fix typo in assert method --- medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py index 9543af6f..f638246c 100644 --- a/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py +++ b/medcat-v2/tests/components/addons/meta_cat/test_meta_cat.py @@ -121,7 +121,7 @@ def test_loading_uses_save_dir_path(self): self.assertEqual(len(meta_cats), 1) mc = meta_cats[0] self.assertIsNotNone(mc.mc.save_dir_path) - self.assertEquals(mc.mc.save_dir_path, exp_meta_cat_path) + self.assertEqual(mc.mc.save_dir_path, exp_meta_cat_path) def test_turns_up_in_output(self): ents = self.cat.get_entities(