# CRF+LSTM

keras 2.2.4

tensorflow 1.13

pip install git+https://www.github.com/keras-team/keras-contrib.git

In [1]:
import re
import os
import pandas as pd

In [2]:
char_vocab_path = "CRF/data/char_vocabs.txt" # 字典文件
#train_data_path = 'data/train_data/train_data_000' # 训练数据
#train_data_path = './data/train_data' # 训练数据
#test_data_path = 'data/train_data/train_data_000' # 测试数据

special_words = ['<PAD>', '<UNK>'] # 特殊词表示

# "BIO"标记的标签
#label2idx = {"O": 0,
#             "B-PER": 1, "I-PER": 2,
#             "B-LOC": 3, "I-LOC": 4,
#             "B-ORG": 5, "I-ORG": 6
#            }
label2idx = {'O': 0,
             'B-DISEASE': 1, 'B-DISEASE_GROUP': 2,
             'B-DRUG_DOSAGE': 3, 'B-DRUG_EFFICACY': 4,
             'B-DRUG_INGREDIENT': 5, 'B-DRUG_TASTE': 6,
             'B-FOOD_GROUP':7, 'B-PERSON_GROUP':8,
             'B-SYMPTOM':9, 'B-SYNDROME':10,
             'I-DISEASE': 11, 'I-DISEASE_GROUP': 12,
             'I-DRUG_DOSAGE': 13, 'I-DRUG_EFFICACY': 14,
             'I-DRUG_INGREDIENT': 15, 'I-DRUG_TASTE': 16,
             'I-FOOD_GROUP':17, 'I-PERSON_GROUP':18,
             'I-SYMPTOM':19, 'I-SYNDROME':20
            }

# 索引和BIO标签对应
idx2label = {idx: label for label, idx in label2idx.items()}

# 读取字符词典文件
with open(char_vocab_path, "r", encoding="utf8") as fo:
    char_vocabs = [line.strip() for line in fo]
char_vocabs = special_words + char_vocabs

# 字符和索引编号对应
idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}
vocab2idx = {char: idx for idx, char in idx2vocab.items()}

In [3]:
# 读取训练语料
def read_corpus(corpus_path, vocab2idx, label2idx):
    with open(corpus_path, encoding='utf-8') as fr:
        lines = fr.readlines()

    sent_, tag_ = [], []
    for letter in lines:
        [char,label,_] = re.split('\t|\n',letter)
        char = re.sub(' |\*|<|>','_',char)
        sent_.append(char)
        tag_.append(label)

    sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
    tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]
    return sent_ids, tag_ids

# 加载训练集
#train_datas, train_labels = read_corpus(train_data_path, vocab2idx, label2idx)
# 加载测试集
#test_datas, test_labels = read_corpus(test_data_path, vocab2idx, label2idx)


In [4]:
train_datas = []
train_labels = []
files = os.listdir('data/train_data')
for file in files:
    train_data_path_i = 'data/train_data/'+file
    train_datas_i, train_labels_i = read_corpus(train_data_path_i, vocab2idx, label2idx)
    train_datas.append(train_datas_i)
    train_labels.append(train_labels_i)
    #if i%10==0:
    #    print(i)

In [5]:
valid_datas = []
valid_labels = []
files = os.listdir('data/valid_data')
for file in files:
    valid_data_path_i = 'data/valid_data/'+file
    valid_datas_i, valid_labels_i = read_corpus(valid_data_path_i, vocab2idx, label2idx)
    valid_datas.append(valid_datas_i)
    valid_labels.append(valid_labels_i)

In [6]:
print(train_datas[50])
print([idx2vocab[idx] for idx in train_datas[50]])
print(train_labels[50])
print([idx2label[idx] for idx in train_labels[50]])

