In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import tqdm
from nn_DataProcess import prepare_data, build_word2vec, Data_set
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import os
from model import LSTMModel, LSTM_attention
from nn_Config import Config
from nn_eval import val_accuary


def train(train_dataloader, model, device, epoches, lr):
    """

    :param train_dataloader:
    :param model:
    :param device:
    :param epoches:
    :param lr:
    :return:
    """
    # 模型为训练模式
    model.train()
    # 将模型转化到gpu上
    model = model.to(device)
    print(model)
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)  # 学习率调整
    best_acc = 0.88
    # 一个epoch可以认为是一次训练循环
    for epoch in range(epoches):
        train_loss = 0.0
        correct = 0
        total = 0

        # tqdm用在dataloader上其实是对每个batch和batch总数做的进度条
        train_dataloader = tqdm.tqdm(train_dataloader)
        # train_dataloader.set_description('[%s%04d/%04d %s%f]' % ('Epoch:', epoch + 1, epoches, 'lr:', scheduler.get_last_lr()[0]))
        # 遍历每个batch size数据
        for i, data_ in (enumerate(train_dataloader)):
            # 梯度清零
            optimizer.zero_grad()
            input_, target = data_[0], data_[1]
            # 将数据类型转化为整数
            input_ = input_.type(torch.LongTensor)
            target = target.type(torch.LongTensor)
            # 将数据转换到gpu上
            input_ = input_.to(device)
            target = target.to(device)
            # 前向传播
            output = model(input_)
            # 扩充维度
            target = target.squeeze(1)
            # 损失
            loss = criterion(output, target)
            # 反向传播
            loss.backward()
            # 梯度更新
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(output, 1)
            # print(predicted.shape)
            # 计数
            total += target.size(0)  # 此处的size()类似numpy的shape: np.shape(train_images)[0]
            # print(target.shape)
            # 计算预测正确的个数
            correct += (predicted == target).sum().item()
            # 评价指标F1、Recall、混淆矩阵
            F1 = f1_score(target.cpu(), predicted.cpu(), average='weighted')
            Recall = recall_score(target.cpu(), predicted.cpu(), average='micro')
            # CM=confusion_matrix(target.cpu(),predicted.cpu())
            postfix = {'train_loss: {:.5f},train_acc:{:.3f}%'
                       ',F1: {:.3f}%,Recall:{:.3f}%'.format(train_loss / (i + 1),
                                                            100 * correct / total, 100 * F1, 100 * Recall)}
            # tqdm pbar.set_postfix：设置训练时的输出
            train_dataloader.set_postfix(log=postfix)

        # 计算验证集的准确率
        acc = val_accuary(model, val_dataloader, device, criterion)
        # 当准确率提升时，保存模型。
        if acc > best_acc:
            best_acc = acc
            if os.path.exists(Config.model_state_dict_path) == False:
                os.mkdir(Config.model_state_dict_path)
            save_path = '{}_epoch_{}.pkl'.format("nn_model", epoch)
            print(os.path.join(Config.model_state_dict_path, save_path))
            torch.save(model, os.path.join(Config.model_state_dict_path, save_path))
        # 恢复到训练模式
        model.train()


if __name__ == '__main__':
    splist = []
    # 构建word2id词典
    word2id = {}
    with open(Config.word2id_path, encoding='utf-8') as f:
        for line in f.readlines():
            sp = line.strip().split()  # 去掉\n \t 等
            splist.append(sp)
        word2id = dict(splist)  # 转成字典

    # 转换索引的数据类型为整数
    for key in word2id:
        word2id[key] = int(word2id[key])

    # 构建id2word
    id2word = {}
    for key, val in word2id.items():
        id2word[val] = key

    # 设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 得到数字索引表示的句子和标签
    train_array, train_lable, val_array, val_lable, test_array, test_lable = prepare_data(word2id,
                                                                                          train_path=Config.train_path,
                                                                                          val_path=Config.val_path,
                                                                                          test_path=Config.test_path,
                                                                                          seq_lenth=Config.max_sen_len)
    # 构建训练Data_set与DataLoader
    train_loader = Data_set(train_array, train_lable)
    train_dataloader = DataLoader(train_loader,
                                  batch_size=Config.batch_size,
                                  shuffle=True,
                                  num_workers=0)  # 用了workers反而变慢了
    # 构建验证Data_set与DataLoader
    val_loader = Data_set(val_array, val_lable)
    val_dataloader = DataLoader(val_loader,
                                batch_size=Config.batch_size,
                                shuffle=True,
                                num_workers=0)

    # 构建测试Data_set与DataLoader
    test_loader = Data_set(test_array, test_lable)
    test_dataloader = DataLoader(test_loader,
                                 batch_size=Config.batch_size,
                                 shuffle=True,
                                 num_workers=0)
    # 构建word2vec词向量
    w2vec = build_word2vec(Config.pre_word2vec_path, word2id, None)
    # 将词向量转化为Tensor
    w2vec = torch.from_numpy(w2vec)
    # CUDA接受float32，不接受float64
    w2vec = w2vec.float()
    # LSTM_attention
    model = LSTM_attention(Config.vocab_size, Config.embedding_dim, w2vec, Config.update_w2v,
                           Config.hidden_dim, Config.num_layers, Config.drop_keep_prob, Config.n_class,
                           Config.bidirectional)

    # 训练
    train(train_dataloader, model=model, device=device, epoches=Config.n_epoch, lr=Config.lr)

