diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 2226819d7..cd45d4896 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -31,6 +31,7 @@ from medcat.utils.defaults import doing_legacy_conversion_message from medcat.utils.defaults import LegacyConversionDisabledError from medcat.utils.usage_monitoring import UsageMonitor +from medcat.utils.import_utils import MissingDependenciesError logger = logging.getLogger(__name__) @@ -157,6 +158,25 @@ def _mp_worker_func( self, texts_and_indices: list[tuple[str, str, bool]] ) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]: + # NOTE: this is needed for subprocess as otherwise they wouldn't have + # any of these set + # NOTE: these need to by dynamic in case the extra's aren't included + try: + from medcat.components.addons.meta_cat import MetaCATAddon + has_meta_cat = True + except MissingDependenciesError: + has_meta_cat = False + try: + from medcat.components.addons.relation_extraction.rel_cat import ( + RelCATAddon) + has_rel_cat = True + except MissingDependenciesError: + has_rel_cat = False + for addon in self._pipeline.iter_addons(): + if has_meta_cat and isinstance(addon, MetaCATAddon): + addon._init_data_paths(self._pipeline.tokenizer) + elif has_rel_cat and isinstance(addon, RelCATAddon): + addon._rel_cat._init_data_paths() return [ (text, text_index, self.get_entities(text, only_cui=only_cui)) for text, text_index, only_cui in texts_and_indices] @@ -180,7 +200,7 @@ def _generate_batches_by_char_length( yield docs docs = [] char_count = clen - docs.append((doc_index, doc, only_cui)) + docs.append((doc, doc_index, only_cui)) if len(docs) > 0: yield docs @@ -326,7 +346,7 @@ def get_entities_multi_texts( if n_process == 1: # just do in series for batch in batch_iter: - for text_index, _, result in self._mp_worker_func(batch): + for _, text_index, result in self._mp_worker_func(batch): yield text_index, result return diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 4c8263b54..f945903cc 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -728,7 +728,7 @@ def test_batching_gets_full_char(self): # has all texts self.assertEqual(sum(len(batch) for batch in batches), self.NUM_TEXTS) # has all characters - self.assertEqual(sum(len(text[1]) for text in batches[0]), + self.assertEqual(sum(len(text[0]) for text in batches[0]), self.total_text_length) def test_batching_gets_all_half_at_a_time(self): @@ -746,7 +746,7 @@ def test_batching_gets_all_half_at_a_time(self): # has all texts self.assertEqual(sum(len(batch) for batch in batches), self.NUM_TEXTS) # has all characters - self.assertEqual(sum(len(text[1]) + self.assertEqual(sum(len(text[0]) for batch in batches for text in batch), self.total_text_length)