## 使用`Bert`进行命名实体识别
***
***
Time: 2020-09-19
Author: dsy
***

模块库导入

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud # 自定义数据集 
import math
from collections import Counter
import pandas as pd
from model.bert import BERT

常量定义

In [2]:
max_seq_length = 202
BATCH_SIZE = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

数据处理

In [3]:
# 创建vocab
def readVocab(filepath:str="data/vocab.txt"):
    vocab = []
    with open(filepath,"r",encoding="utf-8") as f:
        for line in f:
            vocab.append(line.strip())
            
    with open("data/train.txt","r",encoding="utf-8") as fp:
        for line in fp:
            if '' != line.strip():
                vocab.append(line.strip().split(" ")[0])
            
    vocab = set(vocab)        
    vocab2id = { data:1+i for i,data in enumerate(vocab)}
    id2vocab = {1+i:data for i,data in enumerate(vocab)}
    return (vocab,vocab2id,id2vocab,len(vocab))

In [4]:
def handle(filepath:str="data/train.txt"):
    words = []
    labels = []
    vocab = []
    label_list = set()
    label_list.add("[CLS]")
    label_list.add("[SEP]")
    with open(filepath,"r",encoding="utf-8") as fp:
        word = []
        label = []
        for index,line in enumerate( fp):
            
            if '' == line.strip():
                words.append(word)
                labels.append(label)
#                 print("words:\n",pd.DataFrame(words))
#                 print("labels:\n",pd.DataFrame(labels))
                word = []
                label = []
            else:
                linesplit = line.strip().split(" ")
                word.append(linesplit[0])
                label.append(linesplit[1])
                label_list.add(linesplit[1])
                
    label2id = {data:i+1 for i,data in enumerate(label_list)}
    id2label = {1+i:data for i,data in enumerate(label_list)}
    
    input_ids = []
    for i in words:
        input_id = []
    #     break
        for j in i:
            input_id.append(vocab2id[j])
    #         break

        if len(input_id) > ( max_seq_length - 2):
            input_id = input_id[:max_seq_length - 2]
        input_id.insert(0,vocab2id["[CLS]"])   
        while len(input_id) < (max_seq_length - 1):
            input_id.append(0)
        input_id.append(vocab2id["[SEP]"])

        input_ids.append(input_id)
        
    label_ids = []
    for i in labels:
        label_id = []
    #     break
        for j in i:
            label_id.append(label2id[j])
    #         break

        if len(label_id) > ( max_seq_length - 2):
            label_id = label_id[:max_seq_length - 2]
        label_id.insert(0,label2id["[CLS]"])   
        while len(label_id) < (max_seq_length - 1):
            label_id.append(0)
        label_id.append(label2id["[SEP]"])

        label_ids.append(label_id)

    return (words,labels,label_list,label2id,id2label,len(label_list),input_ids,label_ids)

In [5]:
(vocab,vocab2id,id2vocab,vocab_size)= readVocab()

In [6]:
words,labels,label_list,label2id,id2label,label_size,input_ids,label_ids = handle()

In [7]:
input_ids = torch.Tensor(input_ids).long()
label_ids = torch.Tensor(label_ids).long()

自定义数据集

In [8]:
class BertNerDataset(tud.Dataset): 
    def __init__(self,input_ids,label_ids): 
        super(BertNerDataset,self).__init__() 
        self.label_ids = label_ids
        self.input_ids = input_ids
       
    def __len__(self): 
        return len(self.input_ids) 
    def __getitem__(self, index): 
        return (self.input_ids[index,:],self.label_ids[index,:])

