Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Linker Ensemble #53

Merged
merged 25 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3d51b2b
:sparkles: Add regen wikification (#44)
GabrielePicco Dec 28, 2022
df6b2ca
:bookmark: v0.0.7
GabrielePicco Jan 4, 2023
311e1cf
:bug: Fix bug in displacy serve method (#46)
marmg Jan 4, 2023
2853a26
:bug: Fix bug in zshot_evaluate import (#49)
marmg Feb 21, 2023
d61625e
linkers and decriptions ensemble
Feb 9, 2023
fb4b4d9
Refactor ensembler
marmg Feb 13, 2023
5ddd547
Remove ensembling notebook from gitignore
marmg Feb 13, 2023
a5645b7
:bug: Fix travis
marmg Feb 13, 2023
66a9437
:sparkles: Added Linker ensembler
marmg Mar 10, 2023
4c84399
:sparkles: Added Linker Ensemble
marmg Mar 24, 2023
4cf93e0
Delete .travis.yml
marmg Mar 24, 2023
302b756
:lock: Fix setuptools vulnerability
marmg Mar 24, 2023
0c94c4b
:white_check_mark: Fixed tests
marmg Mar 24, 2023
cd9c1e5
:white_check_mark: Fixed tests
marmg Mar 24, 2023
a80a780
:bug: Fix deprecation in load_ontonotes. Fix bug in run_evaluation. U…
marmg Mar 13, 2023
9946dd1
:white_check_mark: Added tests to improve coverage
marmg Mar 28, 2023
0486ec5
Update python-tests.yml
marmg Mar 28, 2023
e73a9d5
:white_check_mark: Added tests teardown and improve memory usage
marmg Mar 28, 2023
46e6565
:white_check_mark: Added tests teardown and improve memory usage
marmg Mar 28, 2023
93308c0
:white_check_mark: Added tests teardown and improve memory usage
marmg Mar 28, 2023
a3c719d
Update python-tests.yml
marmg Mar 30, 2023
1239642
:white_check_mark: Replace linkers in test_ensemble_linker with dummy…
marmg Mar 30, 2023
98784db
:white_check_mark: Improve coverage
marmg Mar 30, 2023
b2cceb4
:white_check_mark: Fix merge main
marmg Mar 30, 2023
1745b5b
:art: Fix flake
marmg Mar 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ jobs:
python -m spacy download en_core_web_sm
- name: Test with pytest
run: |
python -m pytest --cov -v --cov-report xml:/home/runner/coverage.xml
python -m pytest --cov -v --cov-report xml:/home/runner/coverage.xml
timeout-minutes: 30
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3.1.1
with:
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pytest>=7.0
pytest-cov>=3.0.0
setuptools~=60.0.0
setuptools>=65.5.1
flair==0.11.3
flake8>=4.0.1
coverage>=6.4.1
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
"spacy>=3.4.1",
"requests>=2.28",
"tqdm>=4.62.3",
"setuptools~=60.0.0", # Needed to install dynamic packages from source (e.g. Blink)
"setuptools>=65.5.1", # Needed to install dynamic packages from source (e.g. Blink)
"prettytable>=3.4",
"torch>=1",
"torch>=1,<2",
"transformers>=4.20",
"datasets>=2.9.1",
"evaluate>=0.3.0",
Expand Down
1 change: 1 addition & 0 deletions zshot/linker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from zshot.linker.linker import Linker # noqa: F401
from zshot.linker.linker_smxm import LinkerSMXM # noqa: F401
from zshot.linker.linker_tars import LinkerTARS # noqa: F401
from zshot.linker.linker_ensemble import LinkerEnsemble # noqa: F401
1 change: 1 addition & 0 deletions zshot/linker/linker_ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from zshot.linker.linker_ensemble.linker_ensemble import LinkerEnsemble # noqa: F401
92 changes: 92 additions & 0 deletions zshot/linker/linker_ensemble/linker_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Iterator, Optional, List

from spacy.tokens import Doc

from zshot.linker import Linker
from zshot.linker import LinkerSMXM
from zshot.utils.ensembler import Ensembler
from zshot.utils.data_models import Entity
from zshot.linker.linker_ensemble.utils import sub_span_scoring_per_description, get_enhance_entities


class LinkerEnsemble(Linker):
def __init__(self,
linkers: Optional[List[Linker]] = None,
strategy: Optional[str] = 'max',
threshold: Optional[float] = 0.5):
""" Ensemble of linkers and entities to improve performance.
Each combination of linker with entity will be a voter.

:param linkers: Linkers to use in the ensemble
:param strategy: Strategy to use. Options: max; count
When `max` choose the label with max total vote score
When `count` choose the label with max total vote count
:param threshold: Threshold to use. Proportion of voters voting the entity
"""
super(LinkerEnsemble, self).__init__()
if linkers is not None:
self.linkers = linkers
else:
# default options
self.linkers = [
LinkerSMXM()
]
self.enhance_entities = []
self.strategy = strategy
self.threshold = threshold
self.ensembler = None

def set_smxm_model(self, smxm_model):
for linker in self.linkers:
if isinstance(linker, LinkerSMXM):
linker.model_name = smxm_model

def set_kg(self, entities: Iterator[Entity]):
"""
Set entities that linker can use
:param entities: The list of entities
"""
super().set_kg(entities)
self.enhance_entities = get_enhance_entities(self.entities)
self.ensembler = Ensembler(len(self.linkers),
len(self.enhance_entities) if self.enhance_entities is not None else -1,
threshold=self.threshold)
for linker in self.linkers:
linker.set_kg(entities)

def predict(self, docs: Iterator[Doc], batch_size=None):
"""
Perform the entity prediction
:param docs: A list of spacy Document
:param batch_size: The batch size
:return: List Spans for each Document in docs
"""
spans = []
for entities in self.enhance_entities:
self.set_kg(entities)
for linker in self.linkers:
span_prediction = linker.predict(docs, batch_size)
spans.append(span_prediction)

return self.prediction_ensemble(spans)

def prediction_ensemble(self, spans):
doc_ensemble_spans = []
num_doc = len(spans[0])
for doc_idx in range(num_doc):
union_spans = {}
span_per_descriptions = []
for span in spans:
span_per_descriptions.append(span[doc_idx])
for s in span[doc_idx]:
span_pos = (s.start, s.end)
if span_pos not in union_spans:
union_spans[span_pos] = [s]
else:
union_spans[span_pos].append(s)

sub_span_scoring_per_description(union_spans, span_per_descriptions)
all_union_spans = self.ensembler.ensemble(union_spans)
doc_ensemble_spans.append(all_union_spans)

return doc_ensemble_spans
31 changes: 31 additions & 0 deletions zshot/linker/linker_ensemble/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import random

from zshot.utils.data_models import Span


def sub_span_scoring_per_description(union_spans, spans):
for k in union_spans.keys():
for span in spans:
labels = {}
for p in span:
if k[0] <= p.start and k[1] >= p.end:
if k[0] < p.start or k[1] > p.end:
if p.label not in labels:
labels[p.label] = p
elif labels[p.label].score < p.score:
labels[p.label] = p
for p in labels.values():
union_spans[k].append(Span(label=p.label, score=p.score, start=k[0], end=k[1]))


def normalize_group(group, require_length):
group.extend(random.choices(group, k=require_length - len(group)))


def get_enhance_entities(entities):
entities_groups = [[ent for ent in entities if ent.name == name] for name in set([ent.name for ent in entities])]
max_length = max([len(group) for group in entities_groups])
for group in entities_groups:
normalize_group(group, max_length)

return [list(g) for g in zip(*entities_groups)]
3 changes: 0 additions & 3 deletions zshot/linker/linker_regen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle
from typing import Dict, List

import pytest
from huggingface_hub import hf_hub_download

from zshot.linker.linker_regen.trie import Trie
Expand All @@ -25,7 +24,6 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter):
right_index = min(len(sent_list), end_delimiter_index + half_context + (
half_context - (start_delimiter_index - left_index)))
left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index)))
print(len(sent_list[left_index:right_index]))
return " ".join(sent_list[left_index:right_index])


