## 使用`AlBert`进行命名实体识别
***
***
Time: 2020-09-21
Author: dsy
***

模块库导入

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud # 自定义数据集 
import math
import pandas as pd
from model.torchcrf import CRF
import os
from model.albert_pytorch.modeling_albert import AlbertConfig, AlbertForPreTraining,AlbertModel

常量定义

In [2]:
max_seq_length = 202
BATCH_SIZE = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
config_file = "model/albert_pytorch/prev_trained_model/albert_base_v2/config.json"
config_file

'model/albert_pytorch/prev_trained_model/albert_base_v2/config.json'

数据处理

In [3]:
# 创建vocab
def readVocab(filepath:str="data/vocab.txt"):
    vocab = []
    with open(filepath,"r",encoding="utf-8") as f:
        for line in f:
            vocab.append(line.strip())
            
    with open("data/train.txt","r",encoding="utf-8") as fp:
        for line in fp:
            if '' != line.strip():
                vocab.append(line.strip().split(" ")[0])
            
    vocab = set(vocab)        
    vocab2id = { data:1+i for i,data in enumerate(vocab)}
    id2vocab = {1+i:data for i,data in enumerate(vocab)}
    return (vocab,vocab2id,id2vocab,len(vocab))

In [4]:
def handle(filepath:str="data/train.txt"):
    words = []
    labels = []
    vocab = []
    label_list = set()
    label_list.add("[CLS]")
    label_list.add("[SEP]")
    with open(filepath,"r",encoding="utf-8") as fp:
        word = []
        label = []
        for index,line in enumerate( fp):
            
            if '' == line.strip():
                words.append(word)
                labels.append(label)
#                 print("words:\n",pd.DataFrame(words))
#                 print("labels:\n",pd.DataFrame(labels))
                word = []
                label = []
            else:
                linesplit = line.strip().split(" ")
                word.append(linesplit[0])
                label.append(linesplit[1])
                label_list.add(linesplit[1])
                
    label2id = {data:i+1 for i,data in enumerate(label_list)}
    id2label = {1+i:data for i,data in enumerate(label_list)}
    
    input_ids = []
    for i in words:
        input_id = []
    #     break
        for j in i:
            input_id.append(vocab2id[j])
    #         break

        if len(input_id) > ( max_seq_length - 2):
            input_id = input_id[:max_seq_length - 2]
        input_id.insert(0,vocab2id["[CLS]"])   
        while len(input_id) < (max_seq_length - 1):
            input_id.append(0)
        input_id.append(vocab2id["[SEP]"])

        input_ids.append(input_id)
        
    label_ids = []
    for i in labels:
        label_id = []
    #     break
        for j in i:
            label_id.append(label2id[j])
    #         break

        if len(label_id) > ( max_seq_length - 2):
            label_id = label_id[:max_seq_length - 2]
        label_id.insert(0,label2id["[CLS]"])   
        while len(label_id) < (max_seq_length - 1):
            label_id.append(0)
        label_id.append(label2id["[SEP]"])

        label_ids.append(label_id)

    return (words,labels,label_list,label2id,id2label,len(label_list),input_ids,label_ids)

In [5]:
(vocab,vocab2id,id2vocab,vocab_size)= readVocab()

In [6]:
words,labels,label_list,label2id,id2label,label_size,input_ids,label_ids = handle()

In [7]:
input_ids = torch.Tensor(input_ids).long()
label_ids = torch.Tensor(label_ids).long()

自定义数据集

In [8]:
class BertNerDataset(tud.Dataset): 
    def __init__(self,input_ids,label_ids): 
        super(BertNerDataset,self).__init__() 
        self.label_ids = label_ids
        self.input_ids = input_ids
       
    def __len__(self): 
        return len(self.input_ids) 
    def __getitem__(self, index): 
        return (self.input_ids[index,:],self.label_ids[index,:])

dataset = BertNerDataset(input_ids[:8000],label_ids[:8000]) 
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,num_workers=0)

