In [1]:
# coding: UTF-8
import torch
from importlib import import_module
import pickle as pkl

class CnnModel:

    def __init__(self):
        dataset = 'dataset'  # 数据集
        embedding = 'random'
        model_name = 'TextCNN'  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
        self.id_to_cate = id_to_cate = {int(x.split(" +++$+++ ")[1]\
            .strip()):x.split(" +++$+++ ")[0] for x in open(dataset \
            + '/data/id_to_cate.txt', encoding='utf-8').readlines()}
        x = import_module(model_name)
        config = x.Config(dataset, embedding)
        self.vocab = pkl.load(open(config.vocab_path, 'rb'))
        config.n_vocab = len(self.vocab)
        self.model = x.Model(config).to(config.device)

        self.model.load_state_dict(torch.load(config.save_path))
        self.model.eval()

    def load_dataset(self,message, pad_size=32):
        UNK, PAD = '<UNK>', '<PAD>'  # 未知字，padding符号
        tokenizer = lambda x: [y for y in x]
        contents = []
        for line in message:
            words_line = []
            token = tokenizer(line)
            seq_len = len(token)
            if pad_size:
                if len(token) < pad_size:
                    token.extend([PAD] * (pad_size - len(token)))
                else:
                    token = token[:pad_size]
                    seq_len = pad_size
            # word to id
            for word in token:
                words_line.append(self.vocab.get(word, self.vocab.get(UNK)))
            contents.append(words_line)
        return torch.LongTensor(contents),"placeholder"

    def predict(self,message):
        if type(message) == str:
            message = [message]
        content = self.load_dataset(message)
        outputs = self.model(content)
        predict_int = torch.max(outputs.data, 1)[1].cpu().numpy()
        # id to category
        predict_class = [self.id_to_cate[i] for i in predict_int]

        return predict_class
    
if __name__ == '__main__':
    cnn_model = CnnModel()
    test_demo = ['美国Nordic Naturals 儿童草莓味DHA鳕鱼油口服液 119ml',
                 'DERMACEPT C10铂金抗氧套装']
    print(cnn_model.predict(test_demo))

['宝宝食品', '个人洗护']


In [20]:
testfile = [[cnn_model.id_to_cate[int(x.split(" +++$+++ ")[1]\
            .strip())],x.split(" +++$+++ ")[0]] for x in open('dataset/data/test.txt', encoding='utf-8').readlines()]
x = [i[1] for i in testfile]
y = [i[0] for i in testfile]
predict_class = cnn_model.predict(x)
from sklearn import metrics
acc = metrics.accuracy_score(y, predict_class)
report = metrics.classification_report(y, predict_class, digits=4)
confusion = metrics.confusion_matrix(y, predict_class)
print(report)
print(confusion)

In [36]:
print(report)

             precision    recall  f1-score   support

       个人洗护     0.9310    0.8966    0.9135       783
        保健品     0.9479    0.9673    0.9575       979
       口腔护理     0.9545    0.9492    0.9518       177
    女装/女士内衣     0.9749    0.9749    0.9749       399
      婴幼儿奶粉     0.9718    0.9773    0.9745       176
      孕产妇用品     0.9540    0.9326    0.9432        89
    宝宝服饰/玩具     0.9818    0.9600    0.9708       225
       宝宝洗护     0.9646    0.8934    0.9277       244
  宝宝用品_含纸尿片     0.9890    0.9756    0.9823       369
       宝宝食品     0.9176    0.8830    0.9000       265
    宠物食品/用品     0.9909    0.9559    0.9731       227
       家用家电     0.9481    0.9481    0.9481       135
       居家日用     0.9185    0.9390    0.9286       672
      彩妆/香水     0.9381    0.9739    0.9556      1073
       护理护肤     0.9494    0.9574    0.9533      1665
       数码3C     0.9550    0.9725    0.9636       109
       汽车用品     1.0000    0.5556    0.7143         9
         油品     1.0000    0.9595    0.9793   

In [37]:
import time
from datetime import timedelta
def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

In [39]:
start_time = time.time()
cnn_model = CnnModel()
test_demo = ['美国Nordic Naturals 儿童草莓味DHA鳕鱼油口服液 119ml']
print(cnn_model.predict(test_demo))
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)

Time usage: 0:00:04