LSTM_attention(
  (embedding): Embedding(54848, 50)
  (encoder): LSTM(50, 100, num_layers=2, batch_first=True, dropout=0.25, bidirectional=True)
  (decoder1): Linear(in_features=200, out_features=100, bias=True)
  (decoder2): Linear(in_features=100, out_features=2, bias=True)
)


100%|██████████| 313/313 [01:00<00:00,  5.18it/s, log={'train_loss: 0.57503,train_acc:68.532%,F1: 80.556%,Recall:80.000%'}]



Val accuracy : 77.065%,val_loss:42.529, F1_score：80.328%, Recall：80.328%
nn_models/nn_model_epoch_0.pkl


100%|██████████| 313/313 [01:04<00:00,  4.87it/s, log={'train_loss: 0.46501,train_acc:77.748%,F1: 73.092%,Recall:73.333%'}]



Val accuracy : 78.149%,val_loss:41.055, F1_score：73.770%, Recall：73.770%
nn_models/nn_model_epoch_1.pkl


100%|██████████| 313/313 [00:59<00:00,  5.24it/s, log={'train_loss: 0.43751,train_acc:79.243%,F1: 70.505%,Recall:70.000%'}]



Val accuracy : 78.966%,val_loss:39.505, F1_score：73.610%, Recall：73.770%
nn_models/nn_model_epoch_2.pkl


100%|██████████| 313/313 [01:02<00:00,  5.00it/s, log={'train_loss: 0.41431,train_acc:80.983%,F1: 73.333%,Recall:73.333%'}]



Val accuracy : 79.837%,val_loss:39.230, F1_score：86.885%, Recall：86.885%
nn_models/nn_model_epoch_3.pkl


100%|██████████| 313/313 [01:03<00:00,  4.95it/s, log={'train_loss: 0.38724,train_acc:82.688%,F1: 69.766%,Recall:70.000%'}]



Val accuracy : 81.098%,val_loss:37.692, F1_score：86.892%, Recall：86.885%
nn_models/nn_model_epoch_4.pkl


100%|██████████| 313/313 [01:01<00:00,  5.08it/s, log={'train_loss: 0.36450,train_acc:84.233%,F1: 69.486%,Recall:70.000%'}]



Val accuracy : 81.133%,val_loss:37.098, F1_score：73.470%, Recall：73.770%
nn_models/nn_model_epoch_5.pkl


100%|██████████| 313/313 [01:01<00:00,  5.10it/s, log={'train_loss: 0.33869,train_acc:85.454%,F1: 90.034%,Recall:90.000%'}]



Val accuracy : 81.489%,val_loss:36.460, F1_score：90.276%, Recall：90.164%
nn_models/nn_model_epoch_6.pkl


100%|██████████| 313/313 [00:58<00:00,  5.35it/s, log={'train_loss: 0.31849,train_acc:86.364%,F1: 86.481%,Recall:86.667%'}]



Val accuracy : 81.116%,val_loss:38.118, F1_score：81.770%, Recall：81.967%


100%|██████████| 313/313 [01:00<00:00,  5.18it/s, log={'train_loss: 0.23583,train_acc:90.919%,F1: 90.082%,Recall:90.000%'}]



Val accuracy : 81.347%,val_loss:40.877, F1_score：85.278%, Recall：85.246%


100%|██████████| 313/313 [01:01<00:00,  5.05it/s, log={'train_loss: 0.21789,train_acc:91.584%,F1: 96.686%,Recall:96.667%'}]  



