Skip to content
Merged
61 changes: 61 additions & 0 deletions medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from concurrent.futures import ProcessPoolExecutor, as_completed, Future
import itertools
from contextlib import contextmanager
from collections import deque

import shutil
import zipfile
Expand Down Expand Up @@ -318,6 +319,57 @@ def _mp_one_batch_per_process(
# Yield all results from this batch
yield from cur_results

def save_entities_multi_texts(
self,
texts: Union[Iterable[str], Iterable[tuple[str, str]]],
save_dir_path: str,
only_cui: bool = False,
n_process: int = 1,
batch_size: int = -1,
batch_size_chars: int = 1_000_000,
batches_per_save: int = 20,
) -> None:
"""Saves the resulting entities on disk and allows multiprocessing.

This uses `get_entities_multi_texts` under the hood. But it is designed
to save the data on disk as it comes through.

Args:
texts (Union[Iterable[str], Iterable[tuple[str, str]]]):
The input text. Either an iterable of raw text or one
with in the format of `(text_index, text)`.
save_dir_path (str):
The path where the results are saved. The directory will have
a `annotated_ids.pickle` file containing the
`tuple[list[str], int]` with a list of indices already saved
and the number of parts already saved. In addition there will
be (usually multuple) files in the `part_<num>.pickle` format
with the partial outputs.
only_cui (bool):
Whether to only return CUIs rather than other information
like start/end and annotated value. Defaults to False.
n_process (int):
Number of processes to use. Defaults to 1.
The number of texts to batch at a time. A batch of the
specified size will be given to each worker process.
Defaults to -1 and in this case the character count will
be used instead.
batch_size_chars (int):
The maximum number of characters to process in a batch.
Each process will be given batch of texts with a total
number of characters not exceeding this value. Defaults
to 1,000,000 characters. Set to -1 to disable.
"""
if save_dir_path is None:
raise ValueError("Need to specify a save path (`save_dir_path`), "
f"got {save_dir_path}")
out_iter = self.get_entities_multi_texts(
texts, only_cui=only_cui, n_process=n_process,
batch_size=batch_size, batch_size_chars=batch_size_chars,
save_dir_path=save_dir_path, batches_per_save=batches_per_save)
# NOTE: not keeping anything since it'll be saved on disk
deque(out_iter, maxlen=0)

def get_entities_multi_texts(
self,
texts: Union[Iterable[str], Iterable[tuple[str, str]]],
Expand Down Expand Up @@ -376,6 +428,15 @@ def get_entities_multi_texts(
saver = BatchAnnotationSaver(save_dir_path, batches_per_save)
else:
saver = None
yield from self._get_entities_multi_texts(
n_process=n_process, batch_iter=batch_iter, saver=saver)

def _get_entities_multi_texts(
self,
n_process: int,
batch_iter: Iterator[list[tuple[str, str, bool]]],
saver: Optional[BatchAnnotationSaver],
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
if n_process == 1:
# just do in series
for batch in batch_iter:
Expand Down
28 changes: 26 additions & 2 deletions medcat-v2/tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,14 @@ def _do_mp_run_with_save(
for name in self.cdb.name2info
for negname in self.cdb.name2info if name != negname
]
out_data = list(self.cat.get_entities_multi_texts(
out_data = self.cat.get_entities_multi_texts(
in_data,
save_dir_path=save_to,
batch_size_chars=chars_per_batch,
batches_per_save=batches_per_save,
n_process=n_process,
))
)
out_data = list(out_data)
out_dict_all = {
key: cdata for key, cdata in out_data
}
Expand Down Expand Up @@ -658,6 +659,29 @@ def test_mp_saves_correct_data_with_3_proc(self):
self.assert_correct_loaded_output(
in_data, out_dict_all, all_loaded_output)

def test_get_entities_multi_texts_with_save_dir_lazy(self):
texts = ["text1", "text2"]
with tempfile.TemporaryDirectory() as tmp_dir:
out = self.cat.get_entities_multi_texts(
texts,
save_dir_path=tmp_dir)
# nothing before manual iter
self.assertFalse(os.listdir(tmp_dir))
out_list = list(out)
# something was saved
self.assertTrue(os.listdir(tmp_dir))
# and something was yielded
self.assertEqual(len(out_list), len(texts))

def test_save_entities_multi_texts(self):
texts = ["text1", "text2"]
with tempfile.TemporaryDirectory() as tmp_dir:
self.cat.save_entities_multi_texts(
texts,
save_dir_path=tmp_dir)
# stuff was already saved
self.assertTrue(os.listdir(tmp_dir))


class CATWithDocAddonTests(CATIncludingTests):
EXAMPLE_TEXT = "Example text to tokenize"
Expand Down
11 changes: 6 additions & 5 deletions medcat-v2/tests/utils/ner/test_deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,15 @@ def test_model_works_deid_text_redact(self):
self.assert_deid_redact(anon_text)

def test_model_works_deid_multi_text_single_threaded(self):
processed = self.deid_model.deid_multi_text([input_text, input_text], n_process=1)
processed = self.deid_model.deid_multi_texts([input_text, input_text],
n_process=1)
self.assertEqual(len(processed), 2)
for anon_text in processed:
self.assert_deid_annotations(anon_text)

def test_model_works_deid_multi_text_single_threaded_redact(self):
processed = self.deid_model.deid_multi_text([input_text, input_text],
n_process=1, redact=True)
processed = self.deid_model.deid_multi_texts([input_text, input_text],
n_process=1, redact=True)
self.assertEqual(len(processed), 2)
for anon_text in processed:
self.assert_deid_redact(anon_text)
Expand All @@ -229,7 +230,7 @@ def test_model_works_deid_multi_text_single_threaded_redact(self):
@unittest.skip("Deid Multiprocess is broken. Exits the process, no errors shown")
def test_model_can_multiprocess_no_redact(self):

processed = self.deid_model.deid_multi_text(
processed = self.deid_model.deid_multi_texts(
[input_text, input_text], n_process=2)
self.assertEqual(len(processed), 2)
for tid, new_text in enumerate(processed):
Expand All @@ -245,7 +246,7 @@ def test_model_can_multiprocess_redact(self):
"""
try:
print("Calling test_model_can_multiprocess_redact")
processed = self.deid_model.deid_multi_text(
processed = self.deid_model.deid_multi_texts(
[input_text, input_text], n_process=2, redact=True
)
print("Finished processing")
Expand Down
Loading