## 使用`Bert`进行命名实体识别
***
***
Time: 2020-09-19
Author: dsy
***

模块库导入

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud # 自定义数据集 
import math
from collections import Counter
import pandas as pd
from model.bert import BERT

常量定义

In [2]:
max_seq_length = 202
BATCH_SIZE = 8

数据处理

In [3]:
# 创建vocab
def readVocab(filepath:str="data/vocab.txt"):
    vocab = []
    with open(filepath,"r",encoding="utf-8") as f:
        for line in f:
            vocab.append(line.strip())
            
    with open("data/train.txt","r",encoding="utf-8") as fp:
        for line in fp:
            if '' != line.strip():
                vocab.append(line.strip().split(" ")[0])
            
    vocab = set(vocab)        
    vocab2id = { data:1+i for i,data in enumerate(vocab)}
    id2vocab = {1+i:data for i,data in enumerate(vocab)}
    return (vocab,vocab2id,id2vocab,len(vocab))

In [4]:
def handle(filepath:str="data/train.txt"):
    words = []
    labels = []
    vocab = []
    label_list = set()
    label_list.add("[CLS]")
    label_list.add("[SEP]")
    with open(filepath,"r",encoding="utf-8") as fp:
        word = []
        label = []
        for index,line in enumerate( fp):
            
            if '' == line.strip():
                words.append(word)
                labels.append(label)
#                 print("words:\n",pd.DataFrame(words))
#                 print("labels:\n",pd.DataFrame(labels))
                word = []
                label = []
            else:
                linesplit = line.strip().split(" ")
                word.append(linesplit[0])
                label.append(linesplit[1])
                label_list.add(linesplit[1])
                
    label2id = {data:i+1 for i,data in enumerate(label_list)}
    id2label = {1+i:data for i,data in enumerate(label_list)}
    
    input_ids = []
    for i in words:
        input_id = []
    #     break
        for j in i:
            input_id.append(vocab2id[j])
    #         break

        if len(input_id) > ( max_seq_length - 2):
            input_id = input_id[:max_seq_length - 2]
        input_id.insert(0,vocab2id["[CLS]"])   
        while len(input_id) < (max_seq_length - 1):
            input_id.append(0)
        input_id.append(vocab2id["[SEP]"])

        input_ids.append(input_id)
        
    label_ids = []
    for i in labels:
        label_id = []
    #     break
        for j in i:
            label_id.append(label2id[j])
    #         break

        if len(label_id) > ( max_seq_length - 2):
            label_id = label_id[:max_seq_length - 2]
        label_id.insert(0,label2id["[CLS]"])   
        while len(label_id) < (max_seq_length - 1):
            label_id.append(0)
        label_id.append(label2id["[SEP]"])

        label_ids.append(label_id)

    return (words,labels,label_list,label2id,id2label,len(label_list),input_ids,label_ids)

In [5]:
(vocab,vocab2id,id2vocab,vocab_size)= readVocab()

In [6]:
words,labels,label_list,label2id,id2label,label_size,input_ids,label_ids = handle()

In [7]:
input_ids = torch.Tensor(input_ids).long()
label_ids = torch.Tensor(label_ids).long()

自定义数据集

In [8]:
class BertNerDataset(tud.Dataset): 
    def __init__(self,input_ids,label_ids): 
        super(BertNerDataset,self).__init__() 
        self.label_ids = label_ids
        self.input_ids = input_ids
       
    def __len__(self): 
        return len(self.input_ids) 
    def __getitem__(self, index): 
        return (self.input_ids[index,:],self.label_ids[index,:])

dataset = BertNerDataset(input_ids,label_ids) 
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
class BertNer(nn.Module):
    def __init__(self,vocab_size):
        
        super(BertNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        
    def forward(self,x,segment_ids):
        y = self.bert(x,segment_ids) # 8,202,768
        y = y.permute(0,2,1) 
        return y
    

class BertBiLSTMNer(nn.Module):
    def __init__(self,vocab_size):
        super(BertBiLSTMNer,self).__init__()
        self.bert = BERT( vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1)
        self.bilstm = nn.LSTM(input_size = 768,hidden_size =768 //2, bidirectional =True)
        
    def forward(self,x,segment_ids):
        y =  self.bert(x,segment_ids)
        output,_ = self.bilstm(y)
        output = output.permute(0,2,1) 
        return output
 

模型训练

In [10]:
bertner = BertBiLSTMNer( vocab_size)
# bertner = BertNer( vocab_size)
optimizer = torch.optim.Adam(bertner.parameters(), lr=0.05)   # optimize all parameters
loss_func = nn.CrossEntropyLoss()   # the target label is not one-hotted

In [11]:
EPOCH = 1

In [None]:
for epoch in range(EPOCH):
    for step,(input_ids,label_ids) in enumerate(dataloader):
        y_pred = bertner(input_ids,torch.zeros(input_ids.size()).long())
#         print("y_pred",y_pred.shape)
#         print("label_ids",label_ids.shape)
        loss = loss_func(y_pred,label_ids)
        if 0 == (step+1) % 10 :
            print("epoch",epoch,"step:",step+1,",loss:",loss.item())
        
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradient
