In [None]:
from collections import defaultdict
from operator import itemgetter
from tqdm import tqdm
import numpy as np
import random
import torch 
import jieba
import json
import os

import pickle as pk

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.cuda.set_device(0)
else:
    device = torch.device('cpu')
# 确定模型训练方式，GPU训练或CPU训练
parameter_copy = {
    # 此处embedding维度为768
    'd_model':768, 
    # rnn的隐层维度为300
    'hid_dim':300,
    # 训练的批次为100轮
    'epoch':100,
    # 单次训练的batch_size为100条数据
    'batch_size':10,
    # 设置序列的最大长度为100
    'n_layers':2,
    # 设置dropout，为防止过拟合
    'dropout':0.1,
    # 配置cpu、gpu
    'device':device,
    # 设置训练学习率
    'lr':0.001,
    # 优化器的参数，动量主要用于随机梯度下降
    'momentum':0.99,
}

def build_dataSet(parameter):
    data_name = ['train','dev']
    data_set = {}
    key_table = defaultdict(int)
    vocab_table = defaultdict(int)
    vocab_table['<PAD>'] = 0
    vocab_table['<UNK>'] = 0
    for i in data_name:
        data_set[i] = []
        data_src = open('data/'+i+'.json','r',encoding = 'utf-8').readlines()
        for data in data_src:
            data = json.loads(data)
            text = list(data['text'])
            label = data['label']
            label_new = ['O']*len(text)
            key_table['O']
            for keys in label:
                inds = label[keys].values()
                for id_list in inds:
                    for ind in id_list:
                        if ind[1] - ind[0] == 0:
                            keys_list = ['S-'+keys]
                            label_new[ind[0]] = keys_list[0]
                        if ind[1] - ind[0] == 1:
                            keys_list = ['B-'+keys,'E-'+keys]
                            label_new[ind[0]] = keys_list[0]
                            label_new[ind[1]] = keys_list[1]
                        if ind[1] - ind[0] > 1:
                            keys_list = ['B-'+keys,'I-'+keys,'E-'+keys]
                            label_new[ind[0]] = keys_list[0]
                            label_new[ind[0]+1:ind[1]] = [keys_list[1]]*(ind[1]-1-ind[0])
                            label_new[ind[1]] = keys_list[2]
                        for key in keys_list:
                            key_table[key] += 1
            for j in text:
                vocab_table[j] += 1
            data_set[i].append([text,label_new])
    key2ind = dict(zip(key_table.keys(),range(len(key_table))))
    ind2key = dict(zip(range(len(key_table)),key_table.keys()))
    word2ind = dict(zip(vocab_table.keys(),range(len(vocab_table))))
    ind2word = dict(zip(range(len(vocab_table)),vocab_table.keys()))
    parameter['key2ind'] = key2ind
    parameter['ind2key'] = ind2key
    parameter['word2ind'] = word2ind
    parameter['ind2word'] = ind2word
    parameter['data_set'] = data_set
    parameter['output_size'] = len(key2ind)
    parameter['word_size'] = len(word2ind)
    return parameter


def batch_yield_bert(parameter,shuffle = True,isTrain = True):
    data_set = parameter['data_set']['train'] if isTrain else parameter['data_set']['dev']
    Epoch = parameter['epoch'] if isTrain else 1
    for epoch in range(Epoch):
        # 每轮对原始数据进行随机化
        if shuffle:
            random.shuffle(data_set)
        inputs,targets = [],[]
        max_len = 0
        for items in tqdm(data_set):
            input = tokenizer.convert_tokens_to_ids(items[0])
            target = itemgetter(*items[1])(parameter['key2ind'])
            target = target if type(target) == type(()) else (target,0)
            if len(input) > max_len:
                max_len = len(input)
            inputs.append(list(input))
            targets.append(list(target))
            if len(inputs) >= parameter['batch_size']:
                inputs = [i+[0]*(max_len-len(i)) for i in inputs]
                targets = [i+[-1]*(max_len-len(i)) for i in targets]
                yield list2torch(inputs),list2torch(targets),None,False
                inputs,targets = [],[]
                max_len = 0
        inputs = [i+[0]*(max_len-len(i)) for i in inputs]
        targets = [i+[0]*(max_len-len(i)) for i in targets]
        yield list2torch(inputs),list2torch(targets),epoch,False
        inputs,targets = [],[]
        max_len = 0
    yield None,None,None,True
            

