In [None]:
%matplotlib inline

In [None]:
import sys
sys.path.append("../../")

In [None]:
import eventx
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import io

from itertools import chain, accumulate

from eventx.predictors.predictor_utils import load_predictor
from eventx.models.model_utils import batched_predict_json, batched_predict_instances
from eventx.predictors import snorkel_predictor, smartdata_predictor
from eventx.util import scorer
from eventx.util.utils import snorkel_to_ace_format
from eventx import SD4M_RELATION_TYPES, ROLE_LABELS

from allennlp.predictors import Predictor

In [None]:
from typing import List

In [None]:
def has_triggers(doc):
    return any(entity['entity_type'] == 'trigger' for entity in doc['entities'])

In [None]:
CUDA_DEVICE = -1  # or -1 if no GPU is available

MODEL_DIR = "../../data/runs"

SNORKEL = True  # set to False to use smartdata-eventx model

In [None]:
if SNORKEL:
    DATASET_PATH = "../../data/snorkel_new/test_with_events_and_defaults.jsonl"
    PREDICTOR_NAME = "snorkel-eventx-predictor"
    # ALLENNLP_MODEL = "snorkel_bert_v6-first_trigger_check_gold/model.tar.gz"
    # ALLENNLP_MODEL = "snorkel_bert_v6-first_trigger_check_converted_abstains/model.tar.gz"
    ALLENNLP_MODEL = "snorkel_bert_v6-first-trigger_check_snorkeled_gold_conv_merge_with_abstains/model.tar.gz"
else:
    DATASET_PATH = "../../data/snorkel_new/test_sd4m_with_events.jsonl"
    PREDICTOR_NAME = "smartdata-eventx-predictor"
    ALLENNLP_MODEL = "plass_bert_gold/model.tar.gz"

In [None]:
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename=ALLENNLP_MODEL, weights_file=None)

In [None]:
instances = []
docs_without_triggers = 0
docs_with_triggers = 0
with io.open(DATASET_PATH) as test_file:
    for line in test_file.readlines():
        example = json.loads(line)
        if any(e['entity_type'].lower() == 'trigger' for e in example['entities']):
            instances.append(predictor._json_to_instance(example))
            docs_with_triggers += 1
        else:
            # print(f"Document {example['id']} does not contain triggers and is therefore not supported.")
            docs_without_triggers += 1
print(f"Docs with triggers: {docs_with_triggers} \t Docs without triggers (not supported): {docs_without_triggers}")
prediction_instances = batched_predict_instances(predictor, instances)

In [None]:
doc_as_json_list = []
with io.open(DATASET_PATH, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        example = json.loads(line)
        doc_as_json_list.append(example)
filtered_doc_list = [doc for doc in doc_as_json_list if has_triggers(doc)]

In [None]:
if SNORKEL:
    filtered_doc_list = snorkel_to_ace_format(filtered_doc_list)

In [None]:
table_results = []  # save results in here to later export as csv

## DFKI spree REScorer adapted to Python
Grusdt, B., Nehring, J., & Thomas, P. (2018). Bootstrapping patterns for the detection of mobility related events.
> For a detected event to count as true positive the predicted event type must be equal to the gold standard event type and the predicted event span must at least be subsumed by the gold standard event span.

Schiersch, M., Mironova, V., Schmitt, M., Thomas, P., Gabryszak, A., & Hennig, L. (2018). A german corpus for fine-grained named entity recognition and relation extraction of traffic and industry events. arXiv preprint arXiv:2004.03283.
> \[W\]e chose a soft matching strategy that counts a predicted relation mention as correct if all predicted arguments also occur in the corresponding gold relation mention, and if all required arguments have been correctly predicted, based on their role, underlying entity, and character offsets / extent. Optional arguments from the gold relation mention that are not contained in the predicted relation mention do not count as errors. In other words, we count a predicted relation mention as correct if it contains all required arguments and is subsumed by or equal to the gold relation mention.

The scorer does the following:
- It goes through every gold event and looks for a matching/ subsumed predicted event using an EventComparator (see below).
- It increments the true positive count if such a predicted event is found and increments the false negative count otherwise.
- It treats all remaining predicted events, which were not matched with any of the gold events as false positives.

The EventComparator compares two events using the following criteria:
- Do the event types match?
- Do the spans match or is the predicted event subsumed by the gold event?
- Optionally: Do all the predicted arguments match any of the gold arguments? (I.e. do the argument spans and argument roles match?)

In [None]:
pred_events_batch = [prediction_instance['events'] for prediction_instance in prediction_instances]
gold_events_batch = [doc['events'] for doc in filtered_doc_list]

In [None]:
results = scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=True, ignore_optional_args=True)
acc_results = results['accumulated']
tmp_dict = {
    'row_name': "Event Extraction Acc. (Grusdt et al 2018)",
    'P': acc_results.precision(),
    'R': acc_results.recall(),
    'F1': acc_results.f1()
}
print(tmp_dict)
table_results.append(tmp_dict)
for event_class in SD4M_RELATION_TYPES[:-1]:
    if event_class in results:
        class_results = results[event_class]
        tmp_dict = {
            'row_name': f"{event_class} (Grusdt et al 2018)",
            'P': class_results.precision(),
            'R': class_results.recall(),
            'F1': class_results.f1()
        }
        print(tmp_dict)
        table_results.append(tmp_dict)

