In [1]:
import json
from pprint import pprint
from tqdm import tqdm
from IPython.core.debugger import set_trace
import re
import copy

In [2]:
ace05_ee_path = "../ori_data/ace2005"
train_data = json.load(open(ace05_ee_path + "/train_data.json", "r", encoding = "utf-8"))
dev_data = json.load(open(ace05_ee_path + "/dev_data.json", "r", encoding = "utf-8"))
test_data = json.load(open(ace05_ee_path + "/test_data.json", "r", encoding = "utf-8"))

In [3]:
test_data["nw/timex2norm/AFP_ENG_20030417.0307"].keys()

dict_keys(['word_list', 'lemma_list', 'entities', 'events', 'entity_label_list', 'entity_sublabel_list', 'event_label_list', 'event_argument_list', 'pos_tag_list', 'dependencies', 'num_events', 'num_mw_events', 'sentence_extents', 'docid', 'domain'])

In [4]:
tokenize = lambda text: text.split(" ")
def get_tok2char_span_map(text):
    tokens = tokenize(text)
    tok2char_span = []
    char_num = 0
    for tok in tokens:
        tok2char_span.append((char_num, char_num + len(tok)))
        char_num += len(tok) + 1 # +1: whitespace
    return tok2char_span

In [5]:
def parse_data(data):
    example_list = []
    for k, article in tqdm(data.items()):
        bio_tags = [lab.split("_")[0] for lab in article["event_label_list"]]
        bio_str = "".join(bio_tags)
        text = " ".join(article["word_list"])
        
        # find all events
        event_list = []
        for m in re.finditer("BI*", bio_str):
            trigger_span = [m.span()[0], m.span()[1]]
            trigger_type = article["event_label_list"][trigger_span[0]].split("_")[1]
            even_word_idx_list = [m.span()[0], m.span()[1]]
            trigger_start = trigger_span[0]
            arguments = article["event_argument_list"][trigger_start]
            arguments_new = []
            for arg in arguments:
                arg_start = trigger_start + arg["start"]
                arg_end = trigger_start + arg["end"]
                even_word_idx_list.append(arg_start)
                even_word_idx_list.append(arg_end)
                arg_text = " ".join(article["word_list"][arg_start:arg_end])
    #             assert arg_text == arg["text"].lower()
                arguments_new.append({
                    "text": arg_text,
                    "span": [arg_start, arg_end],
                    "type": arg["type"],
                })
            even_word_idx_list = sorted(even_word_idx_list)
            event_list.append({
                "event_span": [even_word_idx_list[0], even_word_idx_list[-1]],
                "trigger": " ".join(article["word_list"][trigger_span[0]:trigger_span[1]]),
                "trigger_type": trigger_type,
                "trigger_span": trigger_span,
                "argument_list": arguments_new,
            })

        for sent_idx, sen_ext in enumerate(article["sentence_extents"]):
            sent = " ".join(article["word_list"][sen_ext[0]:sen_ext[1]])
            tok2char_span = get_tok2char_span_map(sent)
            
            event_list_sent = []
            for event_ in event_list:
                event = copy.copy(event_)
                if event["event_span"][0] >=  sen_ext[0] and event["event_span"][1] <= sen_ext[1]:
                    for arg in event["argument_list"]:
                        arg["span"] = [arg["span"][0] - sen_ext[0], arg["span"][1] - sen_ext[0]]
                        char_span_list = tok2char_span[arg["span"][0]:arg["span"][1]]
                        arg["span"] = [char_span_list[0][0], char_span_list[-1][1]]
                        assert sent[arg["span"][0]:arg["span"][1]] == arg["text"]
                      
                    del event["event_span"]
#                     event["event_span"] = [event["event_span"][0] - sen_ext[0], event["event_span"][1] - sen_ext[0]]
#                     char_span_list = tok2char_span[event["event_span"][0]:event["event_span"][1]]
#                     event["event_span"] = [char_span_list[0][0], char_span_list[-1][1]]
                    
                    event["trigger_span"] = [event["trigger_span"][0] - sen_ext[0], event["trigger_span"][1] - sen_ext[0]]
                    char_span_list = tok2char_span[event["trigger_span"][0]:event["trigger_span"][1]]
                    event["trigger_span"] = [char_span_list[0][0], char_span_list[-1][1]]
                    assert sent[event["trigger_span"][0]:event["trigger_span"][1]] == event["trigger"]
                    
                    event_list_sent.append(event)
            example_list.append({
                "id": "{}_{}".format(article["docid"], sent_idx),
                "text": sent,
                "event_list": event_list_sent,
            })      
    return example_list

In [6]:
train_data_ = parse_data(train_data)
dev_data_ = parse_data(dev_data)
test_data_ = parse_data(test_data)

100%|██████████| 529/529 [00:00<00:00, 1109.21it/s]
100%|██████████| 30/30 [00:00<00:00, 2250.73it/s]
100%|██████████| 40/40 [00:00<00:00, 1407.53it/s]


In [7]:
# transform to tplinker style
def trans_data(data):
    for sent in data:
        entity_list, rel_list = [], []
        for event in sent["event_list"]:
            entity_list.append({
                "text": event["trigger"],
                "type": event["trigger_type"],
                "char_span": event["trigger_span"],
            })
            for arg in event["argument_list"]:
                entity_list.append({
                    "text": arg["text"],
                    "type": arg["type"],
                    "char_span": arg["span"],
                })
                rel_list.append({
                    "subject": arg["text"],
                    "subj_char_span": arg["span"],
                    "object": event["trigger"],
                    "obj_char_span": event["trigger_span"],
                    "predicate": "{}_{}".format(arg["type"], event["trigger_type"]),
                })
        sent["entity_list"] = entity_list
        sent["relation_list"] = rel_list
    return data

In [8]:
normal_train_data = trans_data(train_data_)
normal_dev_data = trans_data(dev_data_)
normal_test_data = trans_data(test_data_)

In [9]:
normal_test_data[-2]

{'id': 'APW_ENG_20030407.0030_45',
 'text': 'a correspondent for state - run russian television said the convoy was caught in a crossfire and three diplomats were hurt , one with a serious stomach wound .',
 'event_list': [{'trigger': 'crossfire',
   'trigger_type': 'Attack',
   'trigger_span': [83, 92],
   'argument_list': [{'text': 'convoy', 'span': [60, 66], 'type': 'Target'},
    {'text': 'diplomats', 'span': [103, 112], 'type': 'Target'},
    {'text': 'one', 'span': [125, 128], 'type': 'Target'}]},
  {'trigger': 'hurt',
   'trigger_type': 'Injure',
   'trigger_span': [118, 122],
   'argument_list': [{'text': 'diplomats',
     'span': [103, 112],
     'type': 'Victim'},
    {'text': 'one', 'span': [125, 128], 'type': 'Victim'}]}],
 'entity_list': [{'text': 'crossfire',
   'type': 'Attack',
   'char_span': [83, 92]},
  {'text': 'convoy', 'type': 'Target', 'char_span': [60, 66]},
  {'text': 'diplomats', 'type': 'Target', 'char_span': [103, 112]},
  {'text': 'one', 'type': 'Target', '