In [1]:
%matplotlib inline

In [2]:
import sys
sys.path.append("../../")

In [3]:
import eventx
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import io

from itertools import chain, accumulate

from eventx.predictors.predictor_utils import load_predictor
from eventx.models.model_utils import batched_predict_json, batched_predict_instances
from eventx.predictors import snorkel_predictor, smartdata_predictor
from eventx.util import scorer
from eventx.util.utils import snorkel_to_ace_format

from allennlp.predictors import Predictor

In [4]:
from typing import List

In [5]:
def has_triggers(doc):
    return any(entity['entity_type'] == 'trigger' for entity in doc['entities'])

In [6]:
CUDA_DEVICE = -1  # or -1 if no GPU is available

MODEL_DIR = "../../data/runs"

SNORKEL = True  # set to False to use smartdata-eventx model

In [7]:
if SNORKEL:
    DATASET_PATH = "../../data/snorkel_new/test_with_events_and_defaults.jsonl"
    PREDICTOR_NAME = "snorkel-eventx-predictor"
    ALLENNLP_MODEL = "snorkel_bert_v6-first_trigger_check_gold/model.tar.gz"
else:
    DATASET_PATH = "../../data/snorkel_new/test_sd4m_with_events.jsonl"
    PREDICTOR_NAME = "smartdata-eventx-predictor"
    ALLENNLP_MODEL = "plass_bert_gold/model.tar.gz"

In [8]:
predictor = load_predictor(MODEL_DIR, PREDICTOR_NAME, CUDA_DEVICE, archive_filename=ALLENNLP_MODEL, weights_file=None)

In [9]:
instances = predictor._dataset_reader.read(DATASET_PATH) # beware that the read method automatically filters out documents without triggers
prediction_instances = batched_predict_instances(predictor, instances)

68it [00:00, 211.13it/s]
Encountered the triggers_loss key in the model's return dictionary which couldn't be split by the batch size. Key will be ignored.
Encountered the loss key in the model's return dictionary which couldn't be split by the batch size. Key will be ignored.
Encountered the role_loss key in the model's return dictionary which couldn't be split by the batch size. Key will be ignored.


In [10]:
doc_as_json_list = []
with io.open(DATASET_PATH, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        example = json.loads(line)
        doc_as_json_list.append(example)
filtered_doc_list = [doc for doc in doc_as_json_list if has_triggers(doc)]

In [11]:
if SNORKEL:
    filtered_doc_list = snorkel_to_ace_format(filtered_doc_list)

## DFKI spree REScorer adapted to Python
Grusdt, B., Nehring, J., & Thomas, P. (2018). Bootstrapping patterns for the detection of mobility related events.
> For a detected event to count as true positive the predicted event type must be equal to the gold standard event type and the predicted event span must at least be subsumed by the gold standard event span.

Schiersch, M., Mironova, V., Schmitt, M., Thomas, P., Gabryszak, A., & Hennig, L. (2020). A german corpus for fine-grained named entity recognition and relation extraction of traffic and industry events. arXiv preprint arXiv:2004.03283.
> \[W\]e chose a soft matching strategy that counts a predicted relation mention as correct if all predicted arguments also occur in the corresponding gold relation mention, and if all required arguments have been correctly predicted, based on their role, underlying entity, and character offsets / extent. Optional arguments from the gold relation mention that are not contained in the predicted relation mention do not count as errors. In other words, we count a predicted relation mention as correct if it contains all required arguments and is subsumed by or equal to the gold relation mention.

The scorer does the following:
- It goes through every gold event and looks for a matching/ subsumed predicted event using an EventComparator (see below).
- It increments the true positive count if such a predicted event is found and increments the false negative count otherwise.
- It treats all remaining predicted events, which were not matched with any of the gold events as false positives.

The EventComparator compares two events using the following criteria:
- Do the event types match?
- Do the spans match or is the predicted event subsumed by the gold event?
- Do all the predicted arguments match any of the gold arguments? (I.e. do the argument spans and argument roles match?)

In [12]:
pred_events_batch = [prediction_instance['events'] for prediction_instance in prediction_instances]
gold_events_batch = [doc['events'] for doc in filtered_doc_list]

In [13]:
scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=False, ignore_optional_args=False)

Obstruction: 	P=0.400	R=0.333	F1=0.364

Delay: 	P=1.000	R=1.000	F1=1.000

Accident: 	P=0.333	R=0.444	F1=0.381

