In [None]:
import spacy
from spacy.tokens import Doc
from spacy.matcher import Matcher
from zipfile import ZipFile
from pathlib import Path
from tqdm import autonotebook as tqdm
from itertools import combinations
import re


In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
data_dir = Path("./data/teaching-dataset")
with (data_dir / "relation_extraction_text_train.zip").open("rb") as file:
    zip_file = ZipFile(file)
    with zip_file.open("input.txt") as f:
        sentences = [
            sentence.split("\n") for sentence in f.read().decode("utf-8").split("\n\n")
        ]
with (data_dir / "relation_extraction_references_train.zip").open("rb") as file:
    zip_file = ZipFile(file)
    with zip_file.open("references.txt") as f:
        labels = []
        for line in f.read().decode("utf-8").split("\n"):
            relations = []
            for relation in re.finditer(r"\(\((\d+),(\d+)\),\((\d+),(\d+)\)\)", line):
                relation = (
                    (int(relation.group(1)), int(relation.group(2))),
                    (int(relation.group(3)), int(relation.group(4))),
                )
                relations.append(relation)
            labels.append(relations)
assert len(sentences) == len(labels)

In [None]:
doc = nlp(Doc(nlp.vocab, words=sentences[0]))
doc

In [None]:
labels[0]

In [None]:
CAUSAL_CUES = ["because", "since", "due", "cause", "lead", "result", "effect"]
MATCHER = Matcher(nlp.vocab)
# sentence classification
MATCHER.add("CAUSAL", [[{"LEMMA": cue}] for cue in CAUSAL_CUES])
# span extraction
MATCHER.add(
    "EVENT",
    [
        [
            {"POS": {"IN": ["NOUN", "PROPN", "ADJ"]}, "OP": "+"},
        ]
    ],
)
# relation classification
# https://stackoverflow.com/questions/74528441/detect-passive-or-active-sentence-from-text
MATCHER.add(
    "PASSIVE",
    [
        [
            {"DEP": "nsubjpass"},
            {"DEP": "aux", "OP": "*"},
            {"DEP": "auxpass"},
            {"TAG": "VBN"},
        ],
        [
            {"DEP": "nsubjpass"},
            {"DEP": "aux", "OP": "*"},
            {"DEP": "auxpass"},
            {"TAG": "VBZ"},
        ],
        [
            {"DEP": "nsubjpass"},
            {"DEP": "aux", "OP": "*"},
            {"DEP": "auxpass"},
            {"TAG": "RB"},
            {"TAG": "VBN"},
        ],
    ],
)

def in_the_middle_or_overlap(ent_1, ent_2, match_span):
    if ent_1.start > ent_2.start:
        ent_1, ent_2 = ent_2, ent_1
    return ent_1.end < match_span.start < ent_2.start or ent_1.end < match_span.end < ent_2.start


def match(doc):
    raw_matches = MATCHER(doc)
    matches = {"EVENT": [], "CAUSAL": [], "PASSIVE": [], "ACTIVE": []}
    for match_id, start, end in raw_matches:
        matches[nlp.vocab.strings[match_id]].append(doc[start:end])
    matches["EVENT"] = spacy.util.filter_spans(matches["EVENT"])
    return matches

def predict(doc):
    matches = match(doc)
    out = []
    if not matches["CAUSAL"]:
        return out
    for event_1, event_2 in combinations(matches["EVENT"], 2):
        for passive_match in matches["PASSIVE"]:
            if in_the_middle_or_overlap(event_1, event_2, passive_match):
                out.append(((event_2.start, event_2.end), (event_1.start, event_1.end)))
                break
        else:
            out.append(((event_1.start, event_1.end), (event_2.start, event_2.end)))
    return out

In [None]:
idx = -1
doc = nlp(Doc(nlp.vocab, words=sentences[idx]))
pred = predict(doc)
print(doc)
print("Ground truth:")
for cause, effect in labels[idx]:
    print("\t{} -> {}".format(doc[cause[0]:cause[1]], doc[effect[0]:effect[1]]))
print("Predictions:")
for cause, effect in pred:
    print("\t{} -> {}".format(doc[cause[0]:cause[1]], doc[effect[0]:effect[1]]))

In [None]:
predictions = []
for sentence in tqdm.tqdm(sentences):
    doc = nlp(Doc(nlp.vocab, words=sentence))
    predictions.append(predict(doc))

In [None]:
def overlap(ref_event, pred_event):
    return max(ref_event[0], pred_event[0]) <= min(ref_event[1], pred_event[1])


def evaluate_pair(reference, prediction):
    ref_cause, ref_effect = reference
    pred_cause, pred_effect = prediction
    if ref_cause == pred_cause and ref_effect == pred_effect:
        return 1
    elif overlap(ref_cause, pred_cause) and overlap(ref_effect, pred_effect):
        return 0.5
    return 0

def precision(tp, fp):
    if not tp:
        return 0
    return tp / (tp + fp)

def recall(tp, fn):
    if not tp:
        return 0
    return tp / (tp + fn)

def f1(tp, fp, fn):
    if not tp:
        return 0
    return 2 * tp / (2 * tp + fp + fn)

def evaluate(references, predictions):
    tps, fps, fns = [], [], []
    for reference, prediction in zip(references, predictions):
        tp, fp, fn = 0, 0, 0
        remaining_references = set(reference)
        for pred in prediction:
            for ref in remaining_references:
                score = evaluate_pair(ref, pred)
                if score:
                    tp += score
                    remaining_references.remove(ref)
                    break
            else:
                fp += 1
        fn += len(remaining_references)
        tps.append(tp)
        fps.append(fp)
        fns.append(fn)

    macro_prec = sum([precision(tp, fp) for tp, fp in zip(tps, fps)]) / len(tps)
    macro_rec = sum([recall(tp, fn) for tp, fn in zip(tps, fns)]) / len(tps)
    macro_f1 = sum([f1(tp, fp, fn) for tp, fp, fn in zip(tps, fps, fns)]) / len(tps)
    micro_prec = precision(sum(tps), sum(fps))
    micro_rec = recall(sum(tps), sum(fns))
    micro_f1 = f1(sum(tps), sum(fps), sum(fns))
    return {
        "macro": {"precision": macro_prec, "recall": macro_rec, "f1": macro_f1},
        "micro": {"precision": micro_prec, "recall": micro_rec, "f1": micro_f1},
    }

In [None]:
evaluate(labels, predictions)

In [None]:
data_dir = Path("./data/teaching-dataset")
with (data_dir / "relation_extraction_text_test.zip").open("rb") as file:
    zip_file = ZipFile(file)
    with zip_file.open("input.txt") as f:
        test_sentences = [
            sentence.split("\n") for sentence in f.read().decode("utf-8").split("\n\n")
        ]

test_predictions = []
for sentence in tqdm.tqdm(test_sentences):
    doc = nlp(Doc(nlp.vocab, words=sentence))
    test_predictions.append(predict(doc))

with open("predictions.txt", "w") as f:
    f.write("\n".join(",".join(str(relation) for relation in prediction) for prediction in test_predictions).replace(" ", ""))