[58, 61, 77, 1, 58, 17, 181, 3093, 3817, 2654, 6214, 1959, 2177, 286, 6802, 5965, 519, 1408, 2644, 2102, 2732, 1842, 889, 2545, 3093, 3817]
['_', 'b', 'r', '<UNK>', '_', '3', '、', '治', '疗', '期', '间', '忌', '房', '事', '，', '配', '偶', '如', '有', '感', '染', '应', '同', '时', '治', '疗']
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [7]:
import numpy as np
import keras
from keras.models import Sequential
from keras.models import Model
from keras.layers import Masking, Embedding, Bidirectional, LSTM, Dense, Input, TimeDistributed, Activation
from keras.preprocessing import sequence
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_viterbi_accuracy
from keras import backend as K
K.clear_session()

EPOCHS = 30
BATCH_SIZE = 64
EMBED_DIM = 48
HIDDEN_SIZE = 8
MAX_LEN = 100
VOCAB_SIZE = len(vocab2idx)
CLASS_NUMS = len(label2idx)
print(VOCAB_SIZE, CLASS_NUMS)

print('padding sequences')
train_datas = sequence.pad_sequences(train_datas, maxlen=MAX_LEN)
train_labels = sequence.pad_sequences(train_labels, maxlen=MAX_LEN)
valid_datas = sequence.pad_sequences(valid_datas, maxlen=MAX_LEN)
valid_labels = sequence.pad_sequences(valid_labels, maxlen=MAX_LEN)
print('x_train shape:', train_datas.shape)
print('x_test shape:', valid_datas.shape)

train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS)
valid_labels = keras.utils.to_categorical(valid_labels, CLASS_NUMS)
print('trainlabels shape:', train_labels.shape)
print('testlabels shape:', valid_labels.shape)

## BiLSTM+CRF模型构建
inputs = Input(shape=(MAX_LEN,), dtype='int32')
x = Masking(mask_value=0)(inputs)
x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=True)(x)
x = Bidirectional(LSTM(HIDDEN_SIZE, return_sequences=True))(x)
x = TimeDistributed(Dense(CLASS_NUMS))(x)#TimeDistributed层的作用就是把Dense层应用到这10个具体的向量上，对每一个向量进行了一个Dense操作
outputs = CRF(CLASS_NUMS)(x)
model = Model(inputs=inputs, outputs=outputs)
model.summary()

model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy])
model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1)

score = model.evaluate(valid_datas, valid_labels, batch_size=BATCH_SIZE)
print(model.metrics_names)
print(score)

Using TensorFlow backend.





