# 数据准备+预处理

In [1]:
import torch
import pickle as pk
import numpy as np
from tqdm import tqdm
import random
import gensim
import os

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.cuda.set_device(0)
else:
    device = torch.device('cpu')

[x_intent,y_intent,x_ner,y_ner] = pk.load(open('data/data-intent1-ner.pkl','rb'))
train_data = [(x_ner[i],y_ner[i]) for i in range(len(x_ner))]
vocab = set([j for i in x_ner for j in i])
vocab = dict(zip(vocab,range(1,len(vocab)+1)))
vocab['<PAD>'] = 0
vocab['<UNK>'] = len(vocab)
tag2label = ['O']+list(set([j for i in y_ner for j in i]) - set('O'))
tag2label = dict(zip(tag2label,range(len(tag2label))))

# 准备好模型的参数
parameter = {
    'batch_size':32,
    'epoch':10,
    'hid_dim':300,
    'dropout':0.5,
    'lr':0.001,
    'tag2label':tag2label,
    'num_tags':len(tag2label),
    'd_model':768,
    'shuffle':True,
    'vocab':None,
    'model_path':None,
    'n_layers':2,
    'device':device,
}
out_path = 'model/'
os.mkdir(out_path) if not os.path.exists(out_path) else 1
model_path = os.path.join(out_path, "ner/")
os.mkdir(model_path) if not os.path.exists(model_path) else 1
parameter['vocab'] = vocab
parameter['vocab_size'] = len(vocab)
parameter['model_path'] = model_path

def batch_yield(data,parameter,shuffle = True):
    def list2torch(ins):
        return torch.from_numpy(np.array(ins))
    def seq2id(seq, vocab):
        sentence_id = []
        for word in seq:
            if word not in vocab:
                word = '<UNK>'
            sentence_id.append(vocab[word])
        return sentence_id
    # 构建一个迭代器，获取相应的seqs（index型）和label，按照batch_size提取
    if shuffle:
        random.shuffle(data)
    seqs,labels = [],[]
    for (seq,label) in tqdm(data):
        seq = seq2id(seq,parameter['vocab'])
        label = [parameter['tag2label'][label_] for label_ in label]
        if len(seqs) == parameter['batch_size']:
            seq_len_list = [len(i) for i in seqs]
            max_len = max(seq_len_list)
            seqs = [i+[0]*(max_len-len(i)) for i in seqs]
            labels = [i+[0]*(max_len-len(i)) for i in labels]
            yield list2torch(seqs),list2torch(labels),False
            seqs,labels = [],[]
        seqs.append(seq)
        labels.append(label)
    if len(seqs) != 0:
        seq_len_list = [len(i) for i in seqs]
        max_len = max(seq_len_list)
        seqs = [i+[0]*(max_len-len(i)) for i in seqs]
        labels = [i+[0]*(max_len-len(i)) for i in labels]
        yield list2torch(seqs), list2torch(labels),True

# pk.dump(parameter,open(parameter['model_path']+'/parameter.pkl','wb'))
            
data = batch_yield(train_data,parameter)
seqs, labels,keys = next(data)
seqs.shape,labels.shape,seqs[:2],labels[:2]


  0%|                                                                                        | 0/42534 [00:00<?, ?it/s]

(torch.Size([32, 23]),
 torch.Size([32, 23]),
 tensor([[   6,  389, 1073, 2176,  564,  746, 2065,  866,  516,  518, 2091,  734,
           119, 1829,    0,    0,    0,    0,    0,    0,    0,    0,    0],
         [   6,  389,  990,  777, 1240, 1822, 1694,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]],
        dtype=torch.int32),
 tensor([[ 0,  0, 11,  7,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0],
         [ 0,  0, 17, 16,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0]], dtype=torch.int32))

In [3]:
parameter['device']

device(type='cpu')

In [4]:
parameter['tag2label']

