In [15]:
import numpy as np
from sklearn.model_selection import ShuffleSplit
from data_utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
from data_utils import Evaluator
from models import build_lstm_crf_model
from gensim.models import Word2Vec

In [16]:
data_dir = 'brat/'
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])

In [17]:
#载入文件
docs = Documents(data_dir=data_dir)
#分配训练集和测试集
#自动迭代，返回id
rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2018)
train_doc_ids, test_doc_ids = next(rs.split(docs))
train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]

In [21]:
#实体个数
num_cates = max(ent2idx.values()) + 1
#句子长度
sent_len = 64
vocab_size = 3000
#向量长度
emb_size = 100
#两头padding长度
sent_pad = 10
#抽取句子
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)
train_sents = sent_extrator(train_docs)
test_sents = sent_extrator(test_docs)
train_data = Dataset(train_sents, cate2idx=ent2idx)
train_data.build_vocab_dict(vocab_size=vocab_size)
test_data = Dataset(test_sents, word2idx=train_data.word2idx, cate2idx=ent2idx)
vocab_size = len(train_data.word2idx)

In [22]:
w2v_train_sents = []
for doc in docs:
    w2v_train_sents.append(list(doc.text))
    
    
w2v_model = Word2Vec(w2v_train_sents, size=emb_size)
w2v_embeddings = np.zeros((vocab_size, emb_size))
for char, char_idx in train_data.word2idx.items():
    if char in w2v_model.wv:
        w2v_embeddings[char_idx] = w2v_model.wv[char]

In [23]:
#句子的长度前后都padding
seq_len = sent_len + 2 * sent_pad
model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size, 
                             model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.summary()



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 84)                0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 84, 100)           15100     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 84, 512)           731136    
_________________________________________________________________
crf_1 (CRF)                  (None, 84, 40)            22200     
Total params: 768,436
Trainable params: 753,336
Non-trainable params: 15,100
_________________________________________________________________


In [24]:
train_X, train_y = train_data[:]
print('train_X.shape', train_X.shape)
print('train_y.shape', train_y.shape)

train_X.shape (50146, 84)
train_y.shape (50146, 84, 1)


In [None]:
model.fit(train_X,train_y, batch_size=64, epochs=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
10816/50146 [=====>........................] - ETA: 4:45 - loss: 0.0537 - crf_viterbi_accuracy: 0.9580

In [None]:
test_X, _ = test_data[:]
preds = model.predict(test_X, batch_size=64, verbose=True)
pred_docs = make_predictions(preds, test_data, sent_pad, docs, idx2ent)

In [None]:
f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
print('f_score: ', f_score)
print('precision: ', precision)
print('recall: ', recall)

In [None]:
#处理类别标签

import os
data_dir = 'brat/'
sec_doc_ids = [data_dir+fname.split('.')[0]+".ann" for fname in os.listdir(data_dir)]
sec_doc_ids= np.unique(sec_doc_ids)

In [None]:
#load file
def txt_strtonum_feed(filename):
    print(filename)
    data = []
    flag=True
    with open(filename, 'r', encoding='UTF-8') as f:#with auto call close()
        line = f.readline()
        
        while line:
            read_data = line.split('\n')
            if flag:
                flag=False
                line = f.readline()
                continue
                
            read_data = read_data[0].split('\t')
            #"brat/12585968.ann"末尾存在不符合规则的内容
            if len(read_data) < 3:
                line = f.readline()
                continue
            
            first = read_data[1].split()[0].split(":")[0]
            second = read_data[1].split()[0]
            third=read_data[2]
            data.append((first,second,third))
            line = f.readline()
        return data

In [None]:
data_ = []
for i in sec_doc_ids:
    data_+=(txt_strtonum_feed(i))

In [None]:
uni_data_ = list(set(data_))

In [None]:
with open("unidata_class.txt", 'w', encoding='UTF-8') as f:
    for i in uni_data_:
        data_str = i[1]+"\t"+i[0]+"\t"+i[2]+"\n"
        f.write(data_str)
        
        

In [None]:
#获取bert词向量
from bert_serving.client import BertClient

In [8]:
bc = BertClient()

In [None]:
#load file
def txt_strtonum_feed_load(filename):
    load_data1 = []
    load_data2 = []
    flag=True
    with open(filename, 'r', encoding='UTF-8') as f:#with auto call close()
        line = f.readline()
        while line:
            read_line = line.split("\t")
            load_data1.append(read_line[0])
            load_data2.append(read_line[1]+" "+read_line[2][:-1] )
            line = f.readline()
        return load_data1,load_data2

In [None]:
load_data1,load_data2 = txt_strtonum_feed_load("unidata_class.txt")

In [11]:
for i in range(1,129):
    b = bc.encode(load_data2[i*100:(i*100+100)]) 
    a = np.concatenate((a,b))
    print(i)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


In [12]:
b = bc.encode(load_data2[12900:12971]) 
a = np.concatenate((a,b))

In [21]:
e = bc.encode([load_data2[12970]]) 

array([[ 0.14445421,  0.31843638, -0.52057135, ...,  0.2843102 ,
         0.16111115,  0.1876401 ]], dtype=float32)

In [22]:
e

array([[ 0.10902411,  0.21076483, -0.4496503 , ..., -0.04975718,
         0.20783173, -0.08141408]], dtype=float32)

In [23]:
a = np.concatenate((a,e))

In [None]:
import pickle

In [34]:
fw = open('dataFile.txt','wb')
pickle.dump(a, fw)
fw.close()


In [None]:
fw = open('dataFilepred_docs.txt','wb')
pickle.dump(pred_docs, fw)
fw.close()

In [None]:
fw = open('dataFiletest_docs.txt','wb')
pickle.dump(test_docs, fw)
fw.close()

In [None]:
fw = open('detaillabel.txt','wb')
pickle.dump(load_data1, fw)
fw.close()

In [None]:
fw = open('labellist.txt','wb')
pickle.dump(load_data2, fw)
fw.close()