In [1]:
%matplotlib inline

In [2]:
import sys
sys.path.append("../../")

In [3]:
import eventx
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import io

from itertools import chain, accumulate

from eventx.predictors.predictor_utils import load_predictor
from eventx.models.model_utils import batched_predict_json, batched_predict_instances
from eventx.predictors import snorkel_predictor, smartdata_predictor
from eventx.util import scorer
from eventx.util import utils
from eventx import SD4M_RELATION_TYPES, ROLE_LABELS

from allennlp.predictors import Predictor

In [4]:
from typing import List

In [5]:
def make_result_dict(row_name, p, r, f1):
    tmp_dict = {
        'row_name': row_name,
        'P': p,
        'R': r,
        'F1': f1
    }
    print(tmp_dict)
    return tmp_dict

In [6]:
CUDA_DEVICE = -1  # or -1 if no GPU is available

MODEL_DIR = "../../data/runs"

In [7]:
DATASET_PATH = "../../data/daystream_corpus/test_sd4m_with_events.jsonl"
PREDICTOR_NAME = "snorkel-eventx-predictor"
# PREDICTOR_NAME = "smartdata-eventx-predictor"
    
ALLENNLP_MODEL = "snorkel_bert_v6-first_trigger_check_gold/model.tar.gz"
# ALLENNLP_MODEL = "plass_bert_gold/model.tar.gz"

In [8]:
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename=ALLENNLP_MODEL, weights_file=None)

In [9]:
instances = []
gold_doc_list = []
docs_without_triggers = 0
docs_with_triggers = 0
with io.open(DATASET_PATH) as test_file:
    for line in test_file.readlines():
        example = json.loads(line)
        if any(e['entity_type'].lower() == 'trigger' for e in example['entities']):
            gold_doc_list.append(example)
            if 'event_triggers' not in example or 'event_roles' not in example:
                # Convert ACE events to Snorkel event triggers & roles
                example = utils.convert_events(example)
            instances.append(predictor._json_to_instance(example))
            docs_with_triggers += 1
        else:
            # print(f"Document {example['id']} does not contain triggers and is therefore not supported.")
            docs_without_triggers += 1
print(f"Docs with triggers: {docs_with_triggers} \t Docs without triggers (not supported): {docs_without_triggers}")

Docs with triggers: 68 	 Docs without triggers (not supported): 98


In [10]:
prediction_instances = batched_predict_instances(predictor, instances)

In [11]:
table_results = []  # save results in here to later export as csv

## DFKI spree REScorer adapted to Python

In [12]:
pred_events_batch = [prediction_instance['events'] for prediction_instance in prediction_instances]
gold_events_batch = [doc['events'] for doc in gold_doc_list]

Grusdt, B., Nehring, J., & Thomas, P. (2018). Bootstrapping patterns for the detection of mobility related events.
> For a detected event to count as true positive the predicted event type must be equal to the gold standard event type and the predicted event span must at least be subsumed by the gold standard event span.

In [13]:
results = scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=True, ignore_optional_args=True)
acc_results = results['accumulated']
tmp_dict = make_result_dict("Event Extraction Acc. (Grusdt et al 2018)", acc_results.precision(), acc_results.recall(), acc_results.f1())
table_results.append(tmp_dict)
for event_class in SD4M_RELATION_TYPES[:-1]:
    if event_class in results:
        class_results = results[event_class]
        tmp_dict = make_result_dict(f"{event_class} (Grusdt et al 2018)", class_results.precision(), class_results.recall(), class_results.f1())
        table_results.append(tmp_dict)

{'row_name': 'Event Extraction Acc. (Grusdt et al 2018)', 'P': 0.7368421052631579, 'R': 0.6774193548387096, 'F1': 0.7058823529411764}
{'row_name': 'Accident (Grusdt et al 2018)', 'P': 0.6666666666666666, 'R': 0.8888888888888888, 'F1': 0.761904761904762}
{'row_name': 'CanceledRoute (Grusdt et al 2018)', 'P': 0.5, 'R': 0.36363636363636365, 'F1': 0.4210526315789474}
{'row_name': 'CanceledStop (Grusdt et al 2018)', 'P': 0.6666666666666666, 'R': 0.6666666666666666, 'F1': 0.6666666666666666}
{'row_name': 'Delay (Grusdt et al 2018)', 'P': 1.0, 'R': 0.3333333333333333, 'F1': 0.5}
{'row_name': 'Obstruction (Grusdt et al 2018)', 'P': 0.7, 'R': 0.5833333333333334, 'F1': 0.6363636363636365}
{'row_name': 'RailReplacementService (Grusdt et al 2018)', 'P': 0.8, 'R': 0.8, 'F1': 0.8000000000000002}
{'row_name': 'TrafficJam (Grusdt et al 2018)', 'P': 0.8823529411764706, 'R': 0.9375, 'F1': 0.9090909090909091}