{'O': 0,
 'S-author': 1,
 'E-tag': 2,
 'S-kg': 3,
 'E-km1': 4,
 'I-kg': 5,
 'B-author': 6,
 'I-km1': 7,
 'I-class': 8,
 'B-km1': 9,
 'B-tag': 10,
 'B-km2': 11,
 'E-kg': 12,
 'B-kg': 13,
 'E-km2': 14,
 'I-km2': 15,
 'S-title': 16,
 'B-class': 17,
 'S-class': 18,
 'E-title': 19,
 'I-tag': 20,
 'E-author': 21,
 'I-author': 22,
 'B-title': 23,
 'I-title': 24,
 'E-class': 25}

# 模型构建及模型训练

In [None]:
import torch.nn.functional as F # pytorch 激活函数的类
from torch import nn,optim # 构建模型和优化器
from torchcrf import CRF


# 构建基于bilstm+crf实现ner
class bilstm_crf(nn.Module):
    def __init__(self, parameter):
        super(bilstm_crf, self).__init__()
        vocab_size = parameter['vocab_size']
        embedding_dim = parameter['d_model']
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        hidden_size = parameter['hid_dim']
        num_layers = parameter['n_layers']
        dropout = parameter['dropout']
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, bidirectional=True, batch_first=True, dropout=dropout)

        output_size = parameter['num_tags']
        self.fc = nn.Linear(hidden_size*2, output_size)
        
        self.crf = CRF(output_size,batch_first=True)
        
    def forward(self, x):
        out = self.embedding(x)
        out,(h, c)= self.lstm(out)
        out = self.fc(out)
        return out
    
import os
import shutil
import pickle as pk
from torch.utils.tensorboard import SummaryWriter


# 构建模型
model = bilstm_crf(parameter).to(parameter['device'])

# 确定训练模式
model.train()

# 确定优化器和损失
optimizer = torch.optim.SGD(model.parameters(),lr=0.00005, momentum=0.95, nesterov=True)
# optimizer = torch.optim.Adam(model.parameters(),lr = parameter['lr'], \
#                              weight_decay = 0.01)

# 准备学习率策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)


