In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import tqdm
from nn_DataProcess import prepare_data, build_word2vec, Data_set
from sklearn.metrics import confusion_matrix, f1_score, recall_score
import os
from model import LSTMModel, LSTM_attention
from nn_Config import Config
from nn_eval import val_accuary


def train(train_dataloader, model, device, epoches, lr):
    """

    :param train_dataloader:
    :param model:
    :param device:
    :param epoches:
    :param lr:
    :return:
    """
    # 模型为训练模式
    model.train()
    # 将模型转化到gpu上
    model = model.to(device)
    print(model)
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)  # 学习率调整
    best_acc = 0.9
    # 一个epoch可以认为是一次训练循环
    for epoch in range(epoches):
        train_loss = 0.0
        correct = 0
        total = 0

        # tqdm用在dataloader上其实是对每个batch和batch总数做的进度条
        train_dataloader = tqdm.tqdm(train_dataloader)
        # train_dataloader.set_description('[%s%04d/%04d %s%f]' % ('Epoch:', epoch + 1, epoches, 'lr:', scheduler.get_last_lr()[0]))
        # 遍历每个batch size数据
        for i, data_ in (enumerate(train_dataloader)):
            # 梯度清零
            optimizer.zero_grad()
            input_, target = data_[0], data_[1]
            # 将数据类型转化为整数
            input_ = input_.type(torch.LongTensor)
            target = target.type(torch.LongTensor)
            # 将数据转换到gpu上
            input_ = input_.to(device)
            target = target.to(device)
            # 前向传播
            output = model(input_)
            # 扩充维度
            target = target.squeeze(1)
            # 损失
            loss = criterion(output, target)
            # 反向传播
            loss.backward()
            # 梯度更新
            optimizer.step()
            train_loss += loss.item()
            _, predicted = torch.max(output, 1)
            # print(predicted.shape)
            # 计数
            total += target.size(0)  # 此处的size()类似numpy的shape: np.shape(train_images)[0]
            # print(target.shape)
            # 计算预测正确的个数
            correct += (predicted == target).sum().item()
            # 评价指标F1、Recall、混淆矩阵
            F1 = f1_score(target.cpu(), predicted.cpu(), average='weighted')
            Recall = recall_score(target.cpu(), predicted.cpu(), average='micro')
            # CM=confusion_matrix(target.cpu(),predicted.cpu())
            postfix = {'train_loss: {:.5f},train_acc:{:.3f}%'
                       ',F1: {:.3f}%,Recall:{:.3f}%'.format(train_loss / (i + 1),
                                                            100 * correct / total, 100 * F1, 100 * Recall)}
            # tqdm pbar.set_postfix：设置训练时的输出
            train_dataloader.set_postfix(log=postfix)

        # 计算验证集的准确率
        acc = val_accuary(model, val_dataloader, device, criterion)
        # 当准确率提升时，保存模型。
        if acc > best_acc:
            best_acc = acc
            if os.path.exists(Config.model_state_dict_path) == False:
                os.mkdir(Config.model_state_dict_path)
            save_path = '{}_epoch_{}.pkl'.format("nn_model", epoch)
            print(os.path.join(Config.model_state_dict_path, save_path))
            torch.save(model, os.path.join(Config.model_state_dict_path, save_path))
        # 恢复到训练模式
        model.train()


