Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from medcat.utils.defaults import doing_legacy_conversion_message
from medcat.utils.defaults import LegacyConversionDisabledError
from medcat.utils.usage_monitoring import UsageMonitor
from medcat.utils.import_utils import MissingDependenciesError


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,6 +158,25 @@ def _mp_worker_func(
self,
texts_and_indices: list[tuple[str, str, bool]]
) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]:
# NOTE: this is needed for subprocess as otherwise they wouldn't have
# any of these set
# NOTE: these need to by dynamic in case the extra's aren't included
try:
from medcat.components.addons.meta_cat import MetaCATAddon
has_meta_cat = True
except MissingDependenciesError:
has_meta_cat = False
try:
from medcat.components.addons.relation_extraction.rel_cat import (
RelCATAddon)
has_rel_cat = True
except MissingDependenciesError:
has_rel_cat = False
for addon in self._pipeline.iter_addons():
if has_meta_cat and isinstance(addon, MetaCATAddon):
addon._init_data_paths(self._pipeline.tokenizer)
elif has_rel_cat and isinstance(addon, RelCATAddon):
addon._rel_cat._init_data_paths()
return [
(text, text_index, self.get_entities(text, only_cui=only_cui))
for text, text_index, only_cui in texts_and_indices]
Expand All @@ -180,7 +200,7 @@ def _generate_batches_by_char_length(
yield docs
docs = []
char_count = clen
docs.append((doc_index, doc, only_cui))
docs.append((doc, doc_index, only_cui))

if len(docs) > 0:
yield docs
Expand Down Expand Up @@ -326,7 +346,7 @@ def get_entities_multi_texts(
if n_process == 1:
# just do in series
for batch in batch_iter:
for text_index, _, result in self._mp_worker_func(batch):
for _, text_index, result in self._mp_worker_func(batch):
yield text_index, result
return

Expand Down
4 changes: 2 additions & 2 deletions medcat-v2/tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def test_batching_gets_full_char(self):
# has all texts
self.assertEqual(sum(len(batch) for batch in batches), self.NUM_TEXTS)
# has all characters
self.assertEqual(sum(len(text[1]) for text in batches[0]),
self.assertEqual(sum(len(text[0]) for text in batches[0]),
self.total_text_length)

def test_batching_gets_all_half_at_a_time(self):
Expand All @@ -746,7 +746,7 @@ def test_batching_gets_all_half_at_a_time(self):
# has all texts
self.assertEqual(sum(len(batch) for batch in batches), self.NUM_TEXTS)
# has all characters
self.assertEqual(sum(len(text[1])
self.assertEqual(sum(len(text[0])
for batch in batches for text in batch),
self.total_text_length)

Expand Down