In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../")

In [5]:
import tqdm
import os
from sklearn.metrics import classification_report

from src.taskmodules.span_clf_with_gazetteer import Gazetteer, EfficientGazetteer
from pytorch_ie.data.document import LabeledSpan

from transformers import AutoTokenizer

from src.datamodules.datasets.multiconer import load_multiconer

In [6]:
def visualize_documents(documents, annotation_field: str = "entities", show_annotations: bool = True, show_predictions: bool = True):
    for i, document in enumerate(documents):
        print(f"[{i}]" + "=" * 100)
        print(document.text)

        if show_annotations:
            print("Annotations:")
            for entity in sorted(document.annotations(annotation_field), key=lambda ent: ent.start):
                entity_text = document.text[entity.start : entity.end]
                label = entity.label
                print(f"{entity_text} -> {label}")

        if show_predictions:
            print("*" * 25)
            print("Predictions:")
            for entity in sorted(document.predictions(annotation_field), key=lambda ent: ent.start):
                entity_text = document.text[entity.start : entity.end]
                label = entity.label
                print(f"{entity_text} -> {label} [{entity.score}]")

        print("\n")

In [7]:
def predict_with_gazetteer(documents, tokenizer, gazetteer, min_span_length=1, max_span_length=8, max_length=128, score: float = 1.0):
    is_efficient_gazetteer = isinstance(gazetteer, EfficientGazetteer)
    
    for document in tqdm.tqdm(documents):
        text = document.text
        
        inputs = tokenizer(
            document.text,
            padding=False,
            truncation=False,
            max_length=max_length,
            is_split_into_words=False,
            return_offsets_mapping=True,
        )

        seq_length = len(inputs["input_ids"])
        
        for span_length in range(min_span_length, max_span_length + 1):
            
            for start_index in range(seq_length + 1 - span_length):
                end_index = start_index + span_length
                
                if is_efficient_gazetteer:
                    span_text = tuple(inputs["input_ids"][start_index:end_index])
                else:
                    span_text = tokenizer.decode(inputs["input_ids"][start_index:end_index])
                
                labels = gazetteer.lookup(span_text)

                start = inputs["offset_mapping"][start_index][0]
                end = inputs["offset_mapping"][end_index - 1][1]
                
                for label in labels:
                    if label is None:
                        continue
                    document.add_prediction("entities", LabeledSpan(start=start, end=end, label=label))

In [8]:
def classification_report_from_documents(documents, annotation_field: str = "entities"):
    labels = ["CORP", "PROD", "GRP", "CW", "LOC", "PER"]
    label_to_id = {label: i for i, label in enumerate(labels)}
    
    y_true = []
    y_pred = []
    for document in documents:

        entity_annotations = document.annotations(annotation_field)
        entity_predictions = document.predictions(annotation_field)

        span_to_annotations = {}
        for entity in entity_annotations:
            span = (entity.start, entity.end)

            if span not in span_to_annotations:
                span_to_annotations[span] = set()

            span_to_annotations[span].add(entity.label)

        span_to_predictions = {}
        for entity in entity_predictions:
            span = (entity.start, entity.end)

            if span not in span_to_predictions:
                span_to_predictions[span] = set()

            span_to_predictions[span].add(entity.label)

        visited_spans = set()
        for span, annotations in span_to_annotations.items():
            visited_spans.add(span)
            
            y_t = [0] * len(label_to_id)
            y_p = [0] * len(label_to_id)

            for annotation in annotations:
                y_t[label_to_id[annotation]] = 1

            for prediction in span_to_predictions.get(span, []):
                y_p[label_to_id[prediction]] = 1

            y_true.append(y_t)
            y_pred.append(y_p)
        
        for span, predictions in span_to_predictions.items():
            if span in visited_spans:
                continue
            
            visited_spans.add(span)
            
            y_t = [0] * len(label_to_id)
            y_p = [0] * len(label_to_id)
            
            for prediction in predictions:
                y_p[label_to_id[prediction]] = 1
            
            y_true.append(y_t)
            y_pred.append(y_p)
    
    return classification_report(y_true, y_pred, target_names=labels)