if __name__ == '__main__':
    splist = []
    # 构建word2id词典
    word2id = {}
    with open(Config.word2id_path, encoding='utf-8') as f:
        for line in f.readlines():
            sp = line.strip().split()  # 去掉\n \t 等
            splist.append(sp)
        word2id = dict(splist)  # 转成字典

    # 转换索引的数据类型为整数
    for key in word2id:
        word2id[key] = int(word2id[key])

    # 构建id2word
    id2word = {}
    for key, val in word2id.items():
        id2word[val] = key

    # 设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 得到数字索引表示的句子和标签
    train_array, train_lable, val_array, val_lable, test_array, test_lable = prepare_data(word2id,
                                                                                          train_path=Config.train_path,
                                                                                          val_path=Config.val_path,
                                                                                          test_path=Config.test_path,
                                                                                          seq_lenth=Config.max_sen_len)
    # 构建训练Data_set与DataLoader
    train_loader = Data_set(train_array, train_lable)
    train_dataloader = DataLoader(train_loader,
                                  batch_size=Config.batch_size,
                                  shuffle=True,
                                  num_workers=0)  # 用了workers反而变慢了
    # 构建验证Data_set与DataLoader
    val_loader = Data_set(val_array, val_lable)
    val_dataloader = DataLoader(val_loader,
                                batch_size=Config.batch_size,
                                shuffle=True,
                                num_workers=0)

    # 构建测试Data_set与DataLoader
    test_loader = Data_set(test_array, test_lable)
    test_dataloader = DataLoader(test_loader,
                                 batch_size=Config.batch_size,
                                 shuffle=True,
                                 num_workers=0)
    # 构建word2vec词向量
    w2vec = build_word2vec(Config.pre_word2vec_path, word2id, None)
    # 将词向量转化为Tensor
    w2vec = torch.from_numpy(w2vec)
    # CUDA接受float32，不接受float64
    w2vec = w2vec.float()
    # LSTM_attention
    model = LSTM_attention(Config.vocab_size, Config.embedding_dim, w2vec, Config.update_w2v,
                           Config.hidden_dim, Config.num_layers, Config.drop_keep_prob, Config.n_class,
                           Config.bidirectional)

    # 训练
    train(train_dataloader, model=model, device=device, epoches=Config.n_epoch, lr=Config.lr)

LSTM_attention(
  (embedding): Embedding(54848, 50)
  (encoder): LSTM(50, 100, num_layers=6, batch_first=True, dropout=0.2, bidirectional=True)
  (decoder1): Linear(in_features=200, out_features=100, bias=True)
  (decoder2): Linear(in_features=100, out_features=2, bias=True)
)


100%|██████████| 313/313 [00:03<00:00, 96.46it/s, log={'train_loss: 0.64068,train_acc:58.671%,F1: 66.667%,Recall:66.667%'}] 



Val accuracy : 74.862%,val_loss:44.957, F1_score：78.773%, Recall：78.689%
nn_models/nn_model_epoch_0.pkl


100%|██████████| 313/313 [00:03<00:00, 102.29it/s, log={'train_loss: 0.48681,train_acc:76.383%,F1: 83.165%,Recall:83.333%'}]



Val accuracy : 76.372%,val_loss:43.962, F1_score：82.017%, Recall：81.967%
nn_models/nn_model_epoch_1.pkl


100%|██████████| 313/313 [00:03<00:00, 102.89it/s, log={'train_loss: 0.46915,train_acc:77.648%,F1: 86.667%,Recall:86.667%'}]



Val accuracy : 77.136%,val_loss:41.864, F1_score：78.689%, Recall：78.689%
nn_models/nn_model_epoch_2.pkl


100%|██████████| 313/313 [00:03<00:00, 103.11it/s, log={'train_loss: 0.45302,train_acc:78.533%,F1: 76.588%,Recall:76.667%'}]



Val accuracy : 77.829%,val_loss:40.947, F1_score：86.885%, Recall：86.885%
nn_models/nn_model_epoch_3.pkl


100%|██████████| 313/313 [00:03<00:00, 102.55it/s, log={'train_loss: 0.43978,train_acc:79.573%,F1: 90.082%,Recall:90.000%'}]



Val accuracy : 78.700%,val_loss:40.357, F1_score：76.962%, Recall：77.049%
nn_models/nn_model_epoch_4.pkl


100%|██████████| 313/313 [00:03<00:00, 103.47it/s, log={'train_loss: 0.43100,train_acc:79.928%,F1: 86.787%,Recall:86.667%'}]



Val accuracy : 77.918%,val_loss:41.056, F1_score：72.131%, Recall：72.131%


100%|██████████| 313/313 [00:03<00:00, 103.20it/s, log={'train_loss: 0.41692,train_acc:80.903%,F1: 90.011%,Recall:90.000%'}]



Val accuracy : 79.019%,val_loss:40.451, F1_score：81.967%, Recall：81.967%
nn_models/nn_model_epoch_6.pkl


100%|██████████| 313/313 [00:03<00:00, 103.23it/s, log={'train_loss: 0.40579,train_acc:81.593%,F1: 83.277%,Recall:83.333%'}]



