In [None]:
from collections import Counter
import numpy as np

In [None]:
train_file = '../data/ace2005/ace2005.train'

In [None]:
embedding = '../data/glove/glove.6B.100d.txt'

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [None]:
def simple_read(fpath):
    with open(fpath) as f:
        data = f.readlines()

    output = []
    for raw_snt, labels in zip(data[::3], data[1::3]):
        raw_snt, labels = raw_snt.strip(), labels.strip()

        words = raw_snt.split()
        start_position = np.cumsum([0] + [len(w) + 1 for w in words])
        subtoks = tokenizer.tokenize(raw_snt)
        assert sum(len(s) for s in subtoks) == len(raw_snt)

        # label to character-indexed.
        if labels == "":
            labels, entities = [], []
        else:
            labels = labels.split("|")
            _labels, entities = [], []
            for label in labels:
                position, tag = label.split()
                l, r = list(map(int, position.split(",")))
                _labels.append((start_position[l], start_position[r] - 1, tag))
                entities.append(" ".join(words[l:r]))
            labels = _labels

        # align label to subwords
        char_left, char_right = [], []
        for i, tok in enumerate(subtoks):
            char_left.extend([i] * len(tok))
            char_right.extend([i + 1] * len(tok))
        _labels = []
        for label, entity_surface in zip(labels, entities):
            _labels.append(
                (char_left[label[0]], char_right[label[1] - 1], label[2])
            )

            # sanity check
            recovery = "".join(
                subtoks[char_left[label[0]] : char_right[label[1] - 1]]
            )
            recovery = recovery.replace("Ä ", " ").lstrip()
            assert recovery == entity_surface
        labels = _labels

        instance = {
            "id": len(output),
            "snt": subtoks,
            "labels": labels,
        }

        output.append(instance)
    return output

In [None]:
data = simple_read(train_file)

In [None]:
max_len, max_snt_len, labelset = 0, 0, Counter()
for inst in data:
    max_snt_len = max(max_snt_len, len(inst["snt"]))
    for label in inst["labels"]:
        length = label[1] - label[0]
        max_len = max(max_len, length)
    labelset.update(label[2] for label in inst["labels"])
print("max_span_width:", max_len)
print("max_snt_length:", max_snt_len)

In [None]:
print(labelset.most_common())
with open("../data/resources/ace2005.label_vocab.txt", "w") as f:
    for w, c in labelset.most_common():
        f.write(f"{w} {c}\n")


# Static word embedding

In [None]:
with open(embedding) as f:
    embedding = f.readlines()
