In [0]:
import os
path = "/content/drive/My Drive/NLP/nlp_AFQMC"
os.chdir(path)

In [0]:
from utils import * 
from torch.utils.data import DataLoader
from tokenizer import Tokenizer
from sentence_dataSet import SentenceDataSet
from data_processor import DataProcessor
from executor import Executor
from model.lstm_base import LSTMBase
from model.lstm_base_test import LSTMBaseTest
from config import Config
import pandas as pd
import torch

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
def train(config, data_processor, executor):
    # 加载数据
    train_df = pd.read_csv(config.train_data_path)
    dev_df = pd.read_csv(config.dev_data_path)

    # 生成训练数据样本
    train_data_set = data_processor.get_dataset(train_df)
    dev_data_set = data_processor.get_dataset(dev_df)
    train_loader = DataLoader(train_data_set, batch_size=config.batch_size, shuffle=True)
    dev_loader = DataLoader(dev_data_set, batch_size=config.batch_size, shuffle=True)

    # 加载模型
    model = LSTMBaseTest(config, data_processor.emb_matrix)
    print(model)
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(str(total_trainable_params), 'parameters is trainable.')
    model.to(config.device)
    dev_best_loss = float('inf')
    for i in range(config.epoch_num):
        print('Epoch:  ', i + 1)
        executor.train_model(train_loader, model)
        dev_acc, dev_loss, report, confusion = executor.evaluate_model(dev_loader, model)
        print_ans(dev_acc, dev_loss, report, confusion)
        if dev_loss < dev_best_loss: # 保存最好的模型
            dev_best_loss = dev_loss
            torch.save(model.state_dict(), config.model_save_path)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

def test(config,  data_processor, executor):
    # 加载数据
    test_df = pd.read_csv(config.train_data_path)
    # 生成测试数据样本
    test_data_set = data_processor.get_dataset(test_df, is_train=False)
    test_loader = DataLoader(test_data_set, batch_size=config.batch_size, shuffle=True)
    # 加载模型
    model = LSTMBase(config, data_processor.emb_matrix)
    model.load_state_dict(torch.load(config.model_save_path))
    #model
    test_acc, test_loss, report, confusion = executor.evaluate_model(test_loader, model)
    print_ans(test_acc, test_loss, report, confusion)

In [6]:
config = Config()
data_processor = DataProcessor(config)
executor = Executor(config)

dict_len:  259753
emb dim: 300


In [33]:
train(config, data_processor, executor)

LSTMBaseTest(
  (emb): Embedding(259755, 300)
  (encoder_layer): LSTM(300, 200, num_layers=2, dropout=0.2, bidirectional=True)
  (predict_fc): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1200, out_features=200, bias=True)
    (2): Tanh()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=200, out_features=2, bias=True)
  )
)
2007002 parameters is trainable.
Epoch:   1
Iter:   500 Train loss: 0.659 Train acc:62.50% Time:0:00:10
Iter:  1000 Train loss: 0.553 Train acc:71.88% Time:0:00:21


  _warn_prf(average, modifier, msg_start, len(result))


Dev Loss: 0.62, Dev Acc:69.00%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

           0     0.6900    1.0000    0.8166      2978
           1     0.0000    0.0000    0.0000      1338

    accuracy                         0.6900      4316
   macro avg     0.3450    0.5000    0.4083      4316
weighted avg     0.4761    0.6900    0.5634      4316

Confusion Matrix...
[[2978    0]
 [1338    0]]
Epoch:   2
Iter:   500 Train loss: 0.566 Train acc:75.00% Time:0:00:10
Iter:  1000 Train loss: 0.724 Train acc:59.38% Time:0:00:21
Dev Loss: 0.62, Dev Acc:69.00%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

           0     0.6900    1.0000    0.8166      2978
           1     0.0000    0.0000    0.0000      1338

    accuracy                         0.6900      4316
   macro avg     0.3450    0.5000    0.4083      4316
weighted avg     0.4761    0.6900    0.5634      4316

Confusion Matrix...
[[2978    0]
 [133

KeyboardInterrupt: ignored

In [70]:
import gc
gc.collect()

137729

In [52]:
config.epoch_num

10