diff --git a/Dockerfile b/Dockerfile index 256b009..217df14 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,6 +17,11 @@ WORKDIR /work # install python packages COPY requirements.in . +RUN pip install -r requirements.in +RUN pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz +RUN python -m spacy download en_core_web_sm +RUN python -m spacy download en_core_web_md + # add the code as the final step so that when we modify the code # we don't bust the cached layers holding the dependencies and # system packages. @@ -25,9 +30,4 @@ COPY scripts/ scripts/ COPY tests/ tests/ COPY .flake8 .flake8 -RUN pip install -r requirements.in -RUN pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz -RUN python -m spacy download en_core_web_sm -RUN python -m spacy download en_core_web_md - CMD [ "/bin/bash" ] diff --git a/requirements.in b/requirements.in index 5d099ce..4009dd4 100644 --- a/requirements.in +++ b/requirements.in @@ -16,6 +16,7 @@ pytest-cov flake8 black mypy +types-requests # Required for releases. twine diff --git a/scispacy/abbreviation.py b/scispacy/abbreviation.py index 732b068..cf5e4aa 100644 --- a/scispacy/abbreviation.py +++ b/scispacy/abbreviation.py @@ -151,14 +151,22 @@ class AbbreviationDetector: nlp: `Language`, a required argument for spacy to use this as a factory name: `str`, a required argument for spacy to use this as a factory + make_serializable: `bool`, a required argument for whether we want to use the serializable + or non serializable version. """ - def __init__(self, nlp: Language, name: str = "abbreviation_detector") -> None: + def __init__( + self, + nlp: Language, + name: str = "abbreviation_detector", + make_serializable: bool = False, + ) -> None: Doc.set_extension("abbreviations", default=[], force=True) Span.set_extension("long_form", default=None, force=True) self.matcher = Matcher(nlp.vocab) self.matcher.add("parenthesis", [[{"ORTH": "("}, {"OP": "+"}, {"ORTH": ")"}]]) + self.make_serializable = make_serializable self.global_matcher = Matcher(nlp.vocab) def find(self, span: Span, doc: Doc) -> Tuple[Span, Set[Span]]: @@ -186,6 +194,12 @@ def __call__(self, doc: Doc) -> Doc: for short in short_forms: short._.long_form = long_form doc._.abbreviations.append(short) + if self.make_serializable: + abbreviations = doc._.abbreviations + doc._.abbreviations = [ + self.make_short_form_serializable(abbreviation) + for abbreviation in abbreviations + ] return doc def find_matches_for( @@ -223,3 +237,24 @@ def find_matches_for( self.global_matcher.remove(key) return list((k, v) for k, v in all_occurences.items()) + + def make_short_form_serializable(self, abbreviation: Span): + """ + Converts the abbreviations into a short form that is serializable to enable multiprocessing + + Parameters + ---------- + abbreviation: Span + The abbreviation span identified by the detector + """ + long_form = abbreviation._.long_form + abbreviation._.long_form = long_form.text + serializable_abbr = { + "short_text": abbreviation.text, + "short_start": abbreviation.start, + "short_end": abbreviation.end, + "long_text": long_form.text, + "long_start": long_form.start, + "long_end": long_form.end, + } + return serializable_abbr diff --git a/scispacy/candidate_generation.py b/scispacy/candidate_generation.py index 30e6fa8..7bcf934 100644 --- a/scispacy/candidate_generation.py +++ b/scispacy/candidate_generation.py @@ -384,7 +384,11 @@ def create_tfidf_ann_index( # Default values resulted in very low recall. # set to the maximum recommended value. Improves recall at the expense of longer indexing time. - # TODO: This variable name is so hot because I don't actually know what this parameter does. + # We use the HNSW (Hierarchical Navigable Small World Graph) representation which is constructed + # by consecutive insertion of elements in a random order by connecting them to M closest neighbours + # from the previously inserted elements. These later become bridges between the network hubs that + # improve overall graph connectivity. (bigger M -> higher recall, slower creation) + # For more details see: https://arxiv.org/pdf/1603.09320.pdf? m_parameter = 100 # `C` for Construction. Set to the maximum recommended value # Improves recall at the expense of longer indexing time diff --git a/scispacy/custom_tokenizer.py b/scispacy/custom_tokenizer.py index 4731805..3b5ca60 100644 --- a/scispacy/custom_tokenizer.py +++ b/scispacy/custom_tokenizer.py @@ -131,6 +131,6 @@ def combined_rule_tokenizer(nlp: Language) -> Tokenizer: prefix_search=prefix_re.search, suffix_search=suffix_re.search, infix_finditer=infix_re.finditer, - token_match=nlp.tokenizer.token_match, + token_match=nlp.tokenizer.token_match, # type: ignore ) return tokenizer diff --git a/scispacy/linking.py b/scispacy/linking.py index 974bbea..1e247bd 100644 --- a/scispacy/linking.py +++ b/scispacy/linking.py @@ -96,20 +96,23 @@ def __init__( self.umls = self.kb def __call__(self, doc: Doc) -> Doc: - mentions = [] + mention_strings = [] if self.resolve_abbreviations and Doc.has_extension("abbreviations"): - + # TODO: This is possibly sub-optimal - we might + # prefer to look up both the long and short forms. for ent in doc.ents: - # TODO: This is possibly sub-optimal - we might - # prefer to look up both the long and short forms. - if ent._.long_form is not None: - mentions.append(ent._.long_form) + if isinstance(ent._.long_form, Span): + # Long form + mention_strings.append(ent._.long_form.text) + elif isinstance(ent._.long_form, str): + # Long form + mention_strings.append(ent._.long_form) else: - mentions.append(ent) + # no abbreviations case + mention_strings.append(ent.text) else: - mentions = doc.ents + mention_strings = [x.text for x in doc.ents] - mention_strings = [x.text for x in mentions] batch_candidates = self.candidate_generator(mention_strings, self.k) for mention, candidates in zip(doc.ents, batch_candidates): diff --git a/tests/conftest.py b/tests/conftest.py index 68622fb..705dfe1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import os import pytest @@ -8,6 +8,7 @@ from scispacy.custom_sentence_segmenter import pysbd_sentencizer from scispacy.custom_tokenizer import combined_rule_tokenizer, combined_rule_prefixes, remove_new_lines +from scispacy.abbreviation import AbbreviationDetector LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {} @@ -19,6 +20,7 @@ def get_spacy_model( ner: bool, with_custom_tokenizer: bool = False, with_sentence_segmenter: bool = False, + with_serializable_abbreviation_detector: Optional[bool] = None, ) -> SpacyModelType: """ In order to avoid loading spacy models repeatedly, @@ -26,7 +28,7 @@ def get_spacy_model( we used to create the spacy model, so any particular configuration only gets loaded once. """ - options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter) + options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter, with_serializable_abbreviation_detector) if options not in LOADED_SPACY_MODELS: disable = ["vectors", "textcat"] if not pos_tags: @@ -46,6 +48,8 @@ def get_spacy_model( spacy_model.tokenizer = combined_rule_tokenizer(spacy_model) if with_sentence_segmenter: spacy_model.add_pipe("pysbd_sentencizer", first=True) + if with_serializable_abbreviation_detector is not None: + spacy_model.add_pipe("abbreviation_detector", config={"make_serializable": with_serializable_abbreviation_detector}) LOADED_SPACY_MODELS[options] = spacy_model return LOADED_SPACY_MODELS[options] @@ -97,9 +101,13 @@ def test_model_dir(): @pytest.fixture() def combined_all_model_fixture(): - nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False) + nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=True) return nlp +@pytest.fixture() +def combined_all_model_fixture_non_serializable_abbrev(): + nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=False) + return nlp @pytest.fixture() def combined_rule_prefixes_fixture(): diff --git a/tests/custom_tests/test_all_model.py b/tests/custom_tests/test_all_model.py index 56991a3..1da5d0e 100644 --- a/tests/custom_tests/test_all_model.py +++ b/tests/custom_tests/test_all_model.py @@ -3,6 +3,7 @@ import spacy from spacy.vocab import Vocab import shutil +import pytest def test_custom_segmentation(combined_all_model_fixture): @@ -36,6 +37,31 @@ def test_custom_segmentation(combined_all_model_fixture): ] actual_tokens = [t.text for t in doc] assert expected_tokens == actual_tokens - assert doc.is_parsed + assert doc.has_annotation("DEP") assert doc[0].dep_ == "ROOT" assert doc[0].tag_ == "NN" + +def test_full_pipe_serializable(combined_all_model_fixture): + text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." + doc = [doc for doc in combined_all_model_fixture.pipe([text, text], n_process = 2)][0] + # If we got here this means that both model is serializable and there is an abbreviation that would break if it wasn't + assert len(doc._.abbreviations) > 0 + abbrev = doc._.abbreviations[0] + assert abbrev["short_text"] == "CEIL" + assert abbrev["long_text"] == "cytokine expression in leukocytes" + assert doc[abbrev["short_start"] : abbrev["short_end"]].text == abbrev["short_text"] + assert doc[abbrev["long_start"] : abbrev["long_end"]].text == abbrev["long_text"] + +def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev): + text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." + # This line requires the pipeline to be serializable, so the test should fail here + doc = combined_all_model_fixture_non_serializable_abbrev(text) + with pytest.raises(TypeError): + doc.to_bytes() + +# Below is the test version to be used once we move to spacy v3.1.0 or higher +# def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev): +# text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes." +# # This line requires the pipeline to be serializable (because it uses 2 processes), so the test should fail here +# with pytest.raises(TypeError): +# list(combined_all_model_fixture_non_serializable_abbrev.pipe([text, text], n_process = 2)) \ No newline at end of file