Schiersch, M., Mironova, V., Schmitt, M., Thomas, P., Gabryszak, A., & Hennig, L. (2018). A german corpus for fine-grained named entity recognition and relation extraction of traffic and industry events. arXiv preprint arXiv:2004.03283.
> \[W\]e chose a soft matching strategy that counts a predicted relation mention as correct if all predicted arguments also occur in the corresponding gold relation mention, and if all required arguments have been correctly predicted, based on their role, underlying entity, and character offsets / extent. Optional arguments from the gold relation mention that are not contained in the predicted relation mention do not count as errors. In other words, we count a predicted relation mention as correct if it contains all required arguments and is subsumed by or equal to the gold relation mention.

In [14]:
results = scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=False, ignore_optional_args=True)
acc_results = results['accumulated']
tmp_dict = make_result_dict("Event Extraction Acc. (Schiersch et al 2018)", acc_results.precision(), acc_results.recall(), acc_results.f1())
table_results.append(tmp_dict)
for event_class in SD4M_RELATION_TYPES[:-1]:
    if event_class in results:
        class_results = results[event_class]
        tmp_dict = make_result_dict(f"{event_class} (Schiersch et al 2018)", class_results.precision(), class_results.recall(), class_results.f1())
        table_results.append(tmp_dict)

{'row_name': 'Event Extraction Acc. (Schiersch et al 2018)', 'P': 0.7017543859649122, 'R': 0.6451612903225806, 'F1': 0.6722689075630253}
{'row_name': 'Accident (Schiersch et al 2018)', 'P': 0.6666666666666666, 'R': 0.8888888888888888, 'F1': 0.761904761904762}
{'row_name': 'CanceledRoute (Schiersch et al 2018)', 'P': 0.375, 'R': 0.2727272727272727, 'F1': 0.3157894736842105}
{'row_name': 'CanceledStop (Schiersch et al 2018)', 'P': 0.6666666666666666, 'R': 0.6666666666666666, 'F1': 0.6666666666666666}
{'row_name': 'Delay (Schiersch et al 2018)', 'P': 1.0, 'R': 0.3333333333333333, 'F1': 0.5}
{'row_name': 'Obstruction (Schiersch et al 2018)', 'P': 0.6, 'R': 0.5, 'F1': 0.5454545454545454}
{'row_name': 'RailReplacementService (Schiersch et al 2018)', 'P': 0.8, 'R': 0.8, 'F1': 0.8000000000000002}
{'row_name': 'TrafficJam (Schiersch et al 2018)', 'P': 0.8823529411764706, 'R': 0.9375, 'F1': 0.9090909090909091}


## Event extraction evaluation using correctness criteria defined by Ji, Heng and Grishman, Ralph 2008

Ji, Heng and Grishman, Ralph (2008). Refining event extraction through cross-document inference.
> A trigger is correctly labeled if its event type and offsets match a reference trigger.

> An argument is correctly identified if its event type and offsets match any of the reference argument mentions.

> An argument is correctly identified and classified if its event type, offsets, and role match any of the reference argument mentions.

Caution:
Using the following methods to retrieve the triggers and arguments from the gold data might result in duplicate gold triggers & arguments.
This is due to different events possibly sharing the same trigger.
The model is not able to distinguish such events and instead fuses them all together, which should result in lower recall.
If we remove duplicates from the gold triggers and gold arguments, recall and consequently f1 should be higher.

In [15]:
REMOVE_DUPLICATES = True  # change to False if you want to keep duplicate triggers/ arguments from the gold data caused by events sharing the same trigger

In [16]:
gold_triggers = scorer.get_triggers(gold_doc_list)
gold_arguments = scorer.get_arguments(gold_doc_list)
pred_triggers = scorer.get_triggers(prediction_instances)
pred_arguments = scorer.get_arguments(prediction_instances)

In [17]:
if REMOVE_DUPLICATES:
    gold_triggers = list(set(gold_triggers))
    gold_arguments = list(set(gold_arguments))