Val accuracy : 82.874%,val_loss:38.405, F1_score：80.328%, Recall：80.328%
nn_models/nn_model_epoch_12.pkl


100%|██████████| 313/313 [00:59<00:00,  5.30it/s, log={'train_loss: 0.20222,train_acc:92.304%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 82.768%,val_loss:38.622, F1_score：81.938%, Recall：81.967%


100%|██████████| 313/313 [00:59<00:00,  5.30it/s, log={'train_loss: 0.18537,train_acc:93.254%,F1: 96.670%,Recall:96.667%'}]  



Val accuracy : 82.270%,val_loss:40.098, F1_score：82.512%, Recall：81.967%


100%|██████████| 313/313 [00:59<00:00,  5.26it/s, log={'train_loss: 0.16650,train_acc:94.124%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 82.715%,val_loss:43.885, F1_score：80.328%, Recall：80.328%


100%|██████████| 313/313 [01:00<00:00,  5.17it/s, log={'train_loss: 0.15302,train_acc:94.504%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 82.324%,val_loss:44.276, F1_score：85.246%, Recall：85.246%


100%|██████████| 313/313 [00:58<00:00,  5.32it/s, log={'train_loss: 0.13646,train_acc:95.465%,F1: 93.394%,Recall:93.333%'}]  



Val accuracy : 81.791%,val_loss:47.780, F1_score：80.317%, Recall：80.328%


100%|██████████| 313/313 [00:58<00:00,  5.34it/s, log={'train_loss: 0.12142,train_acc:96.110%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 81.986%,val_loss:50.444, F1_score：88.562%, Recall：88.525%


100%|██████████| 313/313 [01:00<00:00,  5.13it/s, log={'train_loss: 0.11296,train_acc:96.370%,F1: 96.670%,Recall:96.667%'}]  



Val accuracy : 81.720%,val_loss:51.524, F1_score：86.885%, Recall：86.885%


100%|██████████| 313/313 [01:01<00:00,  5.10it/s, log={'train_loss: 0.09789,train_acc:96.980%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 81.951%,val_loss:53.070, F1_score：86.773%, Recall：86.885%


100%|██████████| 313/313 [00:57<00:00,  5.45it/s, log={'train_loss: 0.08822,train_acc:97.350%,F1: 96.663%,Recall:96.667%'}]  



Val accuracy : 82.039%,val_loss:59.408, F1_score：88.449%, Recall：88.525%


100%|██████████| 313/313 [00:58<00:00,  5.40it/s, log={'train_loss: 0.07638,train_acc:97.900%,F1: 96.620%,Recall:96.667%'}]  



Val accuracy : 81.364%,val_loss:63.394, F1_score：85.295%, Recall：85.246%


100%|██████████| 313/313 [00:59<00:00,  5.27it/s, log={'train_loss: 0.06747,train_acc:98.140%,F1: 96.678%,Recall:96.667%'}]  



Val accuracy : 81.489%,val_loss:62.711, F1_score：88.506%, Recall：88.525%


100%|██████████| 313/313 [00:57<00:00,  5.43it/s, log={'train_loss: 0.06194,train_acc:98.375%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 80.796%,val_loss:72.232, F1_score：78.631%, Recall：78.689%


100%|██████████| 313/313 [00:55<00:00,  5.62it/s, log={'train_loss: 0.05397,train_acc:98.570%,F1: 96.670%,Recall:96.667%'}]  



Val accuracy : 80.814%,val_loss:72.696, F1_score：78.677%, Recall：78.689%


100%|██████████| 313/313 [00:57<00:00,  5.41it/s, log={'train_loss: 0.04741,train_acc:98.820%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 81.151%,val_loss:75.984, F1_score：85.238%, Recall：85.246%


100%|██████████| 313/313 [00:58<00:00,  5.36it/s, log={'train_loss: 0.04864,train_acc:98.705%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 80.760%,val_loss:86.008, F1_score：77.137%, Recall：77.049%


100%|██████████| 313/313 [00:57<00:00,  5.49it/s, log={'train_loss: 0.04512,train_acc:98.830%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 80.778%,val_loss:80.108, F1_score：72.267%, Recall：72.131%


100%|██████████| 313/313 [01:03<00:00,  4.96it/s, log={'train_loss: 0.03840,train_acc:99.035%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 80.867%,val_loss:79.024, F1_score：83.807%, Recall：83.607%