Val accuracy : 79.250%,val_loss:39.191, F1_score：64.347%, Recall：63.934%
nn_models/nn_model_epoch_7.pkl


100%|██████████| 313/313 [00:03<00:00, 103.36it/s, log={'train_loss: 0.39309,train_acc:82.203%,F1: 83.196%,Recall:83.333%'}]



Val accuracy : 79.925%,val_loss:38.287, F1_score：76.925%, Recall：77.049%
nn_models/nn_model_epoch_8.pkl


100%|██████████| 313/313 [00:03<00:00, 102.65it/s, log={'train_loss: 0.38396,train_acc:83.013%,F1: 83.389%,Recall:83.333%'}]



Val accuracy : 80.139%,val_loss:38.327, F1_score：90.217%, Recall：90.164%
nn_models/nn_model_epoch_9.pkl


100%|██████████| 313/313 [00:03<00:00, 103.34it/s, log={'train_loss: 0.37151,train_acc:83.583%,F1: 86.667%,Recall:86.667%'}]



Val accuracy : 80.121%,val_loss:38.515, F1_score：74.926%, Recall：75.410%


100%|██████████| 313/313 [00:03<00:00, 103.23it/s, log={'train_loss: 0.36180,train_acc:84.158%,F1: 93.273%,Recall:93.333%'}]



Val accuracy : 80.174%,val_loss:38.596, F1_score：72.342%, Recall：72.131%
nn_models/nn_model_epoch_11.pkl


100%|██████████| 313/313 [00:03<00:00, 104.19it/s, log={'train_loss: 0.35133,train_acc:84.893%,F1: 78.314%,Recall:76.667%'}]



Val accuracy : 80.867%,val_loss:36.799, F1_score：85.262%, Recall：85.246%
nn_models/nn_model_epoch_12.pkl


100%|██████████| 313/313 [00:03<00:00, 103.04it/s, log={'train_loss: 0.34024,train_acc:85.574%,F1: 90.101%,Recall:90.000%'}]



Val accuracy : 81.400%,val_loss:36.893, F1_score：80.338%, Recall：80.328%
nn_models/nn_model_epoch_13.pkl


100%|██████████| 313/313 [00:03<00:00, 103.40it/s, log={'train_loss: 0.32962,train_acc:86.154%,F1: 83.048%,Recall:83.333%'}]



Val accuracy : 81.240%,val_loss:36.735, F1_score：78.758%, Recall：78.689%


100%|██████████| 313/313 [00:03<00:00, 103.80it/s, log={'train_loss: 0.31887,train_acc:86.704%,F1: 86.607%,Recall:86.667%'}]



Val accuracy : 81.631%,val_loss:37.260, F1_score：80.317%, Recall：80.328%
nn_models/nn_model_epoch_15.pkl


100%|██████████| 313/313 [00:03<00:00, 103.70it/s, log={'train_loss: 0.31060,train_acc:87.194%,F1: 76.693%,Recall:76.667%'}]



Val accuracy : 81.791%,val_loss:37.769, F1_score：75.206%, Recall：75.410%
nn_models/nn_model_epoch_16.pkl


100%|██████████| 313/313 [00:03<00:00, 103.12it/s, log={'train_loss: 0.30025,train_acc:87.669%,F1: 93.241%,Recall:93.333%'}]



Val accuracy : 81.897%,val_loss:36.377, F1_score：83.607%, Recall：83.607%
nn_models/nn_model_epoch_17.pkl


100%|██████████| 313/313 [00:03<00:00, 104.00it/s, log={'train_loss: 0.28975,train_acc:88.269%,F1: 76.431%,Recall:76.667%'}]



Val accuracy : 82.235%,val_loss:37.204, F1_score：85.368%, Recall：85.246%
nn_models/nn_model_epoch_18.pkl


100%|██████████| 313/313 [00:03<00:00, 102.85it/s, log={'train_loss: 0.28106,train_acc:88.704%,F1: 96.648%,Recall:96.667%'}]



Val accuracy : 82.324%,val_loss:36.887, F1_score：85.246%, Recall：85.246%
nn_models/nn_model_epoch_19.pkl