In [18]:
precision, recall, f1 = scorer.get_trigger_identification_metrics(gold_triggers, pred_triggers)
tmp_dict = make_result_dict('Trigger Identification', precision, recall, f1)
table_results.append(tmp_dict)

{'row_name': 'Trigger Identification', 'P': 0.8947368421052632, 'R': 0.8947368421052632, 'F1': 0.8947368421052632}


In [19]:
precision, recall, f1 = scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=True)
tmp_dict = make_result_dict('Trigger Classification', precision, recall, f1)
table_results.append(tmp_dict)

{'row_name': 'Trigger Classification', 'P': 0.8947368421052632, 'R': 0.8947368421052632, 'F1': 0.8947368421052632}


In [20]:
class_results = scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=False)
for trigger_class in SD4M_RELATION_TYPES[:-1]:
    if trigger_class in class_results:
        precision, recall, f1 = class_results[trigger_class]
    else:
        precision, recall, f1 = 0.0, 0.0, 0.0
    tmp_dict = make_result_dict(trigger_class, precision, recall, f1)
    table_results.append(tmp_dict)

{'row_name': 'Accident', 'P': 0.75, 'R': 1.0, 'F1': 0.8571428571428571}
{'row_name': 'CanceledRoute', 'P': 1.0, 'R': 0.8, 'F1': 0.888888888888889}
{'row_name': 'CanceledStop', 'P': 1.0, 'R': 1.0, 'F1': 1.0}
{'row_name': 'Delay', 'P': 1.0, 'R': 1.0, 'F1': 1.0}
{'row_name': 'Obstruction', 'P': 1.0, 'R': 0.8333333333333334, 'F1': 0.9090909090909091}
{'row_name': 'RailReplacementService', 'P': 0.8, 'R': 0.8, 'F1': 0.8000000000000002}
{'row_name': 'TrafficJam', 'P': 0.8823529411764706, 'R': 0.9375, 'F1': 0.9090909090909091}


In [21]:
precision, recall, f1 = scorer.get_argument_identification_metrics(gold_arguments, pred_arguments)
tmp_dict = make_result_dict('Argument Identification', precision, recall, f1)
table_results.append(tmp_dict)

{'row_name': 'Argument Identification', 'P': 0.7695473251028807, 'R': 0.7663934426229508, 'F1': 0.7679671457905545}


In [22]:
precision, recall, f1 = scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=True)
tmp_dict = make_result_dict('Argument Classification', precision, recall, f1)
table_results.append(tmp_dict)

{'row_name': 'Argument Classification', 'P': 0.7407407407407407, 'R': 0.7377049180327869, 'F1': 0.7392197125256674}


In [23]:
class_results = scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=False)
for role_class in ROLE_LABELS[:-1]:
    if role_class in class_results:
        precision, recall, f1 = class_results[role_class]
    else:
        precision, recall, f1 = 0.0, 0.0, 0.0
    tmp_dict = make_result_dict(role_class, precision, recall, f1)
    table_results.append(tmp_dict)

{'row_name': 'location', 'P': 0.7671232876712328, 'R': 0.6666666666666666, 'F1': 0.713375796178344}
{'row_name': 'delay', 'P': 0.8, 'R': 0.5, 'F1': 0.6153846153846154}
{'row_name': 'direction', 'P': 0.7058823529411765, 'R': 0.7741935483870968, 'F1': 0.7384615384615385}
{'row_name': 'start_loc', 'P': 0.7391304347826086, 'R': 0.8717948717948718, 'F1': 0.7999999999999999}
{'row_name': 'end_loc', 'P': 0.66, 'R': 0.8461538461538461, 'F1': 0.7415730337078651}
{'row_name': 'start_date', 'P': 1.0, 'R': 0.2222222222222222, 'F1': 0.3636363636363636}
{'row_name': 'end_date', 'P': 0.5, 'R': 0.3333333333333333, 'F1': 0.4}
{'row_name': 'cause', 'P': 0.7692307692307693, 'R': 0.6666666666666666, 'F1': 0.7142857142857142}
{'row_name': 'jam_length', 'P': 0.9285714285714286, 'R': 1.0, 'F1': 0.962962962962963}
{'row_name': 'route', 'P': 0.75, 'R': 1.0, 'F1': 0.8571428571428571}


In [24]:
table_results_df = pd.DataFrame(table_results)

In [25]:
table_results_df.to_csv(MODEL_DIR + '/table_results.csv')