Skip to content

Commit

Permalink
Merge branch 'allenai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MichalMalyska committed Jan 11, 2022
2 parents 53a0ecc + 3d153dd commit b14e8f0
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 21 deletions.
10 changes: 5 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ WORKDIR /work
# install python packages
COPY requirements.in .

RUN pip install -r requirements.in
RUN pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz
RUN python -m spacy download en_core_web_sm
RUN python -m spacy download en_core_web_md

# add the code as the final step so that when we modify the code
# we don't bust the cached layers holding the dependencies and
# system packages.
Expand All @@ -25,9 +30,4 @@ COPY scripts/ scripts/
COPY tests/ tests/
COPY .flake8 .flake8

RUN pip install -r requirements.in
RUN pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_sm-0.4.0.tar.gz
RUN python -m spacy download en_core_web_sm
RUN python -m spacy download en_core_web_md

CMD [ "/bin/bash" ]
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pytest-cov
flake8
black
mypy
types-requests

# Required for releases.
twine
Expand Down
37 changes: 36 additions & 1 deletion scispacy/abbreviation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ class AbbreviationDetector:
nlp: `Language`, a required argument for spacy to use this as a factory
name: `str`, a required argument for spacy to use this as a factory
make_serializable: `bool`, a required argument for whether we want to use the serializable
or non serializable version.
"""

def __init__(self, nlp: Language, name: str = "abbreviation_detector") -> None:
def __init__(
self,
nlp: Language,
name: str = "abbreviation_detector",
make_serializable: bool = False,
) -> None:
Doc.set_extension("abbreviations", default=[], force=True)
Span.set_extension("long_form", default=None, force=True)

self.matcher = Matcher(nlp.vocab)
self.matcher.add("parenthesis", [[{"ORTH": "("}, {"OP": "+"}, {"ORTH": ")"}]])
self.make_serializable = make_serializable
self.global_matcher = Matcher(nlp.vocab)

def find(self, span: Span, doc: Doc) -> Tuple[Span, Set[Span]]:
Expand Down Expand Up @@ -186,6 +194,12 @@ def __call__(self, doc: Doc) -> Doc:
for short in short_forms:
short._.long_form = long_form
doc._.abbreviations.append(short)
if self.make_serializable:
abbreviations = doc._.abbreviations
doc._.abbreviations = [
self.make_short_form_serializable(abbreviation)
for abbreviation in abbreviations
]
return doc

def find_matches_for(
Expand Down Expand Up @@ -223,3 +237,24 @@ def find_matches_for(
self.global_matcher.remove(key)

return list((k, v) for k, v in all_occurences.items())

def make_short_form_serializable(self, abbreviation: Span):
"""
Converts the abbreviations into a short form that is serializable to enable multiprocessing
Parameters
----------
abbreviation: Span
The abbreviation span identified by the detector
"""
long_form = abbreviation._.long_form
abbreviation._.long_form = long_form.text
serializable_abbr = {
"short_text": abbreviation.text,
"short_start": abbreviation.start,
"short_end": abbreviation.end,
"long_text": long_form.text,
"long_start": long_form.start,
"long_end": long_form.end,
}
return serializable_abbr
6 changes: 5 additions & 1 deletion scispacy/candidate_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,11 @@ def create_tfidf_ann_index(
# Default values resulted in very low recall.

# set to the maximum recommended value. Improves recall at the expense of longer indexing time.
# TODO: This variable name is so hot because I don't actually know what this parameter does.
# We use the HNSW (Hierarchical Navigable Small World Graph) representation which is constructed
# by consecutive insertion of elements in a random order by connecting them to M closest neighbours
# from the previously inserted elements. These later become bridges between the network hubs that
# improve overall graph connectivity. (bigger M -> higher recall, slower creation)
# For more details see: https://arxiv.org/pdf/1603.09320.pdf?
m_parameter = 100
# `C` for Construction. Set to the maximum recommended value
# Improves recall at the expense of longer indexing time
Expand Down
2 changes: 1 addition & 1 deletion scispacy/custom_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ def combined_rule_tokenizer(nlp: Language) -> Tokenizer:
prefix_search=prefix_re.search,
suffix_search=suffix_re.search,
infix_finditer=infix_re.finditer,
token_match=nlp.tokenizer.token_match,
token_match=nlp.tokenizer.token_match, # type: ignore
)
return tokenizer
21 changes: 12 additions & 9 deletions scispacy/linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,23 @@ def __init__(
self.umls = self.kb

def __call__(self, doc: Doc) -> Doc:
mentions = []
mention_strings = []
if self.resolve_abbreviations and Doc.has_extension("abbreviations"):

# TODO: This is possibly sub-optimal - we might
# prefer to look up both the long and short forms.
for ent in doc.ents:
# TODO: This is possibly sub-optimal - we might
# prefer to look up both the long and short forms.
if ent._.long_form is not None:
mentions.append(ent._.long_form)
if isinstance(ent._.long_form, Span):
# Long form
mention_strings.append(ent._.long_form.text)
elif isinstance(ent._.long_form, str):
# Long form
mention_strings.append(ent._.long_form)
else:
mentions.append(ent)
# no abbreviations case
mention_strings.append(ent.text)
else:
mentions = doc.ents
mention_strings = [x.text for x in doc.ents]

mention_strings = [x.text for x in mentions]
batch_candidates = self.candidate_generator(mention_strings, self.k)

for mention, candidates in zip(doc.ents, batch_candidates):
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Tuple, Optional
import os

import pytest
Expand All @@ -8,6 +8,7 @@

from scispacy.custom_sentence_segmenter import pysbd_sentencizer
from scispacy.custom_tokenizer import combined_rule_tokenizer, combined_rule_prefixes, remove_new_lines
from scispacy.abbreviation import AbbreviationDetector

LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool], SpacyModelType] = {}

Expand All @@ -19,14 +20,15 @@ def get_spacy_model(
ner: bool,
with_custom_tokenizer: bool = False,
with_sentence_segmenter: bool = False,
with_serializable_abbreviation_detector: Optional[bool] = None,
) -> SpacyModelType:
"""
In order to avoid loading spacy models repeatedly,
we'll save references to them, keyed by the options
we used to create the spacy model, so any particular
configuration only gets loaded once.
"""
options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter)
options = (spacy_model_name, pos_tags, parse, ner, with_custom_tokenizer, with_sentence_segmenter, with_serializable_abbreviation_detector)
if options not in LOADED_SPACY_MODELS:
disable = ["vectors", "textcat"]
if not pos_tags:
Expand All @@ -46,6 +48,8 @@ def get_spacy_model(
spacy_model.tokenizer = combined_rule_tokenizer(spacy_model)
if with_sentence_segmenter:
spacy_model.add_pipe("pysbd_sentencizer", first=True)
if with_serializable_abbreviation_detector is not None:
spacy_model.add_pipe("abbreviation_detector", config={"make_serializable": with_serializable_abbreviation_detector})

LOADED_SPACY_MODELS[options] = spacy_model
return LOADED_SPACY_MODELS[options]
Expand Down Expand Up @@ -97,9 +101,13 @@ def test_model_dir():

@pytest.fixture()
def combined_all_model_fixture():
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False)
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=True)
return nlp

@pytest.fixture()
def combined_all_model_fixture_non_serializable_abbrev():
nlp = get_spacy_model("en_core_sci_sm", True, True, True, with_custom_tokenizer=True, with_sentence_segmenter=False, with_serializable_abbreviation_detector=False)
return nlp

@pytest.fixture()
def combined_rule_prefixes_fixture():
Expand Down
28 changes: 27 additions & 1 deletion tests/custom_tests/test_all_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import spacy
from spacy.vocab import Vocab
import shutil
import pytest


def test_custom_segmentation(combined_all_model_fixture):
Expand Down Expand Up @@ -36,6 +37,31 @@ def test_custom_segmentation(combined_all_model_fixture):
]
actual_tokens = [t.text for t in doc]
assert expected_tokens == actual_tokens
assert doc.is_parsed
assert doc.has_annotation("DEP")
assert doc[0].dep_ == "ROOT"
assert doc[0].tag_ == "NN"

def test_full_pipe_serializable(combined_all_model_fixture):
text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
doc = [doc for doc in combined_all_model_fixture.pipe([text, text], n_process = 2)][0]
# If we got here this means that both model is serializable and there is an abbreviation that would break if it wasn't
assert len(doc._.abbreviations) > 0
abbrev = doc._.abbreviations[0]
assert abbrev["short_text"] == "CEIL"
assert abbrev["long_text"] == "cytokine expression in leukocytes"
assert doc[abbrev["short_start"] : abbrev["short_end"]].text == abbrev["short_text"]
assert doc[abbrev["long_start"] : abbrev["long_end"]].text == abbrev["long_text"]

def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev):
text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
# This line requires the pipeline to be serializable, so the test should fail here
doc = combined_all_model_fixture_non_serializable_abbrev(text)
with pytest.raises(TypeError):
doc.to_bytes()

# Below is the test version to be used once we move to spacy v3.1.0 or higher
# def test_full_pipe_not_serializable(combined_all_model_fixture_non_serializable_abbrev):
# text = "Induction of cytokine expression in leukocytes (CEIL) by binding of thrombin-stimulated platelets. BACKGROUND: Activated platelets tether and activate myeloid leukocytes."
# # This line requires the pipeline to be serializable (because it uses 2 processes), so the test should fail here
# with pytest.raises(TypeError):
# list(combined_all_model_fixture_non_serializable_abbrev.pipe([text, text], n_process = 2))

0 comments on commit b14e8f0

Please sign in to comment.