In [None]:
import torch
from torch import nn
import os
import copy
import pandas as pd
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
# import pretraining_args as args
from transformers import BertConfig
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import BertModel
from transformers import BertTokenizer,BertConfig,BertForTokenClassification,BertModel
import time,datetime
from sklearn.metrics import precision_score,classification_report,f1_score,recall_score
import numpy as np
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import AlbertConfig, AlbertModel,AlbertForTokenClassification
# from pytorch_pretrained_bert import BertTokenizer
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from crf import CRF
import time

In [None]:
base = '/root/yy/data/ResumeNER'
base_path = '../berts/albert_base_zh'
train_path = 'train.char.bmes'
dev_path = 'dev.char.bmes'
test_path = 'test.char.bmes'

In [None]:
def load_data(base,train_path):
    full = os.path.join(base,train_path)
    with open(full,'r',encoding='utf-8')as f:
        data = f.readlines()
    tokens,labels = [],[]
    token,label = [],[]
    for line in data:
        line= line.strip().replace("\n",'')
        if len(line.split(' ')) == 2:
            token.append(line.split(' ')[0])
            label.append(line.split(' ')[1])
        else:
            tokens.append(token)
            labels.append(label)
            token,label = [],[]
    return tokens,labels

def trans2id(labels):
    tag_set = set()
    for line in labels:
        for label  in line:
            if label not in tag_set:
                tag_set.add(label)
    tag_set.add('[CLS]')
    tag_set.add('[SEP]')
    tag_set = list(tag_set)
    idx = [i for i in range(len(tag_set))]
    tag2id = dict(zip(tag_set,idx))
    id2tag = dict(zip(idx,tag_set))
    return tag2id,id2tag

def gen_features(tokens,labels,tokenizer,tag2id,max_len):
    input_ids,tags,masks,lengths = [],[],[],[]
    for i,(token,label) in enumerate(zip(tokens,labels)):
        lengths.append(len(token))
        if len(token) >= max_len - 2:
            token = token[0:max_len - 2]
            label = labels[i][0:max_len - 2]
        mask = [1] * len(token)
        
        token = '[CLS] ' + ' '.join(token) + ' [SEP]'
        tokenized_text = tokenizer.tokenize(token)
        input_id = tokenizer.convert_tokens_to_ids(tokenized_text)
        label = [tag2id['[CLS]']] + [tag2id[i] for i in label] + [tag2id['[SEP]']]
        mask = [0] + mask + [0]
        # padding
        if len(input_id) < max_len:
            input_id = input_id + [0] * (max_len - len(input_id))
            label = label + [tag2id['O']] * (max_len - len(label))
            mask = mask + [0] * (max_len - len(mask))
        
        assert len(input_id) == max_len
        assert len(label) == max_len
        assert len(mask) == max_len
         
        input_ids.append(input_id)
        tags.append(label)
        masks.append(mask)
    return input_ids,tags,masks,lengths

In [None]:
max_len = 128
bs = 32
tokenizer = BertTokenizer.from_pretrained('../berts/bert-base-transformers')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_tokens,train_labels = load_data(base,train_path)
tag2id,id2tag = trans2id(train_labels)
train_ids,train_tags,train_masks,train_lengths = gen_features(train_tokens,train_labels,tokenizer,tag2id,max_len)

dev_tokens,dev_labels = load_data(base,dev_path)
dev_ids,dev_tags,dev_masks,dev_lengths = gen_features(dev_tokens,dev_labels,tokenizer,tag2id,max_len)

In [None]:
train_ids = torch.tensor(train_ids)
train_tags = torch.tensor(train_tags)
train_masks = torch.tensor(train_masks)
# train_lengths = torch.tensor(train_lengths)

dev_ids = torch.tensor(dev_ids)
dev_tags = torch.tensor(dev_tags)
dev_masks = torch.tensor(dev_masks)
# dev_lengths = torch.tensor(dev_lengths)