In [None]:
# Schiersch et al 2018

results = scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=False, ignore_optional_args=True)
acc_results = results['accumulated']
tmp_dict = {
    'row_name': "Event Extraction Acc. (Schiersch et al 2018)",
    'P': acc_results.precision(),
    'R': acc_results.recall(),
    'F1': acc_results.f1()
}
print(tmp_dict)
table_results.append(tmp_dict)
for event_class in SD4M_RELATION_TYPES[:-1]:
    if event_class in results:
        class_results = results[event_class]
        tmp_dict = {
            'row_name': f"{event_class} (Schiersch et al 2018)",
            'P': class_results.precision(),
            'R': class_results.recall(),
            'F1': class_results.f1()
        }
        print(tmp_dict)
        table_results.append(tmp_dict)

## Event extraction evaluation using correctness criteria defined by Ji, Heng and Grishman, Ralph 2008

Ji, Heng and Grishman, Ralph (2008). Refining event extraction through cross-document inference.
> A trigger is correctly labeled if its event type and offsets match a reference trigger.

> An argument is correctly identified if its event type and offsets match any of the reference argument mentions.

> An argument is correctly identified and classified if its event type, offsets, and role match any of the reference argument mentions.

Caution:
Using the following methods to retrieve the triggers and arguments from the gold data might result in duplicate gold triggers & arguments.
This is due to different events possibly sharing the same trigger.
The model is not able to distinguish such events and instead fuses them all together, which should result in lower recall.
If we remove duplicates from the gold triggers and gold arguments, recall and consequently f1 should be higher.

In [None]:
REMOVE_DUPLICATES = True  # change to False if you want to keep duplicate triggers/ arguments from the gold data caused by events sharing the same trigger

In [None]:
gold_triggers = scorer.get_triggers(filtered_doc_list)
gold_arguments = scorer.get_arguments(filtered_doc_list)
pred_triggers = scorer.get_triggers(prediction_instances)
pred_arguments = scorer.get_arguments(prediction_instances)

In [None]:
if REMOVE_DUPLICATES:
    gold_triggers = list(set(gold_triggers))
    gold_arguments = list(set(gold_arguments))

In [None]:
precision, recall, f1 = scorer.get_trigger_identification_metrics(gold_triggers, pred_triggers)
tmp_dict = {
    'row_name': 'Trigger Identification',
    'P': precision,
    'R': recall,
    'F1': f1
}
print(tmp_dict)
table_results.append(tmp_dict)

In [None]:
precision, recall, f1 = scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=True)
tmp_dict = {
    'row_name': 'Trigger Classification',
    'P': precision,
    'R': recall,
    'F1': f1
}
print(tmp_dict)
table_results.append(tmp_dict)

In [None]:
class_results = scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=False)
for trigger_class in SD4M_RELATION_TYPES[:-1]:
    if trigger_class in class_results:
        precision, recall, f1 = class_results[trigger_class]
    else:
        precision, recall, f1 = 0.0, 0.0, 0.0
    tmp_dict = {
        'row_name': trigger_class,
        'P': precision,
        'R': recall,
        'F1': f1
    }
    print(tmp_dict)
    table_results.append(tmp_dict)

In [None]:
precision, recall, f1 = scorer.get_argument_identification_metrics(gold_arguments, pred_arguments)
tmp_dict = {
    'row_name': 'Argument Identification',
    'P': precision,
    'R': recall,
    'F1': f1
}
print(tmp_dict)
table_results.append(tmp_dict)

In [None]:
precision, recall, f1 = scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=True)
tmp_dict = {
    'row_name': 'Argument Classification',
    'P': precision,
    'R': recall,
    'F1': f1
}
print(tmp_dict)
table_results.append(tmp_dict)

In [None]:
class_results = scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=False)
for role_class in ROLE_LABELS[:-1]:
    if role_class in class_results:
        precision, recall, f1 = class_results[role_class]
    else:
        precision, recall, f1 = 0.0, 0.0, 0.0
    tmp_dict = {
        'row_name': role_class,
        'P': precision,
        'R': recall,
        'F1': f1
    }
    print(tmp_dict)
    table_results.append(tmp_dict)

In [None]:
table_results_df = pd.DataFrame(table_results)

In [None]:
table_results_df.to_csv('/Users/phuc/Desktop/table_results.csv')