In [1]:
# basic
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from tqdm import tqdm
import sklearn

# np/pd
import numpy as np
import pandas as pd

# torch
import torch
import torchtext
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# transformer
from datasets import load_dataset

# CRF
from torchcrf import CRF

In [2]:
torch.cuda.is_available()

True

In [3]:
Config = {
    'num_tags':9,
    'embedding_dim':200,
    'vocab_size':23623,
    'hidden_dim':100,
    'batch_size':32
}

## 探查conll2003数据

定义dataset，dataLoader

In [4]:
# data_udpos = torchtext.datasets.UDPOS(root='./torchtext_datasets_udpos/', split=('train','valid','test'))

dataset_conll2003 = load_dataset("conll2003")

Reusing dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

In [87]:
dataset_conll2003['test'][0]

{'id': '0',
 'tokens': ['SOCCER',
  '-',
  'JAPAN',
  'GET',
  'LUCKY',
  'WIN',
  ',',
  'CHINA',
  'IN',
  'SURPRISE',
  'DEFEAT',
  '.'],
 'pos_tags': [21, 8, 22, 37, 22, 22, 6, 22, 15, 12, 21, 7],
 'chunk_tags': [11, 0, 11, 21, 11, 12, 0, 11, 13, 11, 12, 0],
 'ner_tags': [0, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0, 0]}

In [5]:
ner_tag2id = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}

ner_id2tag = {}
for key in ner_tag2id.keys():
    ner_id2tag[ner_tag2id[key]] = key

In [6]:
ner_id2tag

{0: 'O',
 1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-LOC',
 6: 'I-LOC',
 7: 'B-MISC',
 8: 'I-MISC'}

In [7]:
def ner_id_to_tags(ner_id_seq):
    res1, res2 = [], []
    for ner_id in ner_id_seq:
        res1.append(ner_id2tag.get(ner_id, ''))
        res2.append(ner_id2tag.get(ner_id, '-').split('-')[-1])
    return res1, res2

ner_id_to_tags(dataset_conll2003['train'][0]['ner_tags'])

(['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'],
 ['ORG', 'O', 'MISC', 'O', 'O', 'O', 'MISC', 'O', 'O'])

In [8]:
m = 0
for k in dataset_conll2003.keys():
    for i in dataset_conll2003[k]['tokens']:
        m = max(m, len(i))
print('max length = {}'.format(m))

max length = 124


In [9]:
word_to_ix = {}
for k in dataset_conll2003.keys():
    for tokens in dataset_conll2003[k]['tokens']:
        for token in tokens:
            if token not in word_to_ix:
                word_to_ix[token] = len(word_to_ix)
                
Config['vocab_size'] = len(word_to_ix)

print(Config['vocab_size'])

30289


In [10]:
def func_word2ix(word_to_ix, token_list):
    res = list()
    for token in token_list:
        res.append(word_to_ix.get(token, len(word_to_ix)+1))
    return {'token_ids':res}

In [11]:
dataset_conll2003_train=dataset_conll2003['train'].map(lambda x: func_word2ix(word_to_ix, x['tokens']))
dataset_conll2003_train.set_format(type="torch", columns=['token_ids','ner_tags'])

dataset_conll2003_test=dataset_conll2003['test'].map(lambda x: func_word2ix(word_to_ix, x['tokens']))
dataset_conll2003_test.set_format(type="torch", columns=['token_ids','ner_tags'])

dataset_conll2003_val=dataset_conll2003['validation'].map(lambda x: func_word2ix(word_to_ix, x['tokens']))
dataset_conll2003_val.set_format(type="torch", columns=['token_ids','ner_tags'])

Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-e993f22e7c8f1977.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-9da2d1c0e9528511.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-586edb6c12fcdba7.arrow


In [89]:
class MyDataset(Dataset):

    def __init__(self, data):
        self.data = data
        self.token_ids = self.data['token_ids'] # 在这变成torch.tensor，但长度不同
        self.ner_tags = self.data['ner_tags']
        self.tokens = self.data['tokens']
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
#         curr = dict()
#         curr['token_ids'] = self.token_ids
#         curr['ner_tags'] = self.ner_tags
        return self.token_ids[index], self.ner_tags[index], self.tokens[index]


def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    x, y, z = zip(*batch)
    x_lens = [len(x_i) for x_i in x]
    y_lens = [len(y_i) for y_i in y]
    x_pad = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
    y_pad = torch.nn.utils.rnn.pad_sequence(y, batch_first=True)
    return x_pad, torch.tensor(x_lens), y_pad, torch.tensor(y_lens), z
    
dataset_conll2003_train_loader = DataLoader(
    MyDataset(dataset_conll2003_train),
    batch_size=32,
    shuffle=True, 
    collate_fn=lambda x: collate_fn_padd(x))


