In [1]:
# coding: UTF-8
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module


if __name__ == '__main__':
    dataset = 'dataset'  # 数据集
    embedding = 'random'
    model_name = 'TextCNN'  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
    from utils import build_dataset, build_iterator, get_time_dif

    x = import_module(model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter)


9209it [00:00, 46216.77it/s]

Loading data...
Vocab size: 3454


96456it [00:02, 44202.77it/s]
12051it [00:00, 45878.64it/s]
12060it [00:00, 45472.68it/s]


Time usage: 0:00:03
<bound method Module.parameters of Model(
  (embedding): Embedding(3454, 300, padding_idx=3453)
  (convs): ModuleList(
    (0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
    (1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
    (2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=768, out_features=26, bias=True)
)>
Epoch [1/20]
Iter:      0,  Train Loss:   3.3,  Train Acc:  5.47%,  Val Loss:   3.0,  Val Acc: 22.56%,  Time: 0:00:04 *
Iter:    100,  Train Loss:  0.82,  Train Acc: 74.22%,  Val Loss:  0.53,  Val Acc: 85.02%,  Time: 0:00:27 *
Iter:    200,  Train Loss:  0.47,  Train Acc: 87.50%,  Val Loss:  0.38,  Val Acc: 89.44%,  Time: 0:00:49 *
Iter:    300,  Train Loss:  0.44,  Train Acc: 85.94%,  Val Loss:  0.31,  Val Acc: 91.62%,  Time: 0:01:12 *
Iter:    400,  Train Loss:  0.33,  Train Acc: 92.97%,  Val Loss:  0.29,  Val Acc: 92.05%,  Time: 0:01:33 *
Iter:    500,  Train

In [None]:
            precision    recall  f1-score   support

       个人洗护     0.9310    0.8966    0.9135       783
        保健品     0.9479    0.9673    0.9575       979
       口腔护理     0.9545    0.9492    0.9518       177
    女装/女士内衣     0.9749    0.9749    0.9749       399
      婴幼儿奶粉     0.9718    0.9773    0.9745       176
      孕产妇用品     0.9540    0.9326    0.9432        89
    宝宝服饰/玩具     0.9818    0.9600    0.9708       225
       宝宝洗护     0.9646    0.8934    0.9277       244
  宝宝用品_含纸尿片     0.9890    0.9756    0.9823       369
       宝宝食品     0.9176    0.8830    0.9000       265
    宠物食品/用品     0.9909    0.9559    0.9731       227
       家用家电     0.9481    0.9481    0.9481       135
       居家日用     0.9185    0.9390    0.9286       672
      彩妆/香水     0.9381    0.9739    0.9556      1073
       护理护肤     0.9494    0.9574    0.9533      1665
       数码3C     0.9550    0.9725    0.9636       109
       汽车用品     1.0000    0.5556    0.7143         9
         油品     1.0000    0.9595    0.9793        74
    男装/男士内衣     0.9912    0.9869    0.9890       457
      箱包/鞋靴     0.9914    0.9961    0.9937      1266
       美容工具     0.9655    0.8750    0.9180        96
       运动户外     1.0000    0.9741    0.9869       309
       进口食品     0.9289    0.9007    0.9146       725
       进口饮料     0.9179    0.8883    0.9029       403
      餐厨/清洁     0.9200    0.9758    0.9471       495
      饰品/手表     0.9692    0.9844    0.9767       639

avg / total     0.9542    0.9541    0.9539     12060