In [1]:
# coding: utf-8
from __future__ import print_function
import os
import tensorflow as tf
import tensorflow.contrib.keras as kr
from cnn_model import TCNNConfig, TextCNN
from data.cnews_loader import read_category, read_vocab
import pandas as pd

In [2]:
# try:
#     bool(type(unicode))
# except NameError:
#     unicode = str

In [3]:
base_dir = 'data'
vocab_dir = os.path.join(base_dir, 'vocab.txt')

save_dir = 'checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径

In [4]:
class CnnModel:
    def __init__(self):
        self.config = TCNNConfig()
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.config.vocab_size = len(self.words)
        self.model = TextCNN(self.config)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型

    def predict(self, message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        #content = unicode(message)
        content = message
        #for x in content[:10]:
        #    print(x)
        data = [self.word_to_id[x] for x in content if x in self.word_to_id]

        feed_dict = {
            self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
            self.model.keep_prob: 1.0
        }

        y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
        return self.categories[y_pred_cls[0]]

In [5]:
f=open('../ensemble/data/val_data')
file=[]
for line in f:
    file.append(line.replace("\n","").split(","))
file = pd.DataFrame(file[1:])
file.columns=['id','article','word_seg','class']
file['word_seg'] = file['word_seg'].str.split()

In [6]:
cnn_model = CnnModel()

INFO:tensorflow:Restoring parameters from checkpoints/textcnn/best_validation


In [7]:
file['cnn_class'] = file['word_seg'].map(lambda a: cnn_model.predict(a))

In [8]:
file['cnn_class'][:10]

0    12
1     9
2     2
3    10
4    18
5    16
6    19
7     7
8    13
9     2
Name: cnn_class, dtype: object

In [9]:
len(file['cnn_class'])

10228

In [10]:
labels_right = []
labels_predict = []

A = dict.fromkeys(file['class'],0)  #预测正确的各个类的数目
B = dict.fromkeys(file['class'],0)   #测试数据集中各个类的数目
C = dict.fromkeys(file['cnn_class'],0) #预测结果中各个类的数目
for i in range(0,len(file['class'])):
    B[file['class'][i]] += 1
    C[file['cnn_class'][i]] += 1
    if file['class'][i] == file['cnn_class'][i]:
        A[file['class'][i]] += 1
print (A)
print (B)
print (C)
F1 = 0
#计算准确率，召回率，F值
for key in B:
    try:
        r = float(A[key]) / float(B[key])
        p = float(A[key]) / float(C[key])
        f = p * r * 2 / (p + r)
        F1 += f
        print ("%s:\t precision:%f\t recall:%f\t f:%f" % (key,p,r,f))
    except:
        print ("error:", key, "right:", A.get(key,0), "real:", B.get(key,0), "predict:",C.get(key,0))
print(F1/19)

{'16': 254, '15': 665, '17': 205, '1': 265, '19': 272, '10': 326, '3': 689, '8': 487, '11': 235, '9': 720, '13': 667, '18': 549, '4': 317, '7': 206, '5': 160, '12': 328, '2': 226, '14': 527, '6': 589}
{'16': 330, '15': 769, '17': 341, '1': 544, '19': 534, '10': 506, '3': 819, '8': 683, '11': 376, '9': 763, '13': 812, '18': 648, '4': 392, '7': 321, '5': 219, '12': 518, '2': 294, '14': 660, '6': 699}
{'10': 479, '15': 743, '8': 758, '17': 297, '1': 442, '16': 340, '3': 800, '19': 446, '11': 354, '18': 656, '13': 969, '4': 377, '7': 330, '5': 204, '12': 539, '9': 776, '2': 334, '14': 748, '6': 636}
16:	 precision:0.747059	 recall:0.769697	 f:0.758209
15:	 precision:0.895020	 recall:0.864759	 f:0.879630
17:	 precision:0.690236	 recall:0.601173	 f:0.642633
1:	 precision:0.599548	 recall:0.487132	 f:0.537525
19:	 precision:0.609865	 recall:0.509363	 f:0.555102
10:	 precision:0.680585	 recall:0.644269	 f:0.661929
3:	 precision:0.861250	 recall:0.841270	 f:0.851143
8:	 precision:0.642480	 reca

i=0
fid0=open('../ensemble/result/result_test_cnn.csv','w')
fid0.write("id,class"+"\n")
for item in file['cnn_class']:
    fid0.write(str(i)+","+str(item)+"\n")
    i=i+1
fid0.close()