6874 21
padding sequences
x_train shape: (6899, 100)
x_test shape: (3974, 100)
trainlabels shape: (6899, 100, 21)
testlabels shape: (3974, 100, 21)


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
masking_1 (Masking)          (None, 100)               0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 100, 48)           329952    
_________________________________________________________________
bidirectional_1 (Bidirection (None, 100, 16)           3648      
_________________________________________________________________
time_distributed_1 (TimeDist (None, 100, 21)           357       
___________________________

Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
['loss', 'crf_viterbi_accuracy']
[7.449981505389663, 0.8925154830381147]


In [8]:
# save model
model.save("model/ch_ner_model.h5")

In [11]:
def get_valid_nertag(input_data, result_tags):
    result_words = []
    start, end =0, 1 # 实体开始结束位置标识
    tag_label = "O" # 实体类型标识
    number = 0
    for i, tag in enumerate(result_tags):
        if tag.startswith("B"):
            number += 1
            if tag_label != "O": # 当前实体tag之前有其他实体     
                result_words.append(('T'+str(number), tag_label, start, end,input_data[start: end]))
                #result_words.append(('T'+str(number), tag_label+' '+str(start)+' '+str(end),input_data[start: end])) # 获取实体
            tag_label = tag.split("-")[1] # 获取当前实体类型
            start, end = i, i+1 # 开始和结束位置变更
        elif tag.startswith("I"):
            temp_label = tag.split("-")[1]
            if temp_label == tag_label: # 当前实体tag是之前实体的一部分
                end += 1 # 结束位置end扩展
        elif tag == "O":
            if tag_label != "O": # 当前位置非实体 但是之前有实体
                #result_words.append(('T'+str(number), tag_label+' '+str(start)+' '+str(end),input_data[start: end])) # 获取实体
                result_words.append(('T'+str(number), tag_label, start, end,input_data[start: end]))
                tag_label = "O"  # 实体类型置"O"
            start, end = i, i+1 # 开始和结束位置变更
    if tag_label != "O": # 最后结尾还有实体
        number += 1
        result_words.append(('T'+str(number), tag_label, start, end,input_data[start: end]))
        #result_words.append(('T'+str(number),tag_label+' '+str(start)+' '+str(end),input_data[start: end])) # 获取结尾的实体
    return result_words

In [None]:
maxlen = 100
result = {}
test_data_path = 'data/chusai_xuanshou/'
#test_data_path = 'data/valid/'
for i in range(500):
#for i in range(7):
    test_file = test_data_path+str(i+1000)+'.txt'
    with open(test_file, "r", encoding="utf8") as test:
        sentence = test.read()
    sentences = sentence.split('。')
    y_ner = []

    for sent in sentences:
        sent = sent.replace(' ','_')
        sent_chars = list(sent+'。')
        sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in sent_chars]

        sent2id_new = np.array([[0] * (maxlen-len(sent2id)) + sent2id[:maxlen]])
        y_pred = model.predict(sent2id_new)
        y_label = np.argmax(y_pred, axis=2)
        y_label = y_label.reshape(1, -1)[0]
        y_ner_ = [idx2label[i] for i in y_label][-len(sent_chars):]
        y_ner.extend(y_ner_)
    result_words = get_valid_nertag(sentence, y_ner)
    ans = []
    #print(i+993)
    for res in result_words:
        number = res[0]
        #tag_start_end = res[1]
        tag = res[1]
        start = res[2]
        end = res[3]
        word = res[4].replace(' ','_')
        ans.append('{}\t{} {} {}\t{}'.format('T'+str(len(ans)+1), tag,start, end, word))
        #ans.append(, tag, start,end, "".join(word))
        print('{}\t{} {} {}\t{}'.format('T'+str(len(ans)), tag,start, end, word))
    #print('='*100)
    
    result[i+1000] = ans

T1	SYMPTOM 80 83	胸卧位
T2	DRUG_EFFICACY 148 152	祛瘀止痛
T3	DRUG_EFFICACY 153 156	散结消
T4	SYMPTOM 167 171	小腹疼痛
T5	SYMPTOM 172 176	腰骶酸痛
T6	SYMPTOM 177 181	带下量多
T1	PERSON_GROUP 31 33	孕妇
T2	DRUG_EFFICACY 76 78	补血
T3	DRUG_EFFICACY 79 81	活血
T4	SYMPTOM 87 91	月经量少
T5	SYMPTOM 95 101	血虚萎黄后错
T6	SYMPTOM 102 106	血虚萎黄
T7	SYMPTOM 107 111	风湿痹痛
T8	SYMPTOM 112 116	肢体麻木
T9	DISEASE 116 119	糖尿病
T1	FOOD_GROUP 25 27	辛辣
T2	FOOD_GROUP 30 32	油腻
T3	DISEASE 36 38	感冒
T4	SYMPTOM 38 40	发热
T5	DISEASE 51 54	高血压
T6	DISEASE_GROUP 55 58	心脏病
T7	DISEASE_GROUP 59 61	肝病
T8	DISEASE 62 65	糖尿病
T9	DISEASE_GROUP 66 68	肾病
T10	DISEASE_GROUP 69 72	慢性病
T11	SYMPTOM 90 94	月经紊乱
T12	SYMPTOM 109 111	眩晕
T13	PERSON_GROUP 158 163	过敏体质者
T14	PERSON_GROUP 192 194	儿童
T15	DRUG_DOSAGE 242 244	颗粒
T16	DRUG_TASTE 245 248	气微香
T17	DRUG_TASTE 249 252	味微苦
T18	SYNDROME 274 276	阴虚
T19	DISEASE_GROUP 276 279	肝旺症
T20	SYMPTOM 282 286	烘热汗出
T21	SYMPTOM 287 291	头晕耳鸣
T22	SYMPTOM 292 296	失眠多梦
T23	SYMPTOM 297 301	五心烦热
T24	SYMPTOM 302 306	腰背酸痛
T25	SYMPTOM 307 311	大便干燥
T26	

T1	PERSON_GROUP 90 92	孕妇
T2	DRUG_EFFICACY 108 110	活血
T3	DRUG_EFFICACY 111 113	化瘀
T4	SYMPTOM 121 125	宿有癥块
T5	SYMPTOM 127 131	血瘀经闭
T6	SYMPTOM 132 136	行经腹痛
T7	SYMPTOM 137 143	产后恶露不尽
T8	DRUG_DOSAGE 194 197	大蜜丸
T9	DRUG_TASTE 198 200	味甜
T10	DRUG_DOSAGE 203 205	丸剂
T11	SYMPTOM 227 231	行经腹痛
T12	SYNDROME 238 240	血瘀
T13	SYMPTOM 240 242	经闭
T14	SYMPTOM 243 247	行经腹痛
T15	SYMPTOM 248 254	产后恶露不尽
T1	PERSON_GROUP 8 10	孕妇
T2	DRUG_DOSAGE 22 24	颗粒
T3	DRUG_TASTE 25 27	气微
T4	DRUG_TASTE 28 30	味甜
T5	DRUG_TASTE 31 33	微涩
T6	DRUG_TASTE 34 36	微苦
T7	DRUG_EFFICACY 40 43	清湿热
T8	SYNDROME 54 56	湿热
T9	DRUG_EFFICACY 56 59	清湿热
T10	DRUG_EFFICACY 60 65	止带下头晕
T11	PERSON_GROUP 100 102	孕妇
T1	DRUG_DOSAGE 24 27	水蜜丸
T2	DRUG_DOSAGE 74 77	大蜜丸
T3	DRUG_TASTE 78 80	味甘
T4	DRUG_TASTE 81 84	苦味甘
T5	PERSON_GROUP 89 91	孕妇
T6	FOOD_GROUP 111 113	生冷
T7	SYMPTOM 131 135	月经不调
T8	SYMPTOM 136 144	经期不准行经腹痛
T9	DRUG_EFFICACY 146 150	养血调经
T10	SYMPTOM 151 155	舒郁化滞
T11	SYNDROME 158 162	气虚血寒
T12	SYNDROME 163 167	肝郁不舒
T13	SYMPTOM 170 174	经期不准
T14	SYMPTOM 17

T1	DRUG_EFFICACY 0 4	活血调经
T2	SYMPTOM 7 11	月经量少
T3	SYMPTOM 12 16	产后腹痛
T4	PERSON_GROUP 16 18	儿童
T5	PERSON_GROUP 30 32	孕妇
T6	FOOD_GROUP 39 41	生冷
T7	SYNDROME 46 50	气血两虚
T8	SYNDROME 57 61	气血两虚
T9	SYMPTOM 64 68	月经量少
T10	SYMPTOM 76 80	头晕心悸
T11	SYMPTOM 81 85	疲乏无力
T12	DISEASE 96 99	高血压
T13	DISEASE_GROUP 100 103	心脏病
T14	DISEASE_GROUP 104 106	肾病
T15	DISEASE 107 110	糖尿病
T16	SYMPTOM 147 150	经量少
T17	PERSON_GROUP 160 165	青春期少女
T18	PERSON_GROUP 166 171	更年期妇女
T19	SYMPTOM 188 190	腹痛
T20	SYMPTOM 192 196	阴道出血
T21	PERSON_GROUP 249 252	过敏者
T22	PERSON_GROUP 255 260	过敏体质者
T23	PERSON_GROUP 290 292	儿童
T24	DRUG_DOSAGE 332 335	软胶囊
T25	DRUG_DOSAGE 350 352	液体
T26	DRUG_TASTE 353 356	气微香
T27	DRUG_TASTE 357 359	味苦
T28	PERSON_GROUP 379 381	孕妇
T1	DRUG_EFFICACY 6 10	养血安神
T2	SYNDROME 13 17	脾胃虚弱
T3	SYNDROME 18 22	心脾两虚
T4	SYNDROME 25 28	血虚证
T5	SYMPTOM 31 39	面色萎黄或白光白
T6	SYMPTOM 40 44	食少纳呆
T7	SYMPTOM 45 49	脘腹胀闷
T8	SYMPTOM 50 54	大便不调
T9	SYMPTOM 55 59	烦躁多汗
T10	SYMPTOM 60 64	倦怠乏力
T11	SYMPTOM 65 69	舌胖色淡
T12	SYMPTOM 70 73	苔薄白
T13	

T1	PERSON_GROUP 18 23	月经过多者
T2	DRUG_EFFICACY 28 32	活血调经
T3	DRUG_EFFICACY 33 37	行气止痛
T4	SYNDROME 40 42	气滞
T5	SYNDROME 42 44	血瘀
T6	SYMPTOM 46 50	月经不调
T7	SYMPTOM 51 53	痛经
T8	SYMPTOM 73 75	痛经
T9	SYMPTOM 105 109	重度痛经
T10	PERSON_GROUP 186 189	过敏者
T11	PERSON_GROUP 192 197	过敏体质者
T12	PERSON_GROUP 224 226	儿童
T13	DRUG_EFFICACY 262 266	活血调经
T14	SYMPTOM 269 273	月经不调
T15	SYMPTOM 274 276	痛经
T16	SYMPTOM 277 281	产后腹痛
T1	DRUG_EFFICACY 6 10	通络下乳
T2	SYNDROME 15 19	气血虚弱
T3	SYMPTOM 25 27	少乳
T4	PERSON_GROUP 87 92	哺乳疾病者
T5	SYNDROME 101 105	气血虚弱
T6	DRUG_DOSAGE 160 162	颗粒
T7	DRUG_TASTE 163 165	味甜
T8	DRUG_TASTE 166 168	微苦
T1	DRUG_EFFICACY 54 58	抗炎镇痛
T2	DRUG_EFFICACY 64 73	促进生殖系统微循环
T3	DRUG_EFFICACY 74 80	化解体内肿块
T4	DRUG_EFFICACY 83 88	增强免疫力
T5	DRUG_DOSAGE 198 202	薄膜衣片
T6	DRUG_TASTE 219 222	气微香
T7	DRUG_TASTE 223 227	味苦微咸
T8	SYMPTOM 233 236	瘕积聚
T9	SYMPTOM 237 241	痛经闭经
T10	SYMPTOM 242 246	赤白带下
T11	DISEASE 247 252	慢性盆腔炎
T12	DRUG_EFFICACY 261 265	活血调经
T13	DRUG_EFFICACY 268 270	止痛
T14	DRUG_EFFICACY 271 275	软坚散结
T15	SYM

T1	SYNDROME 3 7	气血两虚
T2	SYMPTOM 10 14	月经不调
T3	SYMPTOM 17 23	月经周期错后
T4	SYNDROME 23 27	气血两虚
T5	SYMPTOM 30 34	月经不调
T6	SYMPTOM 37 43	月经周期错后
T7	SYMPTOM 44 48	行经量少
T8	SYMPTOM 49 53	精神不振
T9	DRUG_EFFICACY 107 111	益气养血
T10	DRUG_EFFICACY 112 116	活血调经
T11	SYNDROME 119 123	气血两虚
T12	SYNDROME 125 127	血瘀
T13	SYMPTOM 130 134	月经不调
T14	SYMPTOM 137 143	月经周期错后
T15	SYMPTOM 144 148	行经量少
T16	SYMPTOM 149 153	精神不振
T17	SYMPTOM 154 158	肢体乏力
T18	SYMPTOM 158 162	月经不调
T19	DRUG_EFFICACY 163 166	补气血
T1	DRUG_EFFICACY 53 57	滋阴清热
T2	DRUG_EFFICACY 58 62	除烦安神
T3	SYMPTOM 71 75	潮热汗出
T4	SYMPTOM 76 78	眩晕
T5	SYMPTOM 79 81	耳鸣
T6	SYMPTOM 151 155	潮热汗出
T7	SYMPTOM 156 158	眩晕
T8	SYMPTOM 159 161	耳鸣
T9	SYMPTOM 162 164	失眠
T10	SYMPTOM 165 169	烦燥不安
T1	SYMPTOM 78 82	潮热汗出
T2	SYMPTOM 83 85	眩晕
T3	SYMPTOM 86 88	耳鸣
T4	FOOD_GROUP 115 117	辛辣
T5	FOOD_GROUP 120 122	油腻
T6	DISEASE 125 127	感冒
T7	SYMPTOM 127 129	发热
T8	DISEASE 139 142	高血压
T9	DISEASE_GROUP 143 146	心脏病
T10	DISEASE_GROUP 147 149	肝病
T11	DISEASE 150 153	糖尿病
T12	DISEASE_GROUP 154 156	肾病
T13	

T1	DRUG_DOSAGE 1 3	丸剂
T2	DRUG_DOSAGE 4 7	水蜜丸
T3	PERSON_GROUP 10 12	孕妇
T4	DRUG_EFFICACY 86 90	活血化瘀
T5	DRUG_EFFICACY 91 95	缓消瘀块
T6	SYMPTOM 100 104	宿有血块
T7	SYMPTOM 108 112	漏下不止
T8	SYMPTOM 113 117	胎动不安
T9	SYMPTOM 119 123	血瘀经闭
T10	SYMPTOM 124 128	行经腹痛
T11	SYMPTOM 168 172	宿有瘕块
T12	SYMPTOM 174 178	血瘀经闭
T13	SYMPTOM 179 183	行经腹痛
T14	SYMPTOM 184 190	产后恶露不尽
T1	PERSON_GROUP 91 93	孕妇
T2	DRUG_INGREDIENT 148 150	人参
T3	DRUG_INGREDIENT 151 153	白芍
T4	DRUG_INGREDIENT 154 157	反藜芦
T5	DRUG_INGREDIENT 161 163	藜芦
T6	DRUG_INGREDIENT 175 177	甘草
T7	DRUG_INGREDIENT 178 181	反甘遂
T8	DRUG_INGREDIENT 182 184	大戟
T9	DRUG_INGREDIENT 185 187	海藻
T10	DRUG_INGREDIENT 188 190	芫花
T11	DRUG_INGREDIENT 194 196	甘遂
T12	DRUG_INGREDIENT 197 199	大戟
T13	DRUG_INGREDIENT 200 202	海藻
T14	DRUG_INGREDIENT 203 205	芫花
T15	FOOD_GROUP 218 220	生冷
T16	FOOD_GROUP 221 223	辛辣
T17	FOOD_GROUP 224 226	荤腥
T18	FOOD_GROUP 226 228	油腻
T19	DRUG_INGREDIENT 277 279	皂荚
T20	PERSON_GROUP 334 336	孕妇
T21	DRUG_EFFICACY 340 344	补气养血
T22	DRUG_EFFICACY 345 349	调经止带
T23	

In [None]:
#“实体类别”、“起始位置”、“结束位置”以空格分隔

In [17]:
for i in range(1000,1500):
    with open('data/submit/%d.ann'%i,'w', encoding='utf-8') as wr:
        wr.write('\n'.join(result[i]))

KeyError: 1000

In [None]:
#for i in range(1000,1500):
#    pd.DataFrame(result[i]).to_csv('data/submit/%d.ann'%i,
#                                      sep='\t',
#                                      header = None,
#                                      index = 0,
#                                      encoding = 'utf-8')