dataset = BertNerDataset(input_ids,label_ids) 
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
class BertNer(nn.Module):
    def __init__(self,vocab_size):
        
        super(BertNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        
    def forward(self,x,segment_ids):
        y = self.bert(x,segment_ids) # 8,202,768
        y = y.permute(0,2,1) 
        return y
    

class BertBiLSTMNer(nn.Module):
    def __init__(self,vocab_size):
        super(BertBiLSTMNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        self.bilstm = nn.LSTM(input_size = 768,hidden_size =768 //2, bidirectional =True)
        
    def forward(self,x,segment_ids):
        y =  self.bert(x,segment_ids)
        output,_ = self.bilstm(y)
        output = output.permute(0,2,1) 
        return output
    
class BertBiGRUNer(nn.Module):
    def __init__(self,vocab_size):
        super(BertBiGRUNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        self.bigru = nn.GRU(input_size=768,hidden_size=768//2,bidirectional=True)

    def forward(self,x,segment_ids):
        y = self.bert(x,segment_ids)
        output,_ = self.bigru(y)        
        output = output.permute(0,2,1) 
        return output
        
 

In [10]:
# 自注意模型
from  torch.nn.parameter import Parameter
class SelfAttention(nn.Module):
    def __init__(self,embed_dim):
        super(SelfAttention,self).__init__()
        self.embed_dim = embed_dim
        self.selfattention = nn.MultiheadAttention(embed_dim, num_heads=1, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
        
    def forward(self,x):
        L,N,E = x.shape
        W1 = Parameter(torch.empty((L,N,N)))
        W2 = Parameter(torch.empty((L,N,N)))
        W3 = Parameter(torch.empty((L,N,N)))
        
        std = 1./ math.sqrt(self.embed_dim)
        
        nn.init.uniform_(W1,-std,std)
        nn.init.uniform_(W2,-std,std)
        nn.init.uniform_(W3,-std,std)
        
        query = W1.matmul(x) # (L, N, E)
        key = W2.matmul(x) # (S,N,E)
        value = W3.matmul(x) # (S,N,E)
        attn_output,_ = self.selfattention(query, key, value) # (L,N,E)
        return attn_output
    
class BertBiGRUSelfAttentionNer(nn.Module):
    def __init__(self,vocab_size,embed_dim=768):
        super(BertBiGRUSelfAttentionNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        self.bigru = nn.GRU(input_size=768,hidden_size=768//2,bidirectional=True)
        self.selfattention = SelfAttention(embed_dim)

    def forward(self,x,segment_ids):
        y = self.bert(x,segment_ids)
        output,_ = self.bigru(y)
        output = self.selfattention(output)
        output = output.permute(0,2,1) 
        return output        

模型训练

In [11]:
# bertner = BertBiLSTMNer( vocab_size)
# bertner = BertNer( vocab_size)
# bertner = BertBiGRUNer(vocab_size)
bertner = BertBiGRUSelfAttentionNer(vocab_size,embed_dim=768)
optimizer = torch.optim.Adam(bertner.parameters(), lr=0.05)   # optimize all parameters
loss_func = nn.CrossEntropyLoss()   # the target label is not one-hotted

In [12]:
EPOCH = 1

In [None]:
for epoch in range(EPOCH):
    for step,(input_ids,label_ids) in enumerate(dataloader):
        y_pred = bertner(input_ids,torch.zeros(input_ids.size()).long())
#         print("y_pred",y_pred.shape)
#         print("label_ids",label_ids.shape)
        loss = loss_func(y_pred,label_ids)
        if 0 == (step+1) % 10 :
            print("epoch",epoch,"step:",step+1,",loss:",loss.item())
        
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradient


epoch 0 step: 10 ,loss: 70.48233795166016
epoch 0 step: 20 ,loss: 64.64908599853516
epoch 0 step: 30 ,loss: 37.792694091796875
epoch 0 step: 40 ,loss: 22.66753387451172
epoch 0 step: 50 ,loss: 18.804529190063477
epoch 0 step: 60 ,loss: 31.839941024780273
epoch 0 step: 70 ,loss: 19.860679626464844
epoch 0 step: 80 ,loss: 16.839797973632812
epoch 0 step: 90 ,loss: 11.174281120300293
epoch 0 step: 100 ,loss: 5.3015546798706055
epoch 0 step: 110 ,loss: 17.39335823059082
epoch 0 step: 120 ,loss: 9.268795013427734
epoch 0 step: 130 ,loss: 13.026122093200684
epoch 0 step: 140 ,loss: 13.76448917388916
epoch 0 step: 150 ,loss: 15.671184539794922
epoch 0 step: 160 ,loss: 20.138240814208984
epoch 0 step: 170 ,loss: 213.03688049316406