dataset_conll2003_test_loader = DataLoader(
    MyDataset(dataset_conll2003_test),
    batch_size=32,
    shuffle=True, 
    collate_fn=lambda x: collate_fn_padd(x))

In [97]:
next(iter(dataset_conll2003_train_loader))[4][0]

['Katarina',
 'Studenikova',
 '(',
 'Slovakia',
 ')',
 'beat',
 '6',
 '-',
 'Karina',
 'Habsudova']

In [90]:
print(dataset_conll2003_train_loader.dataset.__len__(), dataset_conll2003_test_loader.dataset.__len__())

14041 3453


## 模型定义

In [14]:
class BiLSTM_CRF(nn.Module):

    def __init__(self, config=None):
        super(BiLSTM_CRF, self).__init__()
        self.config = config

        # BiLSTM-model 给 emission 层定义参数
        self.embedding_dim = self.config.get('embedding_dim', 200)
        self.hidden_dim = self.config.get('hidden_dim', 200)
        self.vocab_size = self.config.get('vocab_size', 30289)

        self.word_embeds = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.target_size = self.config.get('num_tags', 9)
        self.num_layers = 1
        self.batch_size = self.config.get('batch_size',16)
        self.bidirectional = True

        # lstm
        self.lstm = nn.LSTM(input_size=self.embedding_dim, hidden_size=self.hidden_dim//2,
                            num_layers=self.num_layers, bidirectional=self.bidirectional)
        self.hidden2tag = nn.Linear(self.hidden_dim, self.target_size)