In [9]:
class AlBertCRFNer(nn.Module):
    '''
    AlBert + CRF
    '''
    def __init__(self,vocab_size,config_file,num_tags=768):
        
        super(AlBertCRFNer,self).__init__()
        config = AlbertConfig.from_pretrained(config_file)
        self.albert = AlbertModel(config)
        self.crf = CRF(num_tags)

    def forward(self,x,tags=None):
        y,_= self.albert(x) # 8,202,768
        if tags is None:
            output = y.permute(1,0,2)
            return self.crf.decode(output) # 8 202
        else:
            return -self.crf(y,tags,reduction='mean') # (seq_length, batch_size, num_tags)   (seq_length, batch_size)--> (batch_size,).sum() 

In [10]:
class AlBertBiLSTMCRFNer(nn.Module):
    '''
    AlBert + BiLSTM + CRF
    '''
    def __init__(self,vocab_size,config_file,num_tags=768):
        super(AlBertBiLSTMCRFNer,self).__init__()
        config = AlbertConfig.from_pretrained(config_file)
        self.albert = AlbertModel(config)
        self.bilstm = nn.LSTM(input_size = 768,hidden_size =768 //2, bidirectional =True)
        self.crf = CRF(num_tags)
        
    def forward(self,x,tags=None):
        y,_=  self.albert(x)
        output,_ = self.bilstm(y) # 8 202 768
        
        if tags is None:
            output = output.permute(1,0,2)
            return self.crf.decode(output) # 8 202
        else:
            return -self.crf(output,tags,reduction='mean') # (seq_length, batch_size, num_tags)   (seq_length, batch_size)--> (batch_size,).sum()


In [11]:
class AlBertBiGRUCRFNer(nn.Module):
    '''
    AlBert + BiGRU + CRF
    '''
    def __init__(self,vocab_size,config_file,num_tags=768):
        super(AlBertBiGRUCRFNer,self).__init__()
        config = AlbertConfig.from_pretrained(config_file)
        self.albert = AlbertModel(config)
        self.bigru = nn.GRU(input_size=768,hidden_size=768//2,bidirectional=True)
        self.crf = CRF(num_tags)

    def forward(self,x,tags=None):
            
        y,_ = self.albert(x)
        output,_ = self.bigru(y)  
        
        if tags is None:
            output = output.permute(1,0,2) 
            return self.crf.decode(output)
        else:
            return -self.crf(output,tags,reduction='mean')


In [12]:
model = AlBertBiGRUCRFNer(vocab_size,config_file)
model.load_state_dict(torch.load("outputs/AlBert-BiGRU-CRF.pt"))
model.eval()

AlBertBiGRUCRFNer(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertEncoder(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (transformer): AlbertTransformer(
        (group): ModuleList(
          (0): AlbertGroup(
            (inner_group): ModuleList(
              (0): AlbertLayer(
                (attention): AlbertAttention(
                  (self): AlbertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
           

In [13]:
import numpy as np
from sklearn.metrics import recall_score,precision_score,f1_score

In [14]:
def precision_recall_f1(true,pred):
    m,n = true.shape
    precision = np.zeros((m,))
    recall = np.zeros((m,))
    f1 = np.zeros((m,))
    
    for i in range(m):
        precision[i] = precision_score(true[i],pred[i],average="macro")
        recall[i] = recall_score(true[i],pred[i],average="macro")
        f1[i] = f1_score(true[i],pred[i],average="macro")
        
    return (precision.mean(),recall.mean(),f1.mean())

In [15]:
for step,(input_ids,label_ids) in enumerate(dataloader):
    y_pred =  model(input_ids,None)
    y_pred = np.array(y_pred)
    y_true = label_ids.numpy()
    precision,recall,f1 = precision_recall_f1(y_true,y_pred) 
    print("precision:",precision,"recall:",recall,"f1:",f1)
    break

precision: 0.0007838283828382838 recall: 0.15833333333333333 f1: 0.0015599343185550081


  _warn_prf(average, modifier, msg_start, len(result))