In [None]:
train_data = TensorDataset(train_ids, train_masks, train_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(dev_ids, dev_masks, dev_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [None]:
class Bert_CRF(nn.Module):
    def __init__(self,base_path,num_labels,lstm_hidden_size = 128,dropout = 0.3,lm_flag = False):
        super(Bert_CRF,self).__init__()
        bert_config = AlbertConfig.from_json_file(os.path.join(base_path,'config.json'))
        bert_config.num_labels = num_labels
        #hidden_states (tuple(torch.FloatTensor), optional, returned when config.output_hidden_states=True):
        #方案二要加下面两个config
        bert_config.output_hidden_states=True
        bert_config.output_attentions=True
        self.bert = AlbertModel.from_pretrained(os.path.join(base_path,'pytorch_model.bin'), config=bert_config)
        self.dropout = nn.Dropout(dropout)
        #lstm input_size = bert_config.hidden_size  hidden_size(第二个参数)= 跟Linear 的第一个参数对上
        # 尝试下双向LSTM
        self.lm_flag = lm_flag
        self.lstm = nn.LSTM(bert_config.hidden_size, lstm_hidden_size,
                            num_layers=1, bidirectional=True, dropout=0.3, batch_first=True)
        self.clf = nn.Linear(256,bert_config.num_labels + 2)
        self.layer_norm = nn.LayerNorm(lstm_hidden_size * 2)
        self.crf = CRF(target_size=bert_config.num_labels, average_batch=True, use_cuda=True)
        
    def forward(self,input_ids,masks):
        batch_size = input_ids.size(0)
        seq_length = input_ids.size(1)
        
        outputs = self.bert(input_ids, token_type_ids=None,
                     attention_mask=masks)
        # 方案一：
#         embeds = outputs[0]
        
        #方案二：倒数第二层hidden_states 的shape
        # bert_config的设置
        all_hidden_states, all_attentions = outputs[-2:]
#         print('all_hidden_states',all_hidden_states.shape)
        embeds = all_hidden_states[-2]
#         print('embeds',embeds.shape)

        lstm_out,hidden = self.lstm(embeds)
        lstm_out= lstm_out.contiguous().view(-1, 128*2)
        if self.lm_flag:
            lstm_out = self.layer_norm(lstm_out)
        logits = self.clf(lstm_out)
        logits = logits.contiguous().view(batch_size, seq_length, -1)
        return logits
    
    def loss(self,logits,mask,tag):
        loss_value = self.crf.neg_log_likelihood_loss(logits,mask,tag)
        bs = logits.size(0)
        loss_value  /= float(bs)
        return loss_value

In [None]:
model = Bert_CRF(base_path,num_labels = len(tag2id),lm_flag = True)

In [None]:
model.to(device)

In [None]:
optimizer = AdamW(model.parameters(),
                  lr = 5e-5, # default is 5e-5
                  eps = 1e-8 # default is 1e-8
                )

In [None]:
epochs = 20
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                           num_warmup_steps = 0,
                                           num_training_steps = total_steps)

In [None]:
def trans2label(id2tag,data,lengths):
    new = []
    for i,line in enumerate(data):
        tmp = [id2tag[word] for word in line]
        tmp = tmp[1:1 + lengths[i]]    
        new.append(tmp)
    return new

def get_entities(tags):
    start, end = -1, -1
    prev = 'O'
    entities = []
    n = len(tags)
    tags = [tag.split('-')[1] if '-' in tag else tag for tag in tags]
    for i, tag in enumerate(tags):
        if tag != 'O':
            if prev == 'O':
                start = i
                prev = tag
            elif tag == prev:
                end = i
                if i == n -1 :
                    entities.append((start, i))
            else:
                entities.append((start, i - 1))
                prev = tag
                start = i
                end = i
        else:
            if start >= 0 and end >= 0:
                entities.append((start, end))
                start = -1
                end = -1
                prev = 'O'
    return entities

def measure(preds,trues,lengths,id2tag):
    correct_num = 0
    predict_num = 0
    truth_num = 0
    pred = trans2label(id2tag,preds,lengths)
    true = trans2label(id2tag,trues,lengths)
    assert len(pred) == len(true)
    for p,t in zip(pred,true):
        pred_en = get_entities(p)
        true_en = get_entities(t)
#         print('pred_en',pred_en)
#         print('true_en',true_en)
        correct_num += len(set(pred_en) & set(true_en))
        predict_num += len(set(pred_en))
        truth_num += len(set(true_en))
    precision = correct_num / predict_num if predict_num else 0
    recall = correct_num / truth_num if truth_num else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
    return f1, precision, recall

In [None]:
max_grad_norm = 1.0

tra_loss,train_steps = 0.0,0
dev_loss,dev_steps = 0.0,0

start = time.time()
for i in range(epochs):
    model.train()
    for step ,batch in enumerate(train_dataloader):
        input_ids,masks,labels= (i.to(device) for i in batch)
        outputs = model(input_ids,masks)
        loss = model.loss(outputs,masks,labels)
        
        loss.backward()
        
        tra_loss += loss.item()
        train_steps += 1
        
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        scheduler.step()
        optimizer.step()
        
        if step % 30 == 0:
            print("epoch :{},step :{} ,Train loss: {}".format(i,step,tra_loss/train_steps))
    
    print("Training Loss of epoch {}:{}".format(i,tra_loss / train_steps))
    
    model.eval()
    predictions , true_labels = [], []
    
    for step ,batch in enumerate(valid_dataloader):
        input_ids,masks,labels = (i.to(device) for i in batch)
        with torch.no_grad():
            logits = model(input_ids,masks)
            loss = model.loss(logits,masks,labels)
            path_score, best_path = model.crf(logits, input_ids.bool())
            
            dev_loss += loss.item()
            dev_steps += 1
            
            if step % 10 == 0:
                print("epoch :{},step :{} ,Dev loss: {}".format(i,step,dev_loss/dev_steps))
 
        best_path = best_path.detach().cpu().numpy().tolist()
        predictions.extend(best_path)
        true_labels.extend(labels.to('cpu').numpy().tolist())
    f1, precision, recall = measure(predictions,true_labels,dev_lengths,id2tag)
    print('epoch {} : Acc : {},Recall : {},F1 :{}'.format(i,precision,recall,f1))
end = time.time()
print('Training Time:',end - start)

In [None]:
test_tokens,test_labels = load_data(base,test_path)
test_ids,test_tags,test_masks,test_lengths = gen_features(test_tokens,test_labels,tokenizer,tag2id,max_len)

In [None]:
test_ids = torch.tensor(test_ids)
test_tags = torch.tensor(test_tags)
test_masks = torch.tensor(test_masks)

In [None]:
test_data = TensorDataset(test_ids, test_masks, test_tags)
# test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, batch_size=bs)

In [None]:
def measure_(preds,trues,lengths,id2tag,test_tokens):
    correct_num = 0
    predict_num = 0
    truth_num = 0
    pred = trans2label(id2tag,preds,lengths)
    true = trans2label(id2tag,trues,lengths)
    assert len(pred) == len(true)
    for i,(p,t) in enumerate(zip(pred,true)):
        pred_en = get_entities(p)
        true_en = get_entities(t)
        print('pred_en',pred_en)
        print('true_en',true_en)
        print(test_tokens[i])
        print('***********************')
        correct_num += len(set(pred_en) & set(true_en))
        predict_num += len(set(pred_en))
        truth_num += len(set(true_en))
    precision = correct_num / predict_num if predict_num else 0
    recall = correct_num / truth_num if truth_num else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
    return f1, precision, recall

In [None]:
model.eval()
test_pre,test_true = [],[]
for batch in test_dataloader:
    input_ids,masks,labels = (i.to(device) for i in batch)
    with torch.no_grad():
        logits = model(input_ids,masks)
        path_score, test_best_path = model.crf(logits, input_ids.bool())

    test_best_path = test_best_path.detach().cpu().numpy().tolist()
    test_pre.extend(test_best_path)
    test_true.extend(labels.to('cpu').numpy().tolist())

test_f1, test_precision, test_recall = measure_(test_pre,test_true,test_lengths,id2tag,test_tokens)
# test_pred = trans2label(id2tag,test_pre,test_lengths)
# test_trues = trans2label(id2tag,test_true,test_lengths)
# measure_1(test_pre,test_true,test_lengths,id2tag,test_tokens)
print('Test Acc : {},Recall : {},F1 :{}'.format(test_precision,test_recall,test_f1))

In [None]:
# list(model.bert.named_parameters())
# list(model.lstm.named_parameters())