#         self.hidden_init = self.init_hidden()

        # CRF-model
        self.crf = CRF(self.config.get('num_tags', 9), batch_first=True)

    def init_hidden(self):
        hidden = torch.zeros(self.num_layers*2 if self.bidirectional else self.num_layers,
                             self.batch_size, self.hidden_dim//2)
        cell = torch.zeros(self.num_layers*2 if self.bidirectional else self.num_layers,
                             self.batch_size, self.hidden_dim//2)
        return hidden, cell

    def forward(self, sent, sent_len):
        """

        :param sent: 输入的已转换为token_id的句子，(batch_len * sent_len * token_emb_len)
        :param sent_len: tensor(list(int))
        :return:
        """
        embeds = self.word_embeds(sent)
        embed_packed = pack_padded_sequence(embeds, lengths=sent_len.to('cpu'),
                                            batch_first=True,
                                            enforce_sorted=False)
        lstm_out, (hidden, cell) = self.lstm(embed_packed) #, self.hidden_init)
        lstm_out, lens = pad_packed_sequence(lstm_out, batch_first=True)
        tag_score = self.hidden2tag(lstm_out)
#         tag_score = nn.functional.softmax(tag_score, dim=-1)
        return tag_score

In [45]:
model_name = 'V0-Embrand200-bilstm1Layer200Hidden16Batch1e-3Learn'
model_lstm = BiLSTM_CRF(config=Config)

In [46]:
model_lstm = model_lstm.cuda()

In [47]:
model_lstm.parameters

<bound method Module.parameters of BiLSTM_CRF(
  (word_embeds): Embedding(30289, 200)
  (lstm): LSTM(200, 50, bidirectional=True)
  (hidden2tag): Linear(in_features=100, out_features=9, bias=True)
  (crf): CRF(num_tags=9)
)>

In [48]:
# 手动计算验证权重
import collections
ner_tags_all = torch.cat(dataset_conll2003_train['ner_tags'])
t=collections.Counter(ner_tags_all.numpy())
res = []
for k in t:
    res.append((k, len(ner_tags_all)/9/t[k]))
print(res)

[(3, 3.5792683998664065), (0, 0.1334168085220698), (7, 6.580731691551936), (1, 3.4279629629629627), (2, 4.996589124460149), (5, 3.1687052598817305), (4, 6.108141348692104), (8, 19.58835978835979), (6, 19.554499183712664)]


In [49]:
class_weights=sklearn.utils.class_weight.compute_class_weight(
    class_weight='balanced',classes=np.unique(ner_tags_all),y=ner_tags_all.numpy())
class_weights=torch.tensor(class_weights,dtype=torch.float)
class_weights

tensor([ 0.1334,  3.4280,  4.9966,  3.5793,  6.1081,  3.1687, 19.5545,  6.5807,
        19.5884])

In [50]:
optimizer = torch.optim.Adam(model_lstm.parameters(), lr=1e-1)
loss_fn = nn.CrossEntropyLoss(reduction='mean', weight=class_weights.cuda()) 
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [51]:
def train_eval_single_epoch(model, data_iter=None, optimizer=None, loss_fn=None, is_train=False):
    if is_train:
        model.train()
    else:
        model.eval()   
    correct_curr, correct_sum, loss_sum, loss_curr = 0, 0, 0, 0
    loss_list, accuracy_list = [], []
    print('Total (training) batch = {}'.format(len(data_iter)))
    batch_i = 0
    data_iter_len = data_iter.dataset.__len__() # total sample num.
    batch_num = len(data_iter)
    
    batch_loss_list = list()
    logits_list = list()
    y_list, y_len_list = [], []
    for batch_data in data_iter:
        torch.cuda.empty_cache()
        batch_i += 1
        if is_train:
            optimizer.zero_grad()
#             model.zero_grad()
        x, x_len, y, y_len = batch_data
        x = x.cuda()
        x_len = x_len.cuda()
        y = y.cuda()
        # model predict 
        logits = model(x, x_len)
        assert logits.shape[0:2]==x.shape[0:2]
        # compute loss 
        batch_loss = 0
        for i in range(logits.size(0)): # num of samples in one batch
            loss = loss_fn(logits[i], y[i])
            batch_loss += loss 
        
        batch_loss /= logits.size(0)
        if is_train:
            batch_loss.backward()
            optimizer.step()
        
        # 记录
        batch_loss_list.append(batch_loss.item())
        logits_list.append(logits)
        y_list.append(y)
        y_len_list.append(y_len)
#         if batch_i%100==0:
#             print(batch_loss.item())
        
        x = x.cpu()
        x_len = x_len.cpu()
        y = y.cpu()
        
#     print(sum(batch_loss_list)/batch_num)
    return batch_loss_list, logits_list, y_list, y_len_list

In [52]:
def func_cal_accu_recall(logits_list=None, y_list=None, y_len_list=None):
    eval_dict = {'tp':0,'tn':0,'fp':0,'fn':0,'others':0,'n_total':0}
    n_total = 0
    for logit_i, logit in enumerate(logits_list):
        assert logit.shape[0:2] == y_list[logit_i].shape
        batch_len = len(y_len_list[logit_i])
        batch_max_seqlen = max(y_len_list[logit_i])
        y_curr = y_list[logit_i]
        logit_argmax = torch.argmax(logit, dim=2)

        # get mask matrix
        mask = torch.zeros((batch_len, batch_max_seqlen))
        for mask_i in range(mask.shape[0]):
            mask[mask_i][0:y_len_list[logit_i][mask_i]] = 1
        assert sum(mask.sum(axis=1)==y_len_list[logit_i])//batch_len==1

        # cal the tp,tn,fp,fn in this batch
        N = mask.sum()
        TP = ((logit_argmax>0)*(y_curr>0)*(logit_argmax==y_curr)*mask.cuda()).sum()
        TN = ((logit_argmax==0)*(y_curr==0)*(logit_argmax==y_curr)*mask.cuda()).sum()
        FP = ((logit_argmax>0)*(y_curr==0)*(logit_argmax!=y_curr)*mask.cuda()).sum()
        FN = ((logit_argmax==0)*(y_curr>0)*(logit_argmax!=y_curr)*mask.cuda()).sum()
        others = ((logit_argmax>0)*(y_curr>0)*(logit_argmax!=y_curr)*mask.cuda()).sum()

        eval_dict['tp'] += TP.item()
        eval_dict['tn'] += TN.item()
        eval_dict['fp'] += FP.item()
        eval_dict['fn'] += FN.item()
        eval_dict['others'] += others.item()
        eval_dict['n_total'] += N.item()

        accu = (TP+TN) / N
        recall = TP / (TP + FN)
#         print(accu, recall)
    
    if False:
        print('Total accu = {:.2f}% recall = {:.2f}%'.format(
            (eval_dict['tp'] + eval_dict['tn'])/eval_dict['n_total']*100, 
            eval_dict['tp']/(eval_dict['tp'] + eval_dict['fn'])*100))
    return eval_dict

In [53]:
def func_eval(model, data_iter=None, loss_fn=None):
    model.eval()
    correct_curr, correct_sum, loss_sum, loss_curr = 0, 0, 0, 0
    loss_list, accuracy_list = [], []
    print('Total (training) batch = {}'.format(len(data_iter)))
    batch_i = 0
    data_iter_len = data_iter.dataset.__len__() # total sample num.
    batch_num = len(data_iter)
    
    batch_loss_list = list()
    logits_list = list()
    y_list, y_len_list = [], []
    for batch_data in data_iter:
        batch_i += 1
        x, x_len, y, y_len = batch_data
        x = x.cuda()
        x_len = x_len.cuda()
        y = y.cuda()
        # model predict 
        logits = model(x, x_len)
        assert logits.shape[0:2]==x.shape[0:2]
        # compute loss 
        batch_loss = 0
        for i in range(logits.size(0)): # num of samples in one batch
            loss = loss_fn(logits[i], y[i])
            batch_loss += loss 
        
        batch_loss /= logits.size(0)
        
        # 记录
        batch_loss_list.append(batch_loss.item())
        logits_list.append(logits)
        y_list.append(y)
        y_len_list.append(y_len)
#         if batch_i%100==0:
#             print(batch_loss.item())
        
        x = x.cpu()
        x_len = x_len.cpu()
        y = y.cpu()
        
#     print(sum(batch_loss_list)/batch_num)
    return batch_loss_list, logits_list, y_list, y_len_list

## 训练开始

In [54]:
F1_test_max, F1_test_curr = 0, 0

for epoch in tqdm(range(50)):
    print('='*50)
    print('Epoch = {}'.format(epoch))
    print('='*50)
    batch_loss_list, logits_list, y_list, y_len_list = train_eval_single_epoch(model_lstm, 
                                                                               dataset_conll2003_train_loader, 
                                                                               optimizer=optimizer,
                                                                               loss_fn=loss_fn,
                                                                               is_train=True)
    scheduler.step() # 加上后好一些
    print('training loss = {:.8f}'.format(sum(batch_loss_list)/len(batch_loss_list)))
    test_batch_loss_list, test_logits_list, test_y_list, test_y_len_list = func_eval(model_lstm, 
                                                                               dataset_conll2003_test_loader, 
                                                                               loss_fn=loss_fn)
    print('testing loss = {:.8f}'.format(sum(test_batch_loss_list)/len(test_batch_loss_list)))
    
    # 评估
    train_dict = func_cal_accu_recall(logits_list=logits_list, y_list=y_list, y_len_list=y_len_list)
    test_dict = func_cal_accu_recall(logits_list=test_logits_list, y_list=test_y_list, y_len_list=test_y_len_list)
    
    print('Train :Total accu = {:.2f}% recall = {:.2f}%'.format(
            (train_dict['tp'] + train_dict['tn'])/train_dict['n_total']*100, 
            train_dict['tp']/(train_dict['tp'] + train_dict['fn'])*100))
    
    print('Test :Total accu = {:.2f}% recall = {:.2f}%'.format(
            (test_dict['tp'] + test_dict['tn'])/test_dict['n_total']*100, 
            test_dict['tp']/(test_dict['tp'] + test_dict['fn'])*100))
    
    # save model if current f1 rate is better than previous ones 
    accu_test = (test_dict['tp'] + test_dict['tn'])/test_dict['n_total']
    recall_test = test_dict['tp']/(test_dict['tp'] + test_dict['fn'])
    F1_test_curr = 1/(1/accu_test+1/recall_test)
    if F1_test_curr>F1_test_max:
        torch.save(model_lstm.state_dict(), './models/model_v0_'+model_name+'epoch='+str(epoch)+
                   'accu='+str(round(accu_test,4))+
                   'recall='+str(round(recall_test,4))+
                   'F1='+str(round(F1_test_curr,4)))
    F1_test_max = max(F1_test_curr, F1_test_max)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch = 0
Total (training) batch = 439
training loss = 1.11865732
Total (training) batch = 108
testing loss = 1.18347435


  from ipykernel import kernelapp as app
  2%|▏         | 1/50 [00:14<11:36, 14.22s/it]

Train :Total accu = 68.88% recall = 73.88%
Test :Total accu = 66.67% recall = 73.10%
Epoch = 1
Total (training) batch = 439
training loss = 0.95371711
Total (training) batch = 108
testing loss = 1.13154795


  4%|▍         | 2/50 [00:28<11:33, 14.44s/it]

Train :Total accu = 73.31% recall = 81.14%
Test :Total accu = 68.02% recall = 76.90%
Epoch = 2
Total (training) batch = 439
training loss = 0.86974234
Total (training) batch = 108
testing loss = 1.13118567


  6%|▌         | 3/50 [00:43<11:13, 14.34s/it]

Train :Total accu = 75.42% recall = 84.35%
Test :Total accu = 70.21% recall = 78.77%
Epoch = 3
Total (training) batch = 439
training loss = 0.78000184
Total (training) batch = 108
testing loss = 1.05750550


  8%|▊         | 4/50 [00:57<10:55, 14.24s/it]

Train :Total accu = 78.71% recall = 87.11%
Test :Total accu = 73.04% recall = 81.12%
Epoch = 4
Total (training) batch = 439
training loss = 0.69043518
Total (training) batch = 108
testing loss = 1.07245177


 10%|█         | 5/50 [01:11<10:41, 14.26s/it]

Train :Total accu = 81.42% recall = 89.64%
Test :Total accu = 75.56% recall = 80.01%
Epoch = 5
Total (training) batch = 439
training loss = 0.63456507
Total (training) batch = 108
testing loss = 1.02980752


 12%|█▏        | 6/50 [01:25<10:28, 14.27s/it]

Train :Total accu = 83.23% recall = 91.27%
Test :Total accu = 75.62% recall = 83.76%
Epoch = 6
Total (training) batch = 439
training loss = 0.57226409
Total (training) batch = 108
testing loss = 0.97377480


 14%|█▍        | 7/50 [01:40<10:14, 14.30s/it]

Train :Total accu = 84.95% recall = 92.35%
Test :Total accu = 76.86% recall = 85.50%
Epoch = 7
Total (training) batch = 439
training loss = 0.52990232
Total (training) batch = 108
testing loss = 1.00981725


 16%|█▌        | 8/50 [01:54<10:00, 14.30s/it]

Train :Total accu = 85.97% recall = 93.43%
Test :Total accu = 78.02% recall = 84.69%
Epoch = 8
Total (training) batch = 439
training loss = 0.50117510
Total (training) batch = 108
testing loss = 0.97807315


 18%|█▊        | 9/50 [02:08<09:47, 14.33s/it]

Train :Total accu = 86.64% recall = 93.91%
Test :Total accu = 78.53% recall = 85.33%
Epoch = 9
Total (training) batch = 439
training loss = 0.46029196
Total (training) batch = 108
testing loss = 0.99282025


 20%|██        | 10/50 [02:23<09:35, 14.38s/it]

Train :Total accu = 87.79% recall = 94.57%
Test :Total accu = 78.81% recall = 86.32%
Epoch = 10
Total (training) batch = 439
training loss = 0.42851140
Total (training) batch = 108
testing loss = 0.96449747


 22%|██▏       | 11/50 [02:37<09:21, 14.40s/it]

Train :Total accu = 88.65% recall = 95.18%
Test :Total accu = 79.97% recall = 86.82%
Epoch = 11
Total (training) batch = 439
training loss = 0.39430666
Total (training) batch = 108
testing loss = 0.93988831


 24%|██▍       | 12/50 [02:51<09:03, 14.30s/it]

Train :Total accu = 89.75% recall = 95.71%
Test :Total accu = 80.65% recall = 88.27%
Epoch = 12
Total (training) batch = 439
training loss = 0.36089305
Total (training) batch = 108
testing loss = 0.94421653


 26%|██▌       | 13/50 [03:05<08:44, 14.19s/it]

Train :Total accu = 90.65% recall = 96.35%
Test :Total accu = 80.71% recall = 89.49%
Epoch = 13
Total (training) batch = 439
training loss = 0.33112882
Total (training) batch = 108
testing loss = 0.96129725


 28%|██▊       | 14/50 [03:19<08:31, 14.21s/it]

Train :Total accu = 91.41% recall = 96.85%
Test :Total accu = 81.68% recall = 89.39%
Epoch = 14
Total (training) batch = 439
training loss = 0.30692836
Total (training) batch = 108
testing loss = 0.96005372


 30%|███       | 15/50 [03:34<08:19, 14.28s/it]

Train :Total accu = 92.00% recall = 97.07%
Test :Total accu = 80.93% recall = 90.28%
Epoch = 15
Total (training) batch = 439
training loss = 0.28493097
Total (training) batch = 108
testing loss = 0.94898864


 32%|███▏      | 16/50 [03:48<08:06, 14.31s/it]

Train :Total accu = 92.47% recall = 97.43%
Test :Total accu = 82.12% recall = 90.44%
Epoch = 16
Total (training) batch = 439
training loss = 0.26307265
Total (training) batch = 108
testing loss = 0.96223267


 34%|███▍      | 17/50 [04:03<07:54, 14.37s/it]

Train :Total accu = 93.00% recall = 97.69%
Test :Total accu = 82.69% recall = 89.96%
Epoch = 17
Total (training) batch = 439
training loss = 0.24401608
Total (training) batch = 108
testing loss = 0.97229819


 36%|███▌      | 18/50 [04:17<07:40, 14.38s/it]

Train :Total accu = 93.49% recall = 98.01%
Test :Total accu = 82.26% recall = 90.49%
Epoch = 18
Total (training) batch = 439
training loss = 0.22807203
Total (training) batch = 108
testing loss = 0.97947279


 38%|███▊      | 19/50 [04:32<07:29, 14.51s/it]

Train :Total accu = 93.96% recall = 98.15%
Test :Total accu = 82.79% recall = 89.94%
Epoch = 19
Total (training) batch = 439
training loss = 0.21176266
Total (training) batch = 108
testing loss = 0.98192497


 40%|████      | 20/50 [04:47<07:15, 14.51s/it]

Train :Total accu = 94.31% recall = 98.30%
Test :Total accu = 82.67% recall = 90.77%
Epoch = 20
Total (training) batch = 439
training loss = 0.19795583
Total (training) batch = 108
testing loss = 0.99596207


 42%|████▏     | 21/50 [05:01<06:59, 14.45s/it]

Train :Total accu = 94.68% recall = 98.46%
Test :Total accu = 82.89% recall = 90.96%
Epoch = 21
Total (training) batch = 439
training loss = 0.18655092
Total (training) batch = 108
testing loss = 1.03285700


 44%|████▍     | 22/50 [05:15<06:45, 14.50s/it]

Train :Total accu = 94.92% recall = 98.48%
Test :Total accu = 83.26% recall = 90.43%
Epoch = 22
Total (training) batch = 439
training loss = 0.17737473
Total (training) batch = 108
testing loss = 1.01728134


 46%|████▌     | 23/50 [05:30<06:30, 14.45s/it]

Train :Total accu = 95.22% recall = 98.65%
Test :Total accu = 82.95% recall = 90.79%
Epoch = 23
Total (training) batch = 439
training loss = 0.17035055
Total (training) batch = 108
testing loss = 1.03948058


 48%|████▊     | 24/50 [05:44<06:16, 14.47s/it]

Train :Total accu = 95.41% recall = 98.78%
Test :Total accu = 83.20% recall = 90.35%
Epoch = 24
Total (training) batch = 439
training loss = 0.16412671
Total (training) batch = 108
testing loss = 1.05153486


 50%|█████     | 25/50 [05:59<06:01, 14.46s/it]

Train :Total accu = 95.66% recall = 98.83%
Test :Total accu = 83.17% recall = 90.80%
Epoch = 25
Total (training) batch = 439
training loss = 0.15962023
Total (training) batch = 108
testing loss = 1.05268786


 52%|█████▏    | 26/50 [06:13<05:45, 14.38s/it]

Train :Total accu = 95.70% recall = 98.89%
Test :Total accu = 83.26% recall = 90.58%
Epoch = 26
Total (training) batch = 439
training loss = 0.15487291
Total (training) batch = 108
testing loss = 1.06320943


 54%|█████▍    | 27/50 [06:27<05:31, 14.40s/it]

Train :Total accu = 95.86% recall = 98.95%
Test :Total accu = 83.67% recall = 90.35%
Epoch = 27
Total (training) batch = 439
training loss = 0.15208751
Total (training) batch = 108
testing loss = 1.06092695


 56%|█████▌    | 28/50 [06:42<05:15, 14.36s/it]

Train :Total accu = 95.94% recall = 98.93%
Test :Total accu = 83.52% recall = 90.78%
Epoch = 28
Total (training) batch = 439
training loss = 0.14894454
Total (training) batch = 108
testing loss = 1.07232034


 58%|█████▊    | 29/50 [06:56<05:01, 14.37s/it]

Train :Total accu = 95.99% recall = 99.01%
Test :Total accu = 83.66% recall = 90.36%
Epoch = 29
Total (training) batch = 439
training loss = 0.14702654
Total (training) batch = 108
testing loss = 1.07421950


 60%|██████    | 30/50 [07:10<04:46, 14.31s/it]

Train :Total accu = 96.08% recall = 98.99%
Test :Total accu = 83.59% recall = 90.45%
Epoch = 30
Total (training) batch = 439
training loss = 0.14549607
Total (training) batch = 108
testing loss = 1.07082562


 62%|██████▏   | 31/50 [07:24<04:31, 14.27s/it]

Train :Total accu = 96.10% recall = 99.01%
Test :Total accu = 83.65% recall = 90.50%
Epoch = 31
Total (training) batch = 439
training loss = 0.14425158
Total (training) batch = 108
testing loss = 1.07491624


 64%|██████▍   | 32/50 [07:39<04:16, 14.25s/it]

Train :Total accu = 96.16% recall = 99.06%
Test :Total accu = 83.72% recall = 90.24%
Epoch = 32
Total (training) batch = 439
training loss = 0.14316513
Total (training) batch = 108
testing loss = 1.07943118


 66%|██████▌   | 33/50 [07:53<04:01, 14.23s/it]

Train :Total accu = 96.20% recall = 99.02%
Test :Total accu = 83.64% recall = 90.43%
Epoch = 33
Total (training) batch = 439
training loss = 0.14217226
Total (training) batch = 108
testing loss = 1.07911226


 68%|██████▊   | 34/50 [08:07<03:48, 14.26s/it]

Train :Total accu = 96.18% recall = 99.05%
Test :Total accu = 83.65% recall = 90.35%
Epoch = 34
Total (training) batch = 439
training loss = 0.14122511
Total (training) batch = 108
testing loss = 1.07895740


 70%|███████   | 35/50 [08:22<03:36, 14.40s/it]

Train :Total accu = 96.24% recall = 99.01%
Test :Total accu = 83.57% recall = 90.51%
Epoch = 35
Total (training) batch = 439
training loss = 0.14091379
Total (training) batch = 108
testing loss = 1.09017754


 72%|███████▏  | 36/50 [08:36<03:22, 14.46s/it]

Train :Total accu = 96.18% recall = 99.06%
Test :Total accu = 83.58% recall = 90.66%
Epoch = 36
Total (training) batch = 439
training loss = 0.14014432
Total (training) batch = 108
testing loss = 1.08698147


 74%|███████▍  | 37/50 [08:51<03:07, 14.44s/it]

Train :Total accu = 96.22% recall = 99.08%
Test :Total accu = 83.76% recall = 90.42%
Epoch = 37
Total (training) batch = 439
training loss = 0.13936851
Total (training) batch = 108
testing loss = 1.08578591


 76%|███████▌  | 38/50 [09:05<02:53, 14.46s/it]

Train :Total accu = 96.28% recall = 99.07%
Test :Total accu = 83.77% recall = 90.41%
Epoch = 38
Total (training) batch = 439
training loss = 0.13897065
Total (training) batch = 108
testing loss = 1.09148828


 78%|███████▊  | 39/50 [09:19<02:38, 14.37s/it]

Train :Total accu = 96.28% recall = 99.06%
Test :Total accu = 83.68% recall = 90.47%
Epoch = 39
Total (training) batch = 439
training loss = 0.13838445
Total (training) batch = 108
testing loss = 1.08890588


 80%|████████  | 40/50 [09:34<02:23, 14.32s/it]

Train :Total accu = 96.27% recall = 99.07%
Test :Total accu = 83.71% recall = 90.55%
Epoch = 40
Total (training) batch = 439
training loss = 0.13786344
Total (training) batch = 108
testing loss = 1.08369221


 82%|████████▏ | 41/50 [09:48<02:09, 14.34s/it]

Train :Total accu = 96.28% recall = 99.06%
Test :Total accu = 83.70% recall = 90.46%
Epoch = 41
Total (training) batch = 439
training loss = 0.13741606
Total (training) batch = 108
testing loss = 1.09134561


 84%|████████▍ | 42/50 [10:03<01:55, 14.43s/it]

Train :Total accu = 96.30% recall = 99.07%
Test :Total accu = 83.76% recall = 90.55%
Epoch = 42
Total (training) batch = 439
training loss = 0.13719902
Total (training) batch = 108
testing loss = 1.08847481


 86%|████████▌ | 43/50 [10:17<01:40, 14.42s/it]

Train :Total accu = 96.28% recall = 99.09%
Test :Total accu = 83.77% recall = 90.54%
Epoch = 43
Total (training) batch = 439
training loss = 0.13674707
Total (training) batch = 108
testing loss = 1.08606759


 88%|████████▊ | 44/50 [10:32<01:26, 14.46s/it]

Train :Total accu = 96.30% recall = 99.09%
Test :Total accu = 83.78% recall = 90.50%
Epoch = 44
Total (training) batch = 439
training loss = 0.13626052
Total (training) batch = 108
testing loss = 1.08623647


 90%|█████████ | 45/50 [10:46<01:12, 14.46s/it]

Train :Total accu = 96.33% recall = 99.09%
Test :Total accu = 83.81% recall = 90.50%
Epoch = 45
Total (training) batch = 439
training loss = 0.13587303
Total (training) batch = 108
testing loss = 1.08735964


 92%|█████████▏| 46/50 [11:01<00:57, 14.49s/it]

Train :Total accu = 96.34% recall = 99.08%
Test :Total accu = 83.83% recall = 90.49%
Epoch = 46
Total (training) batch = 439
training loss = 0.13579427
Total (training) batch = 108
testing loss = 1.09117054


 94%|█████████▍| 47/50 [11:15<00:43, 14.44s/it]

Train :Total accu = 96.36% recall = 99.09%
Test :Total accu = 83.81% recall = 90.57%
Epoch = 47
Total (training) batch = 439
training loss = 0.13563762
Total (training) batch = 108
testing loss = 1.08889939


 96%|█████████▌| 48/50 [11:29<00:28, 14.43s/it]

Train :Total accu = 96.36% recall = 99.09%
Test :Total accu = 83.81% recall = 90.51%
Epoch = 48
Total (training) batch = 439
training loss = 0.13522672
Total (training) batch = 108
testing loss = 1.09315020


 98%|█████████▊| 49/50 [11:44<00:14, 14.37s/it]

Train :Total accu = 96.37% recall = 99.08%
Test :Total accu = 83.80% recall = 90.47%
Epoch = 49
Total (training) batch = 439
training loss = 0.13509511
Total (training) batch = 108
testing loss = 1.09197418


100%|██████████| 50/50 [11:58<00:00, 14.37s/it]

Train :Total accu = 96.36% recall = 99.08%
Test :Total accu = 83.81% recall = 90.54%





## 训练总结

能发现验证集在 21 epoch 时已达到最好状态，之后train开始过拟合，eval 也没办法继续提升准确率。

epoch=0
- Train :Total accu = 68.88% recall = 73.88%
- Test :Total accu = 66.67% recall = 73.10%

epoch=10
- Train :Total accu = 88.65% recall = 95.18%
- Test :Total accu = 79.97% recall = 86.82%

epoch=21
- Train :Total accu = 94.92% recall = 98.48%
- Test :Total accu = 83.26% recall = 90.43%

epoch=50
- Train :Total accu = 96.36% recall = 99.08%
- Test :Total accu = 83.81% recall = 90.54%

## 模型加载及结果探查

In [57]:
model_lstm_load = BiLSTM_CRF(config=Config)
model_lstm_load.load_state_dict(torch.load('./models/model_v0_V0-Embrand200-bilstm1Layer200Hidden16Batch1e-3Learnepoch=27accu=0.835210509314095recall=0.9078498293515358F1=0.43500830208429403'))

<All keys matched successfully>

In [58]:
model_lstm_load

BiLSTM_CRF(
  (word_embeds): Embedding(30289, 200)
  (lstm): LSTM(200, 50, bidirectional=True)
  (hidden2tag): Linear(in_features=100, out_features=9, bias=True)
  (crf): CRF(num_tags=9)
)

In [98]:
t=next(iter(dataset_conll2003_test_loader))
print(t)
model_lstm_load.cuda()
model_lstm_load.eval()
y = t[2]
res = model_lstm_load(t[0].cuda(), t[1].cuda())
res_arg = torch.argmax(res, dim=2)

(tensor([[   81,    16,    88,  ...,     0,     0,     0],
        [ 1835,  2823,  3733,  ...,     0,     0,     0],
        [ 1921, 23318,     0,  ...,     0,     0,     0],
        ...,
        [10569,  1592,  2087,  ...,     0,     0,     0],
        [30177, 30178, 27834,  ...,     0,     0,     0],
        [27575,  2169,   693,  ...,     0,     0,     0]]), tensor([33, 25,  2, 10, 13,  4,  7,  4,  8, 14, 16, 24, 11,  8, 22,  2, 33, 26,
         7,  8,  2,  1, 19,  6, 36,  2,  3, 20,  6,  8,  3,  8]), tensor([[0, 0, 0,  ..., 0, 0, 0],
        [5, 6, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [3, 0, 0,  ..., 0, 0, 0],
        [5, 6, 0,  ..., 0, 0, 0],
        [3, 0, 0,  ..., 0, 0, 0]]), tensor([33, 25,  2, 10, 13,  4,  7,  4,  8, 14, 16, 24, 11,  8, 22,  2, 33, 26,
         7,  8,  2,  1, 19,  6, 36,  2,  3, 20,  6,  8,  3,  8]), (['He', 'said', 'that', 'the', 'procedure', 'to', 'insert', 'a', 'device', 'into', 'Havel', "'s", 'throat', ',', 'done', 'aft

### 案例探查总结

一些误判如下：
- 一些容易修正的误判（加规则或者加CRF）
's, y: O, predict: I-ORG

- 地点误判为机构
United, y: B-LOC, predict: I-ORG Arab, y: I-LOC, predict: I-LOC Emirates, y: I-LOC, predict: I-ORG

In [110]:
# true value
y_ner_list = []
for sent in y:
    tmp = []
    for token in sent:
        tmp.append(ner_id2tag[token.item()])
    y_ner_list.append(tmp)

# predict
res_ner_list = []
for sent in res_arg:
    tmp = []
    for token in sent:
        tmp.append(ner_id2tag[token.item()])
    res_ner_list.append(tmp)

for i, sent in enumerate(t[4]):
    print('='*50)
    for j, token in enumerate(sent):
        print('{}, y: {}, predict: {}'.format(token, y_ner_list[i][j], res_ner_list[i][j]))

He, y: O, predict: O
said, y: O, predict: O
that, y: O, predict: O
the, y: O, predict: O
procedure, y: O, predict: O
to, y: O, predict: O
insert, y: O, predict: O
a, y: O, predict: O
device, y: O, predict: O
into, y: O, predict: O
Havel, y: B-PER, predict: B-PER
's, y: O, predict: I-ORG
throat, y: O, predict: B-MISC
,, y: O, predict: O
done, y: O, predict: O
after, y: O, predict: O
his, y: O, predict: O
breathing, y: O, predict: B-PER
worsened, y: O, predict: B-PER
on, y: O, predict: O
Thursday, y: O, predict: O
,, y: O, predict: O
had, y: O, predict: O
helped, y: O, predict: O
,, y: O, predict: O
and, y: O, predict: O
the, y: O, predict: O
president, y: O, predict: O
's, y: O, predict: I-LOC
condition, y: O, predict: O
significantly, y: O, predict: O
improved, y: O, predict: O
., y: O, predict: O
South, y: B-LOC, predict: B-MISC
Korea, y: I-LOC, predict: I-LOC
made, y: O, predict: O
virtually, y: O, predict: O
certain, y: O, predict: O
of, y: O, predict: O
an, y: O, predict: O
Asian, 