Expand All @@ -42,7 +40,6 @@ def load_wikipedia_trie() -> Trie:
return wikipedia_trie


@pytest.mark.skip(reason="Too expensive to run on every commit")
def load_wikipedia_mapping() -> Dict[str, str]:
"""
Load the wikipedia trie from the HB hub
Expand Down
22 changes: 11 additions & 11 deletions zshot/linker/linker_tars.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ def set_kg(self, entities: Iterator[Entity]):

:param entities: New entities to use
"""
old_entities = self._entities
old_entities = self.entities
super().set_kg(entities)
self.flat_entities()
if old_entities != entities:
self.flat_entities()
self.task = f'zshot.ner.{hash(tuple(self._entities))}'
self.task = f'zshot.ner.{hash(tuple(self.entities))}'
if not self.model:
self.load_models()
self.model.add_and_switch_to_new_task(self.task,
self._entities, label_type='ner')
self.entities, label_type='ner')

def flat_entities(self):
""" As TARS use only the labels, take just the name of the entities and not the description """
if isinstance(self._entities, dict):
self._entities = list(self._entities.keys())
if isinstance(self._entities, list):
self._entities = [e.name if type(e) == Entity else e for e in self._entities]
if self._entities is None:
if isinstance(self.entities, dict):
self._entities = list(self.entities.keys())
if isinstance(self.entities, list):
self._entities = [e.name if type(e) == Entity else e for e in self.entities]
if self.entities is None:
self._entities = []