def list2torch(ins):
    return torch.from_numpy(np.array(ins))

# 因此这边提前配置好用于训练的相关参数
# 不要每次重新生成
if not os.path.exists('parameter.pkl'):
    parameter = parameter_copy
    # 构建相关字典和对应的数据集
    parameter = build_dataSet(parameter)
    pk.dump(parameter,open('parameter.pkl','wb'))
else:
    # 读取已经处理好的parameter，但是考虑到模型训练的参数会发生变化，
    # 因此此处对于parameter中模型训练参数进行替换
    parameter = pk.load(open('parameter.pkl','rb'))
    for i in parameter_copy.keys():
        if i not in parameter:
            parameter[i] = parameter_copy[i]
            continue
        if parameter_copy[i] != parameter[i]:
            parameter[i] = parameter_copy[i]
    for i in parameter_copy.keys():
        print(i,':',parameter[i])
    pk.dump(parameter,open('parameter.pkl','wb'))
    del parameter_copy,i

In [None]:
from transformers import WEIGHTS_NAME, BertConfig,get_linear_schedule_with_warmup,AdamW, BertTokenizer
from transformers import BertModel,BertPreTrainedModel
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch

import torch.nn.functional as F # pytorch 激活函数的类
from torch import nn,optim # 构建模型和优化器
from torchcrf import CRF

class bert(BertPreTrainedModel):
    def __init__(self, config,parameter):
        super(bert, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        embedding_dim = parameter['d_model']
        output_size = parameter['output_size']
        self.fc = nn.Linear(embedding_dim, output_size)
        self.init_weights()
        
        self.crf = CRF(output_size,batch_first=True)
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None):
        outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        logits = self.fc(sequence_output)
        return logits
    
config_class, model_class, tokenizer_class = BertConfig, bert, BertTokenizer
config = config_class.from_pretrained("prev_trained_model")
tokenizer = tokenizer_class.from_pretrained("prev_trained_model")
model = model_class.from_pretrained("prev_trained_model",config=config,parameter = parameter)

In [None]:
import os
import shutil
import pickle as pk
from torch.utils.tensorboard import SummaryWriter


# 构建模型
model = bert(config,parameter).to(parameter['device'])

# 确定训练模式
model.train()

# 确定优化器和损失
# optimizer = torch.optim.SGD(model.parameters(),lr=3*10**-5, momentum=0.95, nesterov=True)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, betas=(0.9, 0.999), eps=1e-6, \
                             weight_decay = 0)
criterion = nn.CrossEntropyLoss(ignore_index=-1)

# 准备学习率策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

# for param in model.bert.parameters():
#     param.requires_grad = False

# 准备迭代器
train_yield = batch_yield_bert(parameter,tokenizer)

# 开始训练
loss_cal = []
min_loss = float('inf')
logging_steps = 0
while 1:
        inputs,targets,epoch,keys = next(train_yield)
        if keys:
            break
        out = model(inputs.long().to(parameter['device']))
        loss = criterion(out, targets.view(-1).long().to(parameter['device']))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_cal.append(loss.item())
        logging_steps += 1
        if logging_steps%100 == 0:
            print(sum(loss_cal)/len(loss_cal))
        if epoch is not None:
            if (epoch+1)%1 == 0:
                loss_cal = sum(loss_cal)/len(loss_cal)
                if loss_cal < min_loss:
                    min_loss = loss_cal
                    torch.save(model.state_dict(), 'bert.h5')
                print('epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, \
                                                       parameter['epoch'],loss_cal))
            loss_cal = [loss.item()]
            scheduler.step()
