In [1]:
import spacy
import numpy as np

from transformers import AutoTokenizer, AutoModel

num_dict = {0: "before", 1: "after", 2: "equal", 3: "vague"}
rel_2_num_dict = {"before":0, "after":1, "equal":2, "vague":3}

nlp = spacy.load("en_core_web_sm")

tokenizer = AutoTokenizer.from_pretrained('google/bigbird-roberta-large')   

before = np.load("../data/flan_t5/before.npy", allow_pickle=True)
after = np.load("../data/flan_t5/after.npy", allow_pickle=True)
equal = np.load("../data/flan_t5/equal.npy", allow_pickle=True)

In [2]:
from tqdm import tqdm

def get_lemma(t):
    parsed_text = nlp(tokenizer.decode([t]))
    if len(parsed_text) == 0:
        return ''
    else:
        return parsed_text[0].lemma_

# find event_pos and event_pos_end for h and t
def get_event_pos(token_ids, h, t):
    parsed_text = [get_lemma(t) for t in token_ids]
    event_pos_h = [i for i, token in enumerate(parsed_text) if token == h]
    event_pos_t = [i for i, token in enumerate(parsed_text) if token == t]
    if len(event_pos_h) > 0 and len(event_pos_t) > 0:
        return event_pos_h[0], event_pos_t[0]
    return None

equal_features = []
relations = [rel_2_num_dict["equal"]]
for i in tqdm(range(len(equal))):
    h, t, _ = equal[i][0]
    for j in range(len(equal[i][1])):
        text = equal[i][1][j]
        text = "[CLS] " + text.replace(". ", ".[SEP] ") + "[SEP]"
        TokenIDs = tokenizer.encode(text, add_special_tokens=False)
        event_pos = get_event_pos(TokenIDs, h, t)
        if not event_pos:
            continue
        event_pos = list(event_pos)
        event_pos_end = [event_pos[0]+1, event_pos[1]+1]
        
        feature = {'input_ids': TokenIDs,
                   'event_pos': event_pos,
                   'event_pos_end': event_pos_end,
                   'event_pair': [[1, 2]],
                   'labels': relations,
                  }
        equal_features.append(feature)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 271/271 [22:14<00:00,  4.92s/it]


In [85]:
tokenizer.decode([321])

''

In [3]:
import json
with open("../data/flan_t5/equal_features.json", "w") as writer:
    json.dump(equal_features, writer) 