-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_bert_emb.py
61 lines (47 loc) · 2.21 KB
/
gen_bert_emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.embeddings import CharLMEmbeddings, StackedEmbeddings, BertEmbeddings
import os
import pickle
import numpy as np
from io_utils import read_yaml, read_lines, read_json_lines
data_config = read_yaml('data_config.yaml')
data_dir = data_config['data_dir']
ace05_event_dir = data_config['ace05_event_dir']
train_list = read_json_lines(os.path.join(ace05_event_dir, 'train_nlp_ner.json'))
dev_list = read_json_lines(os.path.join(ace05_event_dir, 'dev_nlp_ner.json'))
test_list = read_json_lines(os.path.join(ace05_event_dir, 'test_nlp_ner.json'))
train_sent_file = data_config['train_sent_file']
bert = BertEmbeddings(layers='-1', bert_model_or_path='bert-base-uncased').to('cuda:0')
def save_bert(inst_list, filter_tri=True, name='train'):
sents = []
sent_lens = []
for inst in inst_list:
words, trigger_list, ent_list, arg_list = inst['nlp_words'], inst['Triggers'], inst['Entities'], inst['Arguments']
# Empirically filter out sentences where event size is 0 or entity size less than 3 (for traning)
if len(trigger_list) == 0 and len(ent_list) < 3 and filter_tri:
continue
sents.append(words)
sent_lens.append(len(words))
total_word_nums = sum(sent_lens)
input_table = np.empty((total_word_nums,768 * 1))
acc_len = 0
for i, words in enumerate(sents):
if i % 100 == 0:
print('progress: %d, %d' % (i, len(sents)))
sent_len = sent_lens[i]
flair_sent = Sentence(' '.join(words))
bert.embed(flair_sent)
for j, token in enumerate(flair_sent):
start = acc_len + j
input_table[start, :] = token.embedding.cpu().detach().numpy()
acc_len += sent_len
bert_fname = data_config['train_sent_file'] if name == 'train' else \
data_config['dev_sent_file'] if name == 'dev' else data_config['test_sent_file']
np.save(bert_fname, input_table)
print('total_word_nums:', total_word_nums)
#print(len(sent_lens))
if __name__ == "__main__":
save_bert(train_list, name='train')
save_bert(dev_list, filter_tri=False, name='dev')
save_bert(test_list, filter_tri=False, name='test')