# 开始训练
loss_cal = []
min_loss = float('inf')
for epoch in range(parameter['epoch']):
    # 迭代器重置
    train_yield = batch_yield(train_data,parameter)
    while 1:
        inputs,targets,keys = next(train_yield)
        out = model(inputs.long().to(parameter['device']))
        loss = -model.crf(out,targets.long().to(parameter['device']))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_cal.append(loss.item())
        if keys:
            break
    loss_cal = sum(loss_cal)/len(loss_cal)
    if loss_cal < min_loss:
        min_loss = loss_cal
        torch.save(model.state_dict(), parameter['model_path']+'/bilstm_crf.h5')
        print('epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, \
                                                       parameter['epoch'],loss_cal))
    loss_cal = [loss.item()]
    scheduler.step()


  score = torch.where(mask[i].unsqueeze(1), next_score, score)

  0%|                                                                               | 33/42534 [00:00<10:15, 69.08it/s][A
  0%|                                                                               | 65/42534 [00:00<09:51, 71.75it/s][A
  0%|▏                                                                              | 97/42534 [00:01<09:47, 72.26it/s][A
  0%|▏                                                                             | 129/42534 [00:01<09:32, 74.09it/s][A
  0%|▎                                                                             | 161/42534 [00:02<09:20, 75.54it/s][A
  0%|▎                                                                             | 193/42534 [00:02<09:04, 77.69it/s][A
  1%|▍                                                                             | 225/42534 [00:02<09:08, 77.14it/s][A
  1%|▍                                                                    

  5%|███▋                                                                         | 2017/42534 [00:29<09:33, 70.71it/s][A
  5%|███▋                                                                         | 2049/42534 [00:30<09:14, 72.96it/s][A
  5%|███▊                                                                         | 2081/42534 [00:30<08:54, 75.68it/s][A
  5%|███▊                                                                         | 2113/42534 [00:30<08:46, 76.84it/s][A
  5%|███▉                                                                         | 2145/42534 [00:31<09:24, 71.55it/s][A
  5%|███▉                                                                         | 2177/42534 [00:31<09:23, 71.62it/s][A
  5%|███▉                                                                         | 2209/42534 [00:32<09:02, 74.36it/s][A
  5%|████                                                                         | 2241/42534 [00:32<08:44, 76.84it/s][A
  5%|████       

 10%|███████▍                                                                     | 4129/42534 [00:55<07:46, 82.40it/s][A
 10%|███████▌                                                                     | 4161/42534 [00:56<07:59, 79.95it/s][A
 10%|███████▌                                                                     | 4193/42534 [00:56<07:43, 82.76it/s][A
 10%|███████▋                                                                     | 4225/42534 [00:56<07:40, 83.13it/s][A
 10%|███████▋                                                                     | 4257/42534 [00:57<07:33, 84.37it/s][A
 10%|███████▊                                                                     | 4289/42534 [00:57<07:54, 80.63it/s][A
 10%|███████▊                                                                     | 4321/42534 [00:57<07:53, 80.66it/s][A
 10%|███████▉                                                                     | 4353/42534 [00:58<07:35, 83.88it/s][A
 10%|███████▉   

 15%|███████████▎                                                                 | 6241/42534 [01:22<07:49, 77.34it/s][A
 15%|███████████▎                                                                 | 6273/42534 [01:23<08:10, 74.00it/s][A
 15%|███████████▍                                                                 | 6305/42534 [01:23<08:11, 73.66it/s][A
 15%|███████████▍                                                                 | 6337/42534 [01:23<07:34, 79.63it/s][A
 15%|███████████▌                                                                 | 6369/42534 [01:24<07:29, 80.46it/s][A
 15%|███████████▌                                                                 | 6401/42534 [01:24<07:32, 79.90it/s][A
 15%|███████████▋                                                                 | 6433/42534 [01:25<07:38, 78.82it/s][A
 15%|███████████▋                                                                 | 6465/42534 [01:25<07:53, 76.17it/s][A
 15%|███████████

 20%|███████████████                                                              | 8353/42534 [01:54<08:37, 66.10it/s][A
 20%|███████████████▏                                                             | 8385/42534 [01:55<08:29, 66.98it/s][A
 20%|███████████████▏                                                             | 8417/42534 [01:55<08:25, 67.45it/s][A
 20%|███████████████▎                                                             | 8449/42534 [01:55<08:26, 67.27it/s][A
 20%|███████████████▎                                                             | 8481/42534 [01:56<08:30, 66.73it/s][A
 20%|███████████████▍                                                             | 8513/42534 [01:56<08:40, 65.37it/s][A
 20%|███████████████▍                                                             | 8545/42534 [01:57<08:02, 70.39it/s][A
 20%|███████████████▌                                                             | 8577/42534 [01:57<07:58, 70.98it/s][A
 20%|███████████

 25%|██████████████████▋                                                         | 10465/42534 [02:27<06:53, 77.62it/s][A
 25%|██████████████████▊                                                         | 10497/42534 [02:28<07:11, 74.32it/s][A
 25%|██████████████████▊                                                         | 10529/42534 [02:28<07:22, 72.29it/s][A
 25%|██████████████████▊                                                         | 10561/42534 [02:29<07:07, 74.76it/s][A
 25%|██████████████████▉                                                         | 10593/42534 [02:29<07:12, 73.78it/s][A
 25%|██████████████████▉                                                         | 10625/42534 [02:30<06:45, 78.78it/s][A
 25%|███████████████████                                                         | 10657/42534 [02:30<07:03, 75.20it/s][A
 25%|███████████████████                                                         | 10689/42534 [02:30<06:49, 77.81it/s][A
 25%|███████████