RailReplacementService: 	P=0.400	R=0.400	F1=0.400

TrafficJam: 	P=0.824	R=0.875	F1=0.848

CanceledRoute: 	P=0.500	R=0.400	F1=0.444

CanceledStop: 	P=0.667	R=0.667	F1=0.667



In [14]:
scorer.score_events_batch(pred_events_batch, gold_events_batch, allow_subsumption=True, keep_event_matches=True, ignore_span=False, ignore_args=False, ignore_optional_args=True)

Obstruction: 	P=0.600	R=0.500	F1=0.545

Delay: 	P=1.000	R=1.000	F1=1.000

Accident: 	P=0.500	R=0.667	F1=0.571

RailReplacementService: 	P=0.800	R=0.800	F1=0.800

TrafficJam: 	P=0.824	R=0.875	F1=0.848

CanceledRoute: 	P=0.500	R=0.400	F1=0.444

CanceledStop: 	P=0.667	R=0.667	F1=0.667



## Event extraction evaluation according to Chen et al. 2015

Yubo, C., Liheng, X., Kang, L., Daojian, Z., & Jun, Z. (2015). Event extraction via dynamic multi-pooling convolutional neural networks.
> A trigger is correct if its event subtype and offsets match those of a reference trigger.

> An argument is correctly identified if its event subtype and offsets match those of any of the reference argument mentions.

> An argument is correctly classified if its event subtype, offsets and argument role match those of any of the reference argument mentions.

Caution:
Using the following methods to retrieve the triggers and arguments from the gold data might result in duplicate gold triggers & arguments.
This is due to different events possibly sharing the same trigger.
The model is not able to distinguish such events and instead fuses them all together, which should result in lower recall.
If we remove duplicates from the gold triggers and gold arguments, recall and consequently f1 should be higher.

In [15]:
REMOVE_DUPLICATES = True  # change to False if you want to keep duplicate triggers/ arguments from the gold data caused by events sharing the same trigger

In [16]:
gold_triggers = scorer.get_triggers(filtered_doc_list)
gold_arguments = scorer.get_arguments(filtered_doc_list)
pred_triggers = scorer.get_triggers(prediction_instances)
pred_arguments = scorer.get_arguments(prediction_instances)

In [17]:
if REMOVE_DUPLICATES:
    gold_triggers = list(set(gold_triggers))
    gold_arguments = list(set(gold_arguments))

In [18]:
scorer.get_trigger_identification_metrics(gold_triggers, pred_triggers)

(0.8947368421052632, 0.8947368421052632, 0.8947368421052632)

In [19]:
scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=True)

(0.8947368421052632, 0.8947368421052632, 0.8947368421052632)

In [20]:
scorer.get_trigger_classification_metrics(gold_triggers, pred_triggers, accumulated=False)

{'CanceledStop': (1.0, 1.0, 1.0),
 'Delay': (1.0, 1.0, 1.0),
 'TrafficJam': (0.8823529411764706, 0.9375, 0.9090909090909091),
 'Obstruction': (1.0, 0.8333333333333334, 0.9090909090909091),
 'CanceledRoute': (1.0, 0.8, 0.888888888888889),
 'RailReplacementService': (0.8, 0.8, 0.8000000000000002),
 'Accident': (0.75, 1.0, 0.8571428571428571)}

In [21]:
scorer.get_argument_identification_metrics(gold_arguments, pred_arguments)

(0.6237623762376238, 0.78099173553719, 0.6935779816513761)

In [22]:
scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=True)

(0.5973597359735974, 0.7479338842975206, 0.6642201834862386)

In [23]:
scorer.get_argument_classification_metrics(gold_arguments, pred_arguments, accumulated=False)

{'start_date': (1.0, 0.25, 0.4),
 'end_date': (0.5, 0.3333333333333333, 0.4),
 'direction': (0.5454545454545454, 0.7741935483870968, 0.64),
 'jam_length': (0.9285714285714286, 1.0, 0.962962962962963),
 'location': (0.47107438016528924, 0.6785714285714286, 0.5560975609756097),
 'route': (0.75, 1.0, 0.8571428571428571),
 'cause': (0.7333333333333333, 0.7333333333333333, 0.7333333333333333),
 'end_loc': (0.66, 0.8461538461538461, 0.7415730337078651),
 'start_loc': (0.717391304347826, 0.868421052631579, 0.7857142857142858),
 'delay': (0.8, 0.5, 0.6153846153846154)}