100%|██████████| 313/313 [00:03<00:00, 103.81it/s, log={'train_loss: 0.27248,train_acc:89.169%,F1: 90.057%,Recall:90.000%'}]



Val accuracy : 82.466%,val_loss:36.694, F1_score：85.230%, Recall：85.246%
nn_models/nn_model_epoch_20.pkl


100%|██████████| 313/313 [00:03<00:00, 102.61it/s, log={'train_loss: 0.26354,train_acc:89.424%,F1: 89.967%,Recall:90.000%'}]  



Val accuracy : 82.537%,val_loss:37.802, F1_score：86.885%, Recall：86.885%
nn_models/nn_model_epoch_21.pkl


100%|██████████| 313/313 [00:03<00:00, 103.48it/s, log={'train_loss: 0.25255,train_acc:90.134%,F1: 83.352%,Recall:83.333%'}]



Val accuracy : 82.359%,val_loss:38.015, F1_score：86.822%, Recall：86.885%


100%|██████████| 313/313 [00:03<00:00, 104.19it/s, log={'train_loss: 0.24424,train_acc:90.474%,F1: 96.655%,Recall:96.667%'}] 



Val accuracy : 82.306%,val_loss:37.927, F1_score：80.488%, Recall：80.328%


100%|██████████| 313/313 [00:03<00:00, 103.72it/s, log={'train_loss: 0.23686,train_acc:90.799%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 82.555%,val_loss:39.369, F1_score：88.587%, Recall：88.525%
nn_models/nn_model_epoch_24.pkl


100%|██████████| 313/313 [00:03<00:00, 103.06it/s, log={'train_loss: 0.22518,train_acc:91.399%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 82.643%,val_loss:38.271, F1_score：82.007%, Recall：81.967%
nn_models/nn_model_epoch_25.pkl


100%|██████████| 313/313 [00:03<00:00, 103.78it/s, log={'train_loss: 0.21924,train_acc:91.724%,F1: 89.943%,Recall:90.000%'}]  



Val accuracy : 82.661%,val_loss:39.365, F1_score：86.906%, Recall：86.885%
nn_models/nn_model_epoch_26.pkl


100%|██████████| 313/313 [00:03<00:00, 102.21it/s, log={'train_loss: 0.20980,train_acc:92.209%,F1: 86.546%,Recall:86.667%'}]



Val accuracy : 82.537%,val_loss:39.425, F1_score：86.885%, Recall：86.885%


100%|██████████| 313/313 [00:03<00:00, 103.26it/s, log={'train_loss: 0.20185,train_acc:92.639%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 82.572%,val_loss:39.943, F1_score：85.078%, Recall：85.246%


100%|██████████| 313/313 [00:03<00:00, 103.06it/s, log={'train_loss: 0.19492,train_acc:92.844%,F1: 83.048%,Recall:83.333%'}]  



Val accuracy : 82.768%,val_loss:40.968, F1_score：72.208%, Recall：72.131%
nn_models/nn_model_epoch_29.pkl


100%|██████████| 313/313 [00:03<00:00, 103.92it/s, log={'train_loss: 0.18568,train_acc:93.284%,F1: 89.753%,Recall:90.000%'}]  



Val accuracy : 82.199%,val_loss:40.493, F1_score：84.136%, Recall：83.607%


100%|██████████| 313/313 [00:03<00:00, 103.12it/s, log={'train_loss: 0.18221,train_acc:93.409%,F1: 96.670%,Recall:96.667%'}]  



Val accuracy : 81.951%,val_loss:40.840, F1_score：88.499%, Recall：88.525%


100%|██████████| 313/313 [00:03<00:00, 103.89it/s, log={'train_loss: 0.17241,train_acc:94.104%,F1: 96.655%,Recall:96.667%'}]  



Val accuracy : 82.501%,val_loss:41.480, F1_score：76.709%, Recall：77.049%


100%|██████████| 313/313 [00:03<00:00, 103.22it/s, log={'train_loss: 0.16419,train_acc:94.389%,F1: 96.678%,Recall:96.667%'}]  



Val accuracy : 82.022%,val_loss:45.662, F1_score：78.584%, Recall：78.689%


100%|██████████| 313/313 [00:03<00:00, 97.77it/s, log={'train_loss: 0.15874,train_acc:94.564%,F1: 100.000%,Recall:100.000%'}] 



Val accuracy : 82.093%,val_loss:43.837, F1_score：78.584%, Recall：78.689%


100%|██████████| 313/313 [00:03<00:00, 92.97it/s, log={'train_loss: 0.15184,train_acc:94.939%,F1: 90.057%,Recall:90.000%'}]  



Val accuracy : 82.537%,val_loss:46.059, F1_score：90.115%, Recall：90.164%


100%|██████████| 313/313 [00:04<00:00, 75.81it/s, log={'train_loss: 0.14437,train_acc:95.070%,F1: 93.304%,Recall:93.333%'}]  



Val accuracy : 82.182%,val_loss:45.343, F1_score：90.169%, Recall：90.164%


100%|██████████| 313/313 [00:04<00:00, 75.90it/s, log={'train_loss: 0.13575,train_acc:95.570%,F1: 96.639%,Recall:96.667%'}]  



Val accuracy : 81.542%,val_loss:48.222, F1_score：79.353%, Recall：80.328%


100%|██████████| 313/313 [00:04<00:00, 75.56it/s, log={'train_loss: 0.13127,train_acc:95.815%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 81.649%,val_loss:50.197, F1_score：81.770%, Recall：81.967%


100%|██████████| 313/313 [00:04<00:00, 75.31it/s, log={'train_loss: 0.12365,train_acc:96.195%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 81.862%,val_loss:48.148, F1_score：86.820%, Recall：86.885%


100%|██████████| 313/313 [00:04<00:00, 77.29it/s, log={'train_loss: 0.11728,train_acc:96.380%,F1: 96.678%,Recall:96.667%'}]  



Val accuracy : 81.613%,val_loss:53.404, F1_score：81.938%, Recall：81.967%


100%|██████████| 313/313 [00:04<00:00, 77.05it/s, log={'train_loss: 0.11590,train_acc:96.515%,F1: 90.330%,Recall:90.000%'}]  



Val accuracy : 82.110%,val_loss:49.910, F1_score：81.977%, Recall：81.967%


100%|██████████| 313/313 [00:04<00:00, 76.93it/s, log={'train_loss: 0.10575,train_acc:96.915%,F1: 96.648%,Recall:96.667%'}]  



Val accuracy : 81.720%,val_loss:54.653, F1_score：83.607%, Recall：83.607%


100%|██████████| 313/313 [00:03<00:00, 78.51it/s, log={'train_loss: 0.10201,train_acc:96.970%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 81.808%,val_loss:53.107, F1_score：78.793%, Recall：78.689%


100%|██████████| 313/313 [00:04<00:00, 76.05it/s, log={'train_loss: 0.09393,train_acc:97.330%,F1: 93.333%,Recall:93.333%'}]  



Val accuracy : 81.524%,val_loss:53.631, F1_score：81.909%, Recall：81.967%


100%|██████████| 313/313 [00:04<00:00, 68.13it/s, log={'train_loss: 0.09084,train_acc:97.440%,F1: 96.686%,Recall:96.667%'}]  



Val accuracy : 81.578%,val_loss:54.303, F1_score：86.885%, Recall：86.885%


100%|██████████| 313/313 [00:03<00:00, 102.56it/s, log={'train_loss: 0.08855,train_acc:97.500%,F1: 93.394%,Recall:93.333%'}]  



Val accuracy : 81.329%,val_loss:55.537, F1_score：76.987%, Recall：77.049%


100%|██████████| 313/313 [00:03<00:00, 101.29it/s, log={'train_loss: 0.08061,train_acc:97.810%,F1: 96.655%,Recall:96.667%'}]  



Val accuracy : 81.418%,val_loss:57.575, F1_score：80.087%, Recall：80.328%


100%|██████████| 313/313 [00:03<00:00, 100.51it/s, log={'train_loss: 0.08086,train_acc:97.825%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 81.364%,val_loss:64.518, F1_score：80.360%, Recall：80.328%


100%|██████████| 313/313 [00:03<00:00, 100.18it/s, log={'train_loss: 0.07745,train_acc:97.945%,F1: 100.000%,Recall:100.000%'}]



Val accuracy : 81.293%,val_loss:64.607, F1_score：85.222%, Recall：85.246%