def load_models(self):
Expand All @@ -65,9 +65,9 @@ def load_models(self):
self.task = self.default_entities
else:
self.flat_entities()
self.task = f'zshot.ner.{hash(tuple(self._entities))}'
self.task = f'zshot.ner.{hash(tuple(self.entities))}'
self.model.add_and_switch_to_new_task(self.task,
self._entities, label_type='ner')
self.entities, label_type='ner')

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
52 changes: 52 additions & 0 deletions zshot/tests/linker/test_ensemble_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import spacy

from zshot import PipelineConfig
from zshot.linker.linker_ensemble import LinkerEnsemble
from zshot.tests.linker.test_linker import DummyLinkerEnd2End
from zshot.utils.data_models import Entity


def test_ensemble_linker_max():
nlp = spacy.blank("en")
nlp.add_pipe("zshot", config=PipelineConfig(
entities=[
Entity(name="fruits", description="The sweet and fleshy product of a tree or other plant."),
Entity(name="fruits", description="Names of fruits such as banana, oranges")
],
linker=LinkerEnsemble(
linkers=[
DummyLinkerEnd2End(),
DummyLinkerEnd2End(),
]
)
), last=True)
doc = nlp('Apple is a company name not a fruits like apples or orange')
assert "zshot" in nlp.pipe_names
assert len(doc.ents) > 0
assert len(doc._.spans) > 0
assert all([bool(ent.label_) for ent in doc.ents])
del doc, nlp


def test_ensemble_linker_count():
nlp = spacy.blank("en")
nlp.add_pipe("zshot", config=PipelineConfig(
entities=[
Entity(name="fruits", description="The sweet and fleshy product of a tree or other plant."),
Entity(name="fruits", description="Names of fruits such as banana, oranges")
],
linker=LinkerEnsemble(
linkers=[
DummyLinkerEnd2End(),
DummyLinkerEnd2End(),
],
strategy='count'
)
), last=True)

doc = nlp('Apple is a company name not a fruits like apples or orange')
assert "zshot" in nlp.pipe_names
assert len(doc.ents) > 0
assert len(doc._.spans) > 0
assert all([bool(ent.label_) for ent in doc.ents])
del doc, nlp
2 changes: 1 addition & 1 deletion zshot/tests/linker/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def is_end2end(self) -> bool:
return True

def predict(self, docs: Iterator[Doc], batch_size=None):
return [[Span(0, len(doc.text) - 1, label='label')] for doc in docs]
return [[Span(0, len(doc.text) - 1, label='label', score=0.9)] for doc in docs]


class DummyLinkerWithEntities(Linker):
Expand Down
41 changes: 21 additions & 20 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from zshot import PipelineConfig
from zshot.linker.linker_regen.linker_regen import LinkerRegen
from zshot.linker.linker_regen.trie import Trie
from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia, create_input
from zshot.tests.config import EX_DOCS, EX_ENTITIES
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
from zshot.utils.data_models import Span
Expand Down Expand Up @@ -40,30 +39,15 @@ def test_regen_linker():

doc = nlp(EX_DOCS[1])
assert len(doc.ents) > 0
del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp
del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del doc, nlp, config


def test_regen_linker_pipeline():
nlp = spacy.load("en_core_web_sm")
config = PipelineConfig(
mentions_extractor=MentionsExtractorSpacy(),
linker=LinkerRegen(),
entities=EX_ENTITIES
)
nlp.add_pipe("zshot", config=config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp("")
assert len(doc.ents) == 0
docs = [doc for doc in nlp.pipe(EX_DOCS)]
assert all(len(doc.ents) > 0 for doc in docs)
del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp
del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del docs, nlp, config
del doc, nlp, config


def test_regen_linker_wikification():
Expand All @@ -87,13 +71,30 @@ def test_regen_linker_wikification():
del doc, nlp, config


@pytest.mark.skip(reason="Too expensive to run on every commit")
def test_load_wikipedia_trie():
trie = load_wikipedia_trie()
assert len(list(trie.trie_dict.keys())) == 6952


@pytest.mark.skip(reason="Too expensive to run on every commit")
def test_span_to_wiki():
s = Span(label="Surfing", start=0, end=10)
wiki_links = spans_to_wikipedia([s])
assert len(wiki_links) > 0
assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=")


def test_create_input():
start_delimiter = "[START]"
end_delimiter = "[END]"
max_length = 10

times_rep = 6
sentence = "[START]" + " test" * times_rep + " [END]"
input_sentence = create_input(sentence, max_length, start_delimiter, end_delimiter)
assert input_sentence == sentence
times_rep = 12
sentence = "[START]" + " test" * times_rep + " [END]"
input_sentence = create_input(sentence, max_length, start_delimiter, end_delimiter)
assert input_sentence == " ".join(["test" for i in range(9)])
Loading