diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 43c31f975..d89eafcac 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -6,6 +6,7 @@ from concurrent.futures import ProcessPoolExecutor, as_completed, Future import itertools from contextlib import contextmanager +from collections import deque import shutil import zipfile @@ -318,6 +319,57 @@ def _mp_one_batch_per_process( # Yield all results from this batch yield from cur_results + def save_entities_multi_texts( + self, + texts: Union[Iterable[str], Iterable[tuple[str, str]]], + save_dir_path: str, + only_cui: bool = False, + n_process: int = 1, + batch_size: int = -1, + batch_size_chars: int = 1_000_000, + batches_per_save: int = 20, + ) -> None: + """Saves the resulting entities on disk and allows multiprocessing. + + This uses `get_entities_multi_texts` under the hood. But it is designed + to save the data on disk as it comes through. + + Args: + texts (Union[Iterable[str], Iterable[tuple[str, str]]]): + The input text. Either an iterable of raw text or one + with in the format of `(text_index, text)`. + save_dir_path (str): + The path where the results are saved. The directory will have + a `annotated_ids.pickle` file containing the + `tuple[list[str], int]` with a list of indices already saved + and the number of parts already saved. In addition there will + be (usually multuple) files in the `part_.pickle` format + with the partial outputs. + only_cui (bool): + Whether to only return CUIs rather than other information + like start/end and annotated value. Defaults to False. + n_process (int): + Number of processes to use. Defaults to 1. + The number of texts to batch at a time. A batch of the + specified size will be given to each worker process. + Defaults to -1 and in this case the character count will + be used instead. + batch_size_chars (int): + The maximum number of characters to process in a batch. + Each process will be given batch of texts with a total + number of characters not exceeding this value. Defaults + to 1,000,000 characters. Set to -1 to disable. + """ + if save_dir_path is None: + raise ValueError("Need to specify a save path (`save_dir_path`), " + f"got {save_dir_path}") + out_iter = self.get_entities_multi_texts( + texts, only_cui=only_cui, n_process=n_process, + batch_size=batch_size, batch_size_chars=batch_size_chars, + save_dir_path=save_dir_path, batches_per_save=batches_per_save) + # NOTE: not keeping anything since it'll be saved on disk + deque(out_iter, maxlen=0) + def get_entities_multi_texts( self, texts: Union[Iterable[str], Iterable[tuple[str, str]]], @@ -376,6 +428,15 @@ def get_entities_multi_texts( saver = BatchAnnotationSaver(save_dir_path, batches_per_save) else: saver = None + yield from self._get_entities_multi_texts( + n_process=n_process, batch_iter=batch_iter, saver=saver) + + def _get_entities_multi_texts( + self, + n_process: int, + batch_iter: Iterator[list[tuple[str, str, bool]]], + saver: Optional[BatchAnnotationSaver], + ) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]: if n_process == 1: # just do in series for batch in batch_iter: diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index cb37b648a..7a44e42b9 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -534,13 +534,14 @@ def _do_mp_run_with_save( for name in self.cdb.name2info for negname in self.cdb.name2info if name != negname ] - out_data = list(self.cat.get_entities_multi_texts( + out_data = self.cat.get_entities_multi_texts( in_data, save_dir_path=save_to, batch_size_chars=chars_per_batch, batches_per_save=batches_per_save, n_process=n_process, - )) + ) + out_data = list(out_data) out_dict_all = { key: cdata for key, cdata in out_data } @@ -658,6 +659,29 @@ def test_mp_saves_correct_data_with_3_proc(self): self.assert_correct_loaded_output( in_data, out_dict_all, all_loaded_output) + def test_get_entities_multi_texts_with_save_dir_lazy(self): + texts = ["text1", "text2"] + with tempfile.TemporaryDirectory() as tmp_dir: + out = self.cat.get_entities_multi_texts( + texts, + save_dir_path=tmp_dir) + # nothing before manual iter + self.assertFalse(os.listdir(tmp_dir)) + out_list = list(out) + # something was saved + self.assertTrue(os.listdir(tmp_dir)) + # and something was yielded + self.assertEqual(len(out_list), len(texts)) + + def test_save_entities_multi_texts(self): + texts = ["text1", "text2"] + with tempfile.TemporaryDirectory() as tmp_dir: + self.cat.save_entities_multi_texts( + texts, + save_dir_path=tmp_dir) + # stuff was already saved + self.assertTrue(os.listdir(tmp_dir)) + class CATWithDocAddonTests(CATIncludingTests): EXAMPLE_TEXT = "Example text to tokenize" diff --git a/medcat-v2/tests/utils/ner/test_deid.py b/medcat-v2/tests/utils/ner/test_deid.py index c776f229c..f11cc2dec 100644 --- a/medcat-v2/tests/utils/ner/test_deid.py +++ b/medcat-v2/tests/utils/ner/test_deid.py @@ -213,14 +213,15 @@ def test_model_works_deid_text_redact(self): self.assert_deid_redact(anon_text) def test_model_works_deid_multi_text_single_threaded(self): - processed = self.deid_model.deid_multi_text([input_text, input_text], n_process=1) + processed = self.deid_model.deid_multi_texts([input_text, input_text], + n_process=1) self.assertEqual(len(processed), 2) for anon_text in processed: self.assert_deid_annotations(anon_text) def test_model_works_deid_multi_text_single_threaded_redact(self): - processed = self.deid_model.deid_multi_text([input_text, input_text], - n_process=1, redact=True) + processed = self.deid_model.deid_multi_texts([input_text, input_text], + n_process=1, redact=True) self.assertEqual(len(processed), 2) for anon_text in processed: self.assert_deid_redact(anon_text) @@ -229,7 +230,7 @@ def test_model_works_deid_multi_text_single_threaded_redact(self): @unittest.skip("Deid Multiprocess is broken. Exits the process, no errors shown") def test_model_can_multiprocess_no_redact(self): - processed = self.deid_model.deid_multi_text( + processed = self.deid_model.deid_multi_texts( [input_text, input_text], n_process=2) self.assertEqual(len(processed), 2) for tid, new_text in enumerate(processed): @@ -245,7 +246,7 @@ def test_model_can_multiprocess_redact(self): """ try: print("Calling test_model_can_multiprocess_redact") - processed = self.deid_model.deid_multi_text( + processed = self.deid_model.deid_multi_texts( [input_text, input_text], n_process=2, redact=True ) print("Finished processing")