In [9]:
MULTI_CONER_DIR = "/home/christoph/Downloads/public_data/"
MULTI_CONER_SPLIT = "train"

TOKENIZER_NAME_OR_PATH = "google/electra-large-discriminator"

GAZETTEERS_DIR = "../data/gazetteers/"
WIKIDATA_ENTITY_ALIASES_PATH = "/home/christoph/Downloads/wikidata5m_alias/wikidata5m_entity.txt"
WIKIDATA_GRAPH_PATH = "/home/christoph/Downloads/wikidata5m_all_triplet.txt"

In [10]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME_OR_PATH)

In [11]:
eval_docs = load_multiconer(
    data_dir=MULTI_CONER_DIR,
    name="en",
    split=MULTI_CONER_SPLIT,
)

gazetteers = [
    "eng-wikidata-CORP.txt",
    "eng-hltcoe-CORP.txt",
    
    "eng-hltcoe-PROD.txt",
    
    "eng-wikidata-GRP.txt",
    "eng-hltcoe-GRP.txt",
    
    "eng-wikidata-CW.txt",
    
    "eng-wikidata-LOC.txt",
    "eng-hltcoe-LOC.txt",
    
    "eng-wikidata-PER.txt",
    "eng-hltcoe-PER.txt",
]

for gazetteer_name in gazetteers:
    gazetteer = Gazetteer(path=os.path.join(GAZETTEERS_DIR, gazetteer_name), lowercase=True)
    predict_with_gazetteer(eval_docs, tokenizer=tokenizer, gazetteer=gazetteer, min_span_length=2, max_span_length=8)

print(classification_report_from_documents(eval_docs))

Using custom data configuration en-b2b539f3793511b0
Reusing dataset multi_co_ner (/home/christoph/.cache/huggingface/datasets/multi_co_ner/en-b2b539f3793511b0/1.0.0/afa61df806aafde79b4bd38aef1a3db19216190e1aa77a223a2d70d1eea327c9)
  0%|                                                                                                                                                                                      | 0/15300 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15300/1

              precision    recall  f1-score   support

        CORP       0.26      0.70      0.38      3111
        PROD       0.13      0.01      0.02      2923
         GRP       0.10      0.47      0.17      3571
          CW       0.14      0.60      0.22      3752
         LOC       0.22      0.46      0.30      4799
         PER       0.49      0.91      0.64      5397

   micro avg       0.21      0.57      0.31     23553
   macro avg       0.22      0.53      0.29     23553
weighted avg       0.24      0.57      0.32     23553
 samples avg       0.20      0.24      0.21     23553



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [12]:
visualize_documents(eval_docs[:10])

his playlist includes sonny sharrock , gza , country teasers and the notorious b.i.g.
Annotations:
sonny sharrock -> PER
gza -> PER
country teasers -> GRP
the notorious b.i.g. -> PER
*************************
Predictions:
sonny sharrock -> PER [1.0]
gza -> LOC [1.0]
gza -> PER [1.0]
gza -> PER [1.0]
the notorious -> PER [1.0]
the notorious b.i.g. -> PER [1.0]
the notorious -> PER [1.0]
b. -> CW [1.0]
b. -> PER [1.0]
i.g. -> CW [1.0]


it is a series of badminton tournaments , sanctioned by badminton world federation ( bwf ) since 2007 .
Annotations:
badminton world federation -> GRP
*************************
Predictions:
badminton world federation -> GRP [1.0]
world federation -> GRP [1.0]
since 2007 -> CW [1.0]


all songs written by m.o.d. , unless otherwise stated
Annotations:
m.o.d. -> GRP
*************************
Predictions:
written by -> CW [1.0]
m.o.d. -> GRP [1.0]
m. -> PER [1.0]
o. -> PER [1.0]


he worked in a bookstore before becoming a journalist , first for le devoir , a