In [7]:
import os
import csv
import time
import datetime
import random
import json

from collections import Counter
from math import sqrt

import gensim
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score

In [8]:
# 配置参数

class TrainingConfig(object):
    epoches = 5
    evaluateEvery = 100
    checkpointEvery = 100
    learningRate = 0.001
    
class ModelConfig(object):
    embeddingSize = 200
    numFilters = 128

    filterSizes = [2, 3, 4, 5]
    dropoutKeepProb = 0.5
    l2RegLambda = 0.0
    
class Config(object):
    sequenceLength = 200  # 取了所有序列长度的均值
    batchSize = 128
    
    dataSource = "../data/preProcess/labeledTrain.csv"
    
    stopWordSource = "../data/english"
    
    numClasses = 1  # 二分类设置为1，多分类设置为类别的数目
    
    rate = 0.8  # 训练集的比例
    
    training = TrainingConfig()
    
    model = ModelConfig()

    
# 实例化配置参数对象
config = Config()

In [10]:
# 数据预处理的类，生成训练集和测试集

class Dataset(object):
    def __init__(self, config):
        self.config = config
        self._dataSource = config.dataSource
        self._stopWordSource = config.stopWordSource  
        
        self._sequenceLength = config.sequenceLength  # 每条输入的序列处理为定长
        self._embeddingSize = config.model.embeddingSize
        self._batchSize = config.batchSize
        self._rate = config.rate
        
        self._stopWordDict = {}
        
        self.trainReviews = []
        self.trainLabels = []
        
        self.evalReviews = []
        self.evalLabels = []
        
        self.wordEmbedding =None
        
        self.labelList = []
        
    def _readData(self, filePath):
        """
        从csv文件中读取数据集
        """
        
        df = pd.read_csv(filePath)
        
        if self.config.numClasses == 1:
            labels = df["sentiment"].tolist()
        elif self.config.numClasses > 1:
            labels = df["rate"].tolist()
            
        review = df["review"].tolist()
        reviews = [line.strip().split() for line in review]

        return reviews, labels
    
    def _labelToIndex(self, labels, label2idx):
        """
        将标签转换成索引表示
        """
        labelIds = [label2idx[label] for label in labels]
        return labelIds
    
    def _wordToIndex(self, reviews, word2idx):
        """
        将词转换成索引
        """
        reviewIds = [[word2idx.get(item, word2idx["UNK"]) for item in review] for review in reviews]
        return reviewIds
        
    def _genTrainEvalData(self, x, y, word2idx, rate):
        """
        生成训练集和验证集
        """
        reviews = []
        for review in x:
            if len(review) >= self._sequenceLength:
                reviews.append(review[:self._sequenceLength])
            else:
                reviews.append(review + [word2idx["PAD"]] * (self._sequenceLength - len(review)))
            
        trainIndex = int(len(x) * rate)
        
        trainReviews = np.asarray(reviews[:trainIndex], dtype="int64")
        trainLabels = np.array(y[:trainIndex], dtype="float32")
        
        evalReviews = np.asarray(reviews[trainIndex:], dtype="int64")
        evalLabels = np.array(y[trainIndex:], dtype="float32")

        return trainReviews, trainLabels, evalReviews, evalLabels
        
    def _genVocabulary(self, reviews, labels):
        """
        生成词向量和词汇-索引映射字典，可以用全数据集
        """
        
        allWords = [word for review in reviews for word in review]
        
        # 去掉停用词
        subWords = [word for word in allWords if word not in self.stopWordDict]
        
        wordCount = Counter(subWords)  # 统计词频
        sortWordCount = sorted(wordCount.items(), key=lambda x: x[1], reverse=True)
        
        # 去除低频词
        words = [item[0] for item in sortWordCount if item[1] >= 5]
        
        vocab, wordEmbedding = self._getWordEmbedding(words)
        self.wordEmbedding = wordEmbedding
        
        word2idx = dict(zip(vocab, list(range(len(vocab)))))
        
        uniqueLabel = list(set(labels))
        label2idx = dict(zip(uniqueLabel, list(range(len(uniqueLabel)))))
        self.labelList = list(range(len(uniqueLabel)))
        
        # 将词汇-索引映射表保存为json数据，之后做inference时直接加载来处理数据
        with open("../data/wordJson/word2idx.json", "w", encoding="utf-8") as f:
            json.dump(word2idx, f)
        
        with open("../data/wordJson/label2idx.json", "w", encoding="utf-8") as f:
            json.dump(label2idx, f)
        
        return word2idx, label2idx
            
    def _getWordEmbedding(self, words):
        """
        按照我们的数据集中的单词取出预训练好的word2vec中的词向量
        """
        
        wordVec = gensim.models.KeyedVectors.load_word2vec_format("../word2vec/word2Vec.bin", binary=True)
        vocab = []
        wordEmbedding = []
        
        # 添加 "pad" 和 "UNK", 
        vocab.append("PAD")
        vocab.append("UNK")
        wordEmbedding.append(np.zeros(self._embeddingSize))
        wordEmbedding.append(np.random.randn(self._embeddingSize))
        
        for word in words:
            try:
                vector = wordVec.wv[word]
                vocab.append(word)
                wordEmbedding.append(vector)
            except:
                print(word + "不存在于词向量中")
                
        return vocab, np.array(wordEmbedding)
    
    def _readStopWord(self, stopWordPath):
        """
        读取停用词
        """
        
        with open(stopWordPath, "r") as f:
            stopWords = f.read()
            stopWordList = stopWords.splitlines()
            # 将停用词用列表的形式生成，之后查找停用词时会比较快
            self.stopWordDict = dict(zip(stopWordList, list(range(len(stopWordList)))))
            
    def dataGen(self):
        """
        初始化训练集和验证集
        """
        
        # 初始化停用词
        self._readStopWord(self._stopWordSource)
        
        # 初始化数据集
        reviews, labels = self._readData(self._dataSource)
        
        # 初始化词汇-索引映射表和词向量矩阵
        word2idx, label2idx = self._genVocabulary(reviews, labels)
        
        # 将标签和句子数值化
        labelIds = self._labelToIndex(labels, label2idx)
        reviewIds = self._wordToIndex(reviews, word2idx)
        
        # 初始化训练集和测试集
        trainReviews, trainLabels, evalReviews, evalLabels = self._genTrainEvalData(reviewIds, labelIds, word2idx, self._rate)
        self.trainReviews = trainReviews
        self.trainLabels = trainLabels
        
        self.evalReviews = evalReviews
        self.evalLabels = evalLabels
        
        
data = Dataset(config)
data.dataGen()
data



youre不存在于词向量中
youll不存在于词向量中
theyre不存在于词向量中
youve不存在于词向量中
werent不存在于词向量中
youd不存在于词向量中
hasnt不存在于词向量中
shouldnt不存在于词向量中
weve不存在于词向量中
theyve不存在于词向量中
1010不存在于词向量中
wouldve不存在于词向量中
hed不存在于词向量中
andor不存在于词向量中
couldve不存在于词向量中
810不存在于词向量中
itthe不存在于词向量中
710不存在于词向量中
theyd不存在于词向量中
writerdirector不存在于词向量中
moviei不存在于词向量中
iti不存在于词向量中
theyll不存在于词向量中
310不存在于词向量中
410不存在于词向量中
910不存在于词向量中
itll不存在于词向量中
lees不存在于词向量中
familys不存在于词向量中
disneys不存在于词向量中
filmi不存在于词向量中
210不存在于词向量中
shouldve不存在于词向量中
*12不存在于词向量中
shakespeares不存在于词向量中
hitlers不存在于词向量中
brazil:不存在于词向量中
freddys不存在于词向量中
fords不存在于词向量中
storys不存在于词向量中
ohara不存在于词向量中
stewarts不存在于词向量中
kellys不存在于词向量中
scotts不存在于词向量中
tonys不存在于词向量中
itthis不存在于词向量中
keatons不存在于词向量中
rosemarys不存在于词向量中
ches不存在于词向量中
themthe不存在于词向量中
timethe不存在于词向量中
branaghs不存在于词向量中
mustve不存在于词向量中
kubricks不存在于词向量中
hitchcocks不存在于词向量中
lynchs不存在于词向量中
bakshis不存在于词向量中
whove不存在于词向量中
allthe不存在于词向量中
whod不存在于词向量中
itit不存在于词向量中
smiths不存在于词向量中
obrien不存在于词向量中
palmas不存在于词向量中
cravens不存在于词向量中
timei不存在于词向量中
movieit不存在于词向量中
hartley

spoilersi不存在于词向量中
alberts不存在于词向量中
doyles不存在于词向量中
jaffa不存在于词向量中
cassidys不存在于词向量中
worsethe不存在于词向量中
recap:不存在于词向量中
offi不存在于词向量中
greendale不存在于词向量中
todds不存在于词向量中
indias不存在于词向量中
caligari不存在于词向量中
crapthe不存在于词向量中
rev不存在于词向量中
yikes不存在于词向量中
doone不存在于词向量中
yentl不存在于词向量中
outi不存在于词向量中
presque不存在于词向量中
billys不存在于词向量中
10000不存在于词向量中
bosss不存在于词向量中
foxxs不存在于词向量中
rocknroll不存在于词向量中
robertsons不存在于词向量中
itim不存在于词向量中
ruths不存在于词向量中
booki不存在于词向量中
sauras不存在于词向量中
ummm不存在于词向量中
horrorcomedy不存在于词向量中
himi不存在于词向量中
vice-versa不存在于词向量中
marias不存在于词向量中
sabrinas不存在于词向量中
segals不存在于词向量中
comedythe不存在于词向量中
filmas不存在于词向量中
fatherthe不存在于词向量中
elliotts不存在于词向量中
roths不存在于词向量中
nelsons不存在于词向量中
bronsons不存在于词向量中
kidmans不存在于词向量中
onethis不存在于词向量中
wests不存在于词向量中
brians不存在于词向量中
filmone不存在于词向量中
knowthe不存在于词向量中
wellbut不存在于词向量中
warthe不存在于词向量中
demilles不存在于词向量中
esthers不存在于词向量中
screeni不存在于词向量中
leighs不存在于词向量中
timesbut不存在于词向量中
francos不存在于词向量中
umm不存在于词向量中
toi不存在于词向量中
experiencei不存在于词向量中
againit不存在于词向量中
colmans不存在于词向量中
terriblethe不存在于词向量中
daddys不存在于词向量中
ef

bernsens不存在于词向量中
chancethe不存在于词向量中
rothschild不存在于词向量中
pasadena不存在于词向量中
frankie-boy不存在于词向量中
theni不存在于词向量中
jesuss不存在于词向量中
filmoh不存在于词向量中
klaws不存在于词向量中
rodneys不存在于词向量中
flux不存在于词向量中
cassavettess不存在于词向量中
slugs:不存在于词向量中
milford不存在于词向量中
periodi不存在于词向量中
margarete不存在于词向量中
muchit不存在于词向量中
timedirector不存在于词向量中
playoffs不存在于词向量中
thesethe不存在于词向量中
malamud不存在于词向量中
who-dunnit不存在于词向量中
klute不存在于词向量中
wrights不存在于词向量中
bannings不存在于词向量中
ploti不存在于词向量中
hateable不存在于词向量中
tippi不存在于词向量中
vidors不存在于词向量中
rubys不存在于词向量中
enoughthe不存在于词向量中
glenrowan不存在于词向量中
buddys不存在于词向量中
amis不存在于词向量中
rozs不存在于词向量中
tysons不存在于词向量中
heorot不存在于词向量中
bregana不存在于词向量中
1894不存在于词向量中
fani不存在于词向量中
4-bad不存在于词向量中
wellyou不存在于词向量中
rc不存在于词向量中
movieshe不存在于词向量中
burgundians不存在于词向量中
merediths不存在于词向量中
asimovs不存在于词向量中
itwith不存在于词向量中
barbarella不存在于词向量中
!!!!!!!!不存在于词向量中
site!不存在于词向量中
amazingand不存在于词向量中
itnot不存在于词向量中
carrys不存在于词向量中
hickey不存在于词向量中
realismthe不存在于词向量中
tasuiev不存在于词向量中
forrests不存在于词向量中
donofrio不存在于词向量中
moreira不存在于词向量中
gadar不存在于词向量中
emigrant不存在于词向量中
thatt

<__main__.Dataset at 0x7f60f1f69748>

In [27]:
print("train data shape: {}".format(data.trainReviews.shape))
print("train label shape: {}".format(data.trainLabels.shape))
print("eval data shape: {}".format(data.evalReviews.shape))
type(data)
data.trainReviews
data.trainLabels

train data shape: (20000, 200)
train label shape: (20000,)
eval data shape: (5000, 200)


array([1., 1., 0., ..., 0., 0., 0.], dtype=float32)

In [22]:
# 输出batch数据集

def nextBatch(x, y, batchSize):
        """
        生成batch数据集，用生成器的方式输出
        """
    
        perm = np.arange(len(x))
        np.random.shuffle(perm)
        x = x[perm]
        y = y[perm]
        
        numBatches = len(x) // batchSize

        for i in range(numBatches):
            start = i * batchSize
            end = start + batchSize
            batchX = np.array(x[start: end], dtype="int64")
            batchY = np.array(y[start: end], dtype="float32")
            
            yield batchX, batchY

In [23]:
# 构建模型
class TextCNN(object):
    """
    Text CNN 用于文本分类
    """
    def __init__(self, config, wordEmbedding):

        # 定义模型的输入
        self.inputX = tf.placeholder(tf.int32, [None, config.sequenceLength], name="inputX")
        self.inputY = tf.placeholder(tf.int32, [None], name="inputY")
        
        self.dropoutKeepProb = tf.placeholder(tf.float32, name="dropoutKeepProb")
        
        # 定义l2损失
        l2Loss = tf.constant(0.0)
        
        # 词嵌入层
        with tf.name_scope("embedding"):

            # 利用预训练的词向量初始化词嵌入矩阵
            self.W = tf.Variable(tf.cast(wordEmbedding, dtype=tf.float32, name="word2vec") ,name="W")
            # 利用词嵌入矩阵将输入的数据中的词转换成词向量，维度[batch_size, sequence_length, embedding_size]
            self.embeddedWords = tf.nn.embedding_lookup(self.W, self.inputX)
            # 卷积的输入是思维[batch_size, width, height, channel]，因此需要增加维度，用tf.expand_dims来增大维度
            self.embeddedWordsExpanded = tf.expand_dims(self.embeddedWords, -1)

        # 创建卷积和池化层
        pooledOutputs = []
        # 有三种size的filter，3， 4， 5，textCNN是个多通道单层卷积的模型，可以看作三个单层的卷积模型的融合
        for i, filterSize in enumerate(config.model.filterSizes):
            with tf.name_scope("conv-maxpool-%s" % filterSize):
                # 卷积层，卷积核尺寸为filterSize * embeddingSize，卷积核的个数为numFilters
                # 初始化权重矩阵和偏置
                filterShape = [filterSize, config.model.embeddingSize, 1, config.model.numFilters]
                W = tf.Variable(tf.truncated_normal(filterShape, stddev=0.1), name="W")
                b = tf.Variable(tf.constant(0.1, shape=[config.model.numFilters]), name="b")
                conv = tf.nn.conv2d(
                    self.embeddedWordsExpanded,
                    W,
                    strides=[1, 1, 1, 1],
                    padding="VALID",
                    name="conv")
                
                # relu函数的非线性映射
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
                
                # 池化层，最大池化，池化是对卷积后的序列取一个最大值
                pooled = tf.nn.max_pool(
                    h,
                    ksize=[1, config.sequenceLength - filterSize + 1, 1, 1],  # ksize shape: [batch, height, width, channels]
                    strides=[1, 1, 1, 1],
                    padding='VALID',
                    name="pool")
                pooledOutputs.append(pooled)  # 将三种size的filter的输出一起加入到列表中

        # 得到CNN网络的输出长度
        numFiltersTotal = config.model.numFilters * len(config.model.filterSizes)
        
        # 池化后的维度不变，按照最后的维度channel来concat
        self.hPool = tf.concat(pooledOutputs, 3)
        
        # 摊平成二维的数据输入到全连接层
        self.hPoolFlat = tf.reshape(self.hPool, [-1, numFiltersTotal])

        # dropout
        with tf.name_scope("dropout"):
            self.hDrop = tf.nn.dropout(self.hPoolFlat, self.dropoutKeepProb)
       
        # 全连接层的输出
        with tf.name_scope("output"):
            outputW = tf.get_variable(
                "outputW",
                shape=[numFiltersTotal, config.numClasses],
                initializer=tf.contrib.layers.xavier_initializer())
            outputB= tf.Variable(tf.constant(0.1, shape=[config.numClasses]), name="outputB")
            l2Loss += tf.nn.l2_loss(outputW)
            l2Loss += tf.nn.l2_loss(outputB)
            self.logits = tf.nn.xw_plus_b(self.hDrop, outputW, outputB, name="logits")
            if config.numClasses == 1:
                self.predictions = tf.cast(tf.greater_equal(self.logits, 0.0), tf.int32, name="predictions")
            elif config.numClasses > 1:
                self.predictions = tf.argmax(self.logits, axis=-1, name="predictions")
            
            print(self.predictions)
        
        # 计算二元交叉熵损失
        with tf.name_scope("loss"):
            
            if config.numClasses == 1:
                losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.cast(tf.reshape(self.inputY, [-1, 1]), 
                                                                                                    dtype=tf.float32))
            elif config.numClasses > 1:
                losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.inputY)
                
            self.loss = tf.reduce_mean(losses) + config.model.l2RegLambda * l2Loss
            

In [24]:
"""
定义各类性能指标
"""

def mean(item: list) -> float:
    """
    计算列表中元素的平均值
    :param item: 列表对象
    :return:
    """
    res = sum(item) / len(item) if len(item) > 0 else 0
    return res


def accuracy(pred_y, true_y):
    """
    计算二类和多类的准确率
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :return:
    """
    if isinstance(pred_y[0], list):
        pred_y = [item[0] for item in pred_y]
    corr = 0
    for i in range(len(pred_y)):
        if pred_y[i] == true_y[i]:
            corr += 1
    acc = corr / len(pred_y) if len(pred_y) > 0 else 0
    return acc


def binary_precision(pred_y, true_y, positive=1):
    """
    二类的精确率计算
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param positive: 正例的索引表示
    :return:
    """
    corr = 0
    pred_corr = 0
    for i in range(len(pred_y)):
        if pred_y[i] == positive:
            pred_corr += 1
            if pred_y[i] == true_y[i]:
                corr += 1

    prec = corr / pred_corr if pred_corr > 0 else 0
    return prec


def binary_recall(pred_y, true_y, positive=1):
    """
    二类的召回率
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param positive: 正例的索引表示
    :return:
    """
    corr = 0
    true_corr = 0
    for i in range(len(pred_y)):
        if true_y[i] == positive:
            true_corr += 1
            if pred_y[i] == true_y[i]:
                corr += 1

    rec = corr / true_corr if true_corr > 0 else 0
    return rec


def binary_f_beta(pred_y, true_y, beta=1.0, positive=1):
    """
    二类的f beta值
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param beta: beta值
    :param positive: 正例的索引表示
    :return:
    """
    precision = binary_precision(pred_y, true_y, positive)
    recall = binary_recall(pred_y, true_y, positive)
    try:
        f_b = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
    except:
        f_b = 0
    return f_b


def multi_precision(pred_y, true_y, labels):
    """
    多类的精确率
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param labels: 标签列表
    :return:
    """
    if isinstance(pred_y[0], list):
        pred_y = [item[0] for item in pred_y]

    precisions = [binary_precision(pred_y, true_y, label) for label in labels]
    prec = mean(precisions)
    return prec


def multi_recall(pred_y, true_y, labels):
    """
    多类的召回率
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param labels: 标签列表
    :return:
    """
    if isinstance(pred_y[0], list):
        pred_y = [item[0] for item in pred_y]

    recalls = [binary_recall(pred_y, true_y, label) for label in labels]
    rec = mean(recalls)
    return rec


def multi_f_beta(pred_y, true_y, labels, beta=1.0):
    """
    多类的f beta值
    :param pred_y: 预测结果
    :param true_y: 真实结果
    :param labels: 标签列表
    :param beta: beta值
    :return:
    """
    if isinstance(pred_y[0], list):
        pred_y = [item[0] for item in pred_y]

    f_betas = [binary_f_beta(pred_y, true_y, beta, label) for label in labels]
    f_beta = mean(f_betas)
    return f_beta


def get_binary_metrics(pred_y, true_y, f_beta=1.0):
    """
    得到二分类的性能指标
    :param pred_y:
    :param true_y:
    :param f_beta:
    :return:
    """
    acc = accuracy(pred_y, true_y)
    recall = binary_recall(pred_y, true_y)
    precision = binary_precision(pred_y, true_y)
    f_beta = binary_f_beta(pred_y, true_y, f_beta)
    return acc, recall, precision, f_beta


def get_multi_metrics(pred_y, true_y, labels, f_beta=1.0):
    """
    得到多分类的性能指标
    :param pred_y:
    :param true_y:
    :param labels:
    :param f_beta:
    :return:
    """
    acc = accuracy(pred_y, true_y)
    recall = multi_recall(pred_y, true_y, labels)
    precision = multi_precision(pred_y, true_y, labels)
    f_beta = multi_f_beta(pred_y, true_y, labels, f_beta)
    return acc, recall, precision, f_beta

In [28]:
# 训练模型

# 生成训练集和验证集
trainReviews = data.trainReviews
trainLabels = data.trainLabels
evalReviews = data.evalReviews
evalLabels = data.evalLabels

wordEmbedding = data.wordEmbedding
labelList = data.labelList

# 定义计算图
with tf.Graph().as_default():

    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    session_conf.gpu_options.allow_growth=True
    session_conf.gpu_options.per_process_gpu_memory_fraction = 0.6  # 配置gpu占用率  

    sess = tf.Session(config=session_conf)
    
    # 定义会话
    with sess.as_default():
        cnn = TextCNN(config, wordEmbedding)
        
        globalStep = tf.Variable(0, name="globalStep", trainable=False)
        # 定义优化函数，传入学习速率参数
        optimizer = tf.train.AdamOptimizer(config.training.learningRate)
        # 计算梯度,得到梯度和变量
        gradsAndVars = optimizer.compute_gradients(cnn.loss)
        # 将梯度应用到变量下，生成训练器
        trainOp = optimizer.apply_gradients(gradsAndVars, global_step=globalStep)
        
        # 用summary绘制tensorBoard
        gradSummaries = []
        for g, v in gradsAndVars:
            if g is not None:
                tf.summary.histogram("{}/grad/hist".format(v.name), g)
                tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
        
        outDir = os.path.abspath(os.path.join(os.path.curdir, "summarys"))
        print("Writing to {}\n".format(outDir))
        
        lossSummary = tf.summary.scalar("loss", cnn.loss)
        summaryOp = tf.summary.merge_all()
        
        trainSummaryDir = os.path.join(outDir, "train")
        trainSummaryWriter = tf.summary.FileWriter(trainSummaryDir, sess.graph)
        
        evalSummaryDir = os.path.join(outDir, "eval")
        evalSummaryWriter = tf.summary.FileWriter(evalSummaryDir, sess.graph)
        
        
        # 初始化所有变量
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
        
        # 保存模型的一种方式，保存为pb文件
        savedModelPath = "../model/textCNN/savedModel"
        if os.path.exists(savedModelPath):
            os.rmdir(savedModelPath)
        builder = tf.saved_model.builder.SavedModelBuilder(savedModelPath)
            
        sess.run(tf.global_variables_initializer())

        def trainStep(batchX, batchY):
            """
            训练函数
            """   
            feed_dict = {
              cnn.inputX: batchX,
              cnn.inputY: batchY,
              cnn.dropoutKeepProb: config.model.dropoutKeepProb
            }
            _, summary, step, loss, predictions = sess.run(
                [trainOp, summaryOp, globalStep, cnn.loss, cnn.predictions],
                feed_dict)
            timeStr = datetime.datetime.now().isoformat()
            
            if config.numClasses == 1:
                acc, recall, prec, f_beta = get_binary_metrics(pred_y=predictions, true_y=batchY)

                
            elif config.numClasses > 1:
                acc, recall, prec, f_beta = get_multi_metrics(pred_y=predictions, true_y=batchY,
                                                              labels=labelList)
                
            trainSummaryWriter.add_summary(summary, step)
            
            return loss, acc, prec, recall, f_beta

        def devStep(batchX, batchY):
            """
            验证函数
            """
            feed_dict = {
              cnn.inputX: batchX,
              cnn.inputY: batchY,
              cnn.dropoutKeepProb: 1.0
            }
            summary, step, loss, predictions = sess.run(
                [summaryOp, globalStep, cnn.loss, cnn.predictions],
                feed_dict)
            
            if config.numClasses == 1:
            
                acc, precision, recall, f_beta = get_binary_metrics(pred_y=predictions, true_y=batchY)
            elif config.numClasses > 1:
                acc, precision, recall, f_beta = get_multi_metrics(pred_y=predictions, true_y=batchY, labels=labelList)
            
            evalSummaryWriter.add_summary(summary, step)
            
            return loss, acc, precision, recall, f_beta
        
        for i in range(config.training.epoches):
            # 训练模型
            print("start training model")
            for batchTrain in nextBatch(trainReviews, trainLabels, config.batchSize):
                loss, acc, prec, recall, f_beta = trainStep(batchTrain[0], batchTrain[1])
                
                currentStep = tf.train.global_step(sess, globalStep) 
                print("train: step: {}, loss: {}, acc: {}, recall: {}, precision: {}, f_beta: {}".format(
                    currentStep, loss, acc, recall, prec, f_beta))
                if currentStep % config.training.evaluateEvery == 0:
                    print("\nEvaluation:")
                    
                    losses = []
                    accs = []
                    f_betas = []
                    precisions = []
                    recalls = []
                    
                    for batchEval in nextBatch(evalReviews, evalLabels, config.batchSize):
                        loss, acc, precision, recall, f_beta = devStep(batchEval[0], batchEval[1])
                        losses.append(loss)
                        accs.append(acc)
                        f_betas.append(f_beta)
                        precisions.append(precision)
                        recalls.append(recall)
                        
                    time_str = datetime.datetime.now().isoformat()
                    print("{}, step: {}, loss: {}, acc: {},precision: {}, recall: {}, f_beta: {}".format(time_str, currentStep, mean(losses), 
                                                                                                       mean(accs), mean(precisions),
                                                                                                       mean(recalls), mean(f_betas)))
                    
                if currentStep % config.training.checkpointEvery == 0:
                    # 保存模型的另一种方法，保存checkpoint文件
                    path = saver.save(sess, "../model/textCNN", global_step=currentStep)
                    print("Saved model checkpoint to {}\n".format(path))
                    
        inputs = {"inputX": tf.saved_model.utils.build_tensor_info(cnn.inputX),
                  "keepProb": tf.saved_model.utils.build_tensor_info(cnn.dropoutKeepProb)}

        outputs = {"predictions": tf.saved_model.utils.build_tensor_info(cnn.predictions)}

        prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs, outputs=outputs,
                                                                                      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
        legacy_init_op = tf.group(tf.tables_initializer(), name="legacy_init_op")
        builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],
                                            signature_def_map={"predict": prediction_signature}, legacy_init_op=legacy_init_op)

        builder.save()

W0726 14:04:09.657838 140058595800832 deprecation.py:506] From <ipython-input-23-32f6413f910e>:67: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
W0726 14:04:10.461930 140058595800832 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0726 14:04:10.484984 140058595800832 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be

Tensor("output/predictions:0", shape=(?, 1), dtype=int32)
Writing to /home/jqxx/jupyternotebook/textClassifier-master/textCNN/summarys

start training model
train: step: 1, loss: 1.6475231647491455, acc: 0.5546875, recall: 0.47368421052631576, precision: 0.5, f_beta: 0.4864864864864865
train: step: 2, loss: 1.501197338104248, acc: 0.5, recall: 0.5806451612903226, precision: 0.4864864864864865, f_beta: 0.5294117647058824
train: step: 3, loss: 1.5962328910827637, acc: 0.515625, recall: 0.24242424242424243, precision: 0.5714285714285714, f_beta: 0.3404255319148936
train: step: 4, loss: 1.1230379343032837, acc: 0.6171875, recall: 0.6506024096385542, precision: 0.7297297297297297, f_beta: 0.6878980891719745
train: step: 5, loss: 1.8734443187713623, acc: 0.5390625, recall: 0.9178082191780822, precision: 0.5583333333333333, f_beta: 0.694300518134715
train: step: 6, loss: 1.9405497312545776, acc: 0.5078125, recall: 0.8688524590163934, precision: 0.49074074074074076, f_beta: 0.6272189349112426


train: step: 59, loss: 0.8931604623794556, acc: 0.6171875, recall: 0.6875, precision: 0.6027397260273972, f_beta: 0.6423357664233577
train: step: 60, loss: 0.8508222103118896, acc: 0.5859375, recall: 0.6307692307692307, precision: 0.5857142857142857, f_beta: 0.6074074074074074
train: step: 61, loss: 0.7002382278442383, acc: 0.65625, recall: 0.7681159420289855, precision: 0.654320987654321, f_beta: 0.7066666666666667
train: step: 62, loss: 0.6984703540802002, acc: 0.6015625, recall: 0.6212121212121212, precision: 0.6119402985074627, f_beta: 0.6165413533834586
train: step: 63, loss: 0.7978183031082153, acc: 0.59375, recall: 0.639344262295082, precision: 0.5652173913043478, f_beta: 0.6000000000000001
train: step: 64, loss: 0.743649959564209, acc: 0.5859375, recall: 0.5757575757575758, precision: 0.6031746031746031, f_beta: 0.5891472868217055
train: step: 65, loss: 0.7123292684555054, acc: 0.640625, recall: 0.6065573770491803, precision: 0.6271186440677966, f_beta: 0.6166666666666666
train

train: step: 117, loss: 0.4765554666519165, acc: 0.7890625, recall: 0.7424242424242424, precision: 0.8305084745762712, f_beta: 0.784
train: step: 118, loss: 0.5172358751296997, acc: 0.7734375, recall: 0.7, precision: 0.8596491228070176, f_beta: 0.7716535433070866
train: step: 119, loss: 0.4521632790565491, acc: 0.7890625, recall: 0.7580645161290323, precision: 0.7966101694915254, f_beta: 0.7768595041322315
train: step: 120, loss: 0.6275324821472168, acc: 0.734375, recall: 0.7647058823529411, precision: 0.7428571428571429, f_beta: 0.7536231884057971
train: step: 121, loss: 0.4490429759025574, acc: 0.765625, recall: 0.7741935483870968, precision: 0.75, f_beta: 0.7619047619047619
train: step: 122, loss: 0.48795533180236816, acc: 0.765625, recall: 0.7910447761194029, precision: 0.7681159420289855, f_beta: 0.7794117647058824
train: step: 123, loss: 0.5513334274291992, acc: 0.7265625, recall: 0.7213114754098361, precision: 0.7096774193548387, f_beta: 0.7154471544715446
train: step: 124, loss

train: step: 176, loss: 0.37308183312416077, acc: 0.84375, recall: 0.8714285714285714, precision: 0.8472222222222222, f_beta: 0.8591549295774648
train: step: 177, loss: 0.4030953645706177, acc: 0.8203125, recall: 0.8666666666666667, precision: 0.7761194029850746, f_beta: 0.8188976377952756
train: step: 178, loss: 0.45506933331489563, acc: 0.8125, recall: 0.8142857142857143, precision: 0.8382352941176471, f_beta: 0.8260869565217392
train: step: 179, loss: 0.4809369444847107, acc: 0.796875, recall: 0.8591549295774648, precision: 0.7922077922077922, f_beta: 0.8243243243243243
train: step: 180, loss: 0.4173513650894165, acc: 0.84375, recall: 0.896551724137931, precision: 0.7878787878787878, f_beta: 0.8387096774193549
train: step: 181, loss: 0.3967845141887665, acc: 0.8203125, recall: 0.8676470588235294, precision: 0.8082191780821918, f_beta: 0.8368794326241135
train: step: 182, loss: 0.40107765793800354, acc: 0.8515625, recall: 0.8923076923076924, precision: 0.8285714285714286, f_beta: 0.8

train: step: 233, loss: 0.48875099420547485, acc: 0.765625, recall: 0.7333333333333333, precision: 0.8461538461538461, f_beta: 0.7857142857142856
train: step: 234, loss: 0.38135233521461487, acc: 0.8515625, recall: 0.8970588235294118, precision: 0.8356164383561644, f_beta: 0.8652482269503545
train: step: 235, loss: 0.41840097308158875, acc: 0.8046875, recall: 0.8181818181818182, precision: 0.75, f_beta: 0.7826086956521738
train: step: 236, loss: 0.44924530386924744, acc: 0.84375, recall: 0.9130434782608695, precision: 0.8181818181818182, f_beta: 0.863013698630137
train: step: 237, loss: 0.27242574095726013, acc: 0.8984375, recall: 0.9285714285714286, precision: 0.8904109589041096, f_beta: 0.9090909090909091
train: step: 238, loss: 0.47711288928985596, acc: 0.75, recall: 0.8103448275862069, precision: 0.6911764705882353, f_beta: 0.746031746031746
train: step: 239, loss: 0.44424647092819214, acc: 0.765625, recall: 0.746031746031746, precision: 0.7704918032786885, f_beta: 0.75806451612903

train: step: 291, loss: 0.3752177357673645, acc: 0.8515625, recall: 0.75, precision: 0.8936170212765957, f_beta: 0.8155339805825244
train: step: 292, loss: 0.3830568194389343, acc: 0.828125, recall: 0.8405797101449275, precision: 0.8405797101449275, f_beta: 0.8405797101449275
train: step: 293, loss: 0.44956162571907043, acc: 0.8125, recall: 0.8852459016393442, precision: 0.7605633802816901, f_beta: 0.8181818181818182
train: step: 294, loss: 0.2633017301559448, acc: 0.875, recall: 0.9047619047619048, precision: 0.8507462686567164, f_beta: 0.8769230769230769
train: step: 295, loss: 0.29562124609947205, acc: 0.875, recall: 0.8852459016393442, precision: 0.8571428571428571, f_beta: 0.8709677419354839
train: step: 296, loss: 0.3987821936607361, acc: 0.8125, recall: 0.9122807017543859, precision: 0.7323943661971831, f_beta: 0.8124999999999999
train: step: 297, loss: 0.40164846181869507, acc: 0.8203125, recall: 0.7910447761194029, precision: 0.8548387096774194, f_beta: 0.8217054263565892
trai

train: step: 347, loss: 0.3147275149822235, acc: 0.8671875, recall: 0.8305084745762712, precision: 0.875, f_beta: 0.8521739130434782
train: step: 348, loss: 0.34190019965171814, acc: 0.84375, recall: 0.84375, precision: 0.84375, f_beta: 0.84375
train: step: 349, loss: 0.23805439472198486, acc: 0.8828125, recall: 0.9230769230769231, precision: 0.8571428571428571, f_beta: 0.888888888888889
train: step: 350, loss: 0.3030197024345398, acc: 0.859375, recall: 0.8571428571428571, precision: 0.8823529411764706, f_beta: 0.8695652173913043
train: step: 351, loss: 0.31672924757003784, acc: 0.8515625, recall: 0.8888888888888888, precision: 0.8533333333333334, f_beta: 0.8707482993197277
train: step: 352, loss: 0.24094133079051971, acc: 0.9140625, recall: 0.8805970149253731, precision: 0.9516129032258065, f_beta: 0.9147286821705426
train: step: 353, loss: 0.2772025167942047, acc: 0.859375, recall: 0.8888888888888888, precision: 0.8, f_beta: 0.8421052631578948
train: step: 354, loss: 0.29071122407913

train: step: 404, loss: 0.3899232745170593, acc: 0.8203125, recall: 0.704225352112676, precision: 0.9615384615384616, f_beta: 0.8130081300813008
train: step: 405, loss: 0.34054034948349, acc: 0.84375, recall: 0.8245614035087719, precision: 0.8245614035087719, f_beta: 0.8245614035087719
train: step: 406, loss: 0.3089768886566162, acc: 0.859375, recall: 0.8493150684931506, precision: 0.8985507246376812, f_beta: 0.8732394366197183
train: step: 407, loss: 0.3226710557937622, acc: 0.8515625, recall: 0.8688524590163934, precision: 0.828125, f_beta: 0.8480000000000001
train: step: 408, loss: 0.2831882834434509, acc: 0.8828125, recall: 0.9104477611940298, precision: 0.8714285714285714, f_beta: 0.8905109489051095
train: step: 409, loss: 0.29307836294174194, acc: 0.875, recall: 0.9393939393939394, precision: 0.8378378378378378, f_beta: 0.8857142857142858
train: step: 410, loss: 0.27137312293052673, acc: 0.875, recall: 0.8955223880597015, precision: 0.8695652173913043, f_beta: 0.8823529411764706


train: step: 462, loss: 0.22211503982543945, acc: 0.921875, recall: 0.9411764705882353, precision: 0.9142857142857143, f_beta: 0.9275362318840579
train: step: 463, loss: 0.27166748046875, acc: 0.859375, recall: 0.9130434782608695, precision: 0.84, f_beta: 0.8749999999999999
train: step: 464, loss: 0.28467291593551636, acc: 0.8671875, recall: 0.9545454545454546, precision: 0.8181818181818182, f_beta: 0.881118881118881
train: step: 465, loss: 0.2690189480781555, acc: 0.890625, recall: 0.8870967741935484, precision: 0.8870967741935484, f_beta: 0.8870967741935484
train: step: 466, loss: 0.3424748480319977, acc: 0.8359375, recall: 0.9833333333333333, precision: 0.7468354430379747, f_beta: 0.8489208633093525
train: step: 467, loss: 0.1822190284729004, acc: 0.9296875, recall: 0.9516129032258065, precision: 0.9076923076923077, f_beta: 0.9291338582677167
train: step: 468, loss: 0.2718096971511841, acc: 0.9140625, recall: 0.8524590163934426, precision: 0.9629629629629629, f_beta: 0.9043478260869

train: step: 519, loss: 0.2860107719898224, acc: 0.875, recall: 0.8840579710144928, precision: 0.8840579710144928, f_beta: 0.8840579710144928
train: step: 520, loss: 0.20289252698421478, acc: 0.9296875, recall: 0.9726027397260274, precision: 0.9102564102564102, f_beta: 0.9403973509933774
train: step: 521, loss: 0.221758633852005, acc: 0.9140625, recall: 0.9538461538461539, precision: 0.8857142857142857, f_beta: 0.9185185185185185
train: step: 522, loss: 0.20125478506088257, acc: 0.90625, recall: 0.9122807017543859, precision: 0.8813559322033898, f_beta: 0.8965517241379309
train: step: 523, loss: 0.17691631615161896, acc: 0.9375, recall: 0.9333333333333333, precision: 0.9333333333333333, f_beta: 0.9333333333333333
train: step: 524, loss: 0.2191617786884308, acc: 0.9296875, recall: 0.9074074074074074, precision: 0.9245283018867925, f_beta: 0.9158878504672898
train: step: 525, loss: 0.23411208391189575, acc: 0.90625, recall: 0.8939393939393939, precision: 0.921875, f_beta: 0.9076923076923

train: step: 577, loss: 0.2653692364692688, acc: 0.8984375, recall: 0.9253731343283582, precision: 0.8857142857142857, f_beta: 0.9051094890510949
train: step: 578, loss: 0.2126491665840149, acc: 0.9140625, recall: 0.9538461538461539, precision: 0.8857142857142857, f_beta: 0.9185185185185185
train: step: 579, loss: 0.20462453365325928, acc: 0.921875, recall: 0.9154929577464789, precision: 0.9420289855072463, f_beta: 0.9285714285714286
train: step: 580, loss: 0.230033740401268, acc: 0.921875, recall: 0.9436619718309859, precision: 0.9178082191780822, f_beta: 0.9305555555555556
train: step: 581, loss: 0.24571287631988525, acc: 0.8671875, recall: 0.8615384615384616, precision: 0.875, f_beta: 0.8682170542635659
train: step: 582, loss: 0.2663164436817169, acc: 0.8984375, recall: 0.95, precision: 0.8507462686567164, f_beta: 0.8976377952755905
train: step: 583, loss: 0.1354064643383026, acc: 0.96875, recall: 0.9692307692307692, precision: 0.9692307692307692, f_beta: 0.9692307692307692
train: s

W0726 14:07:26.579973 140058595800832 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:960: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.


2019-07-26T14:07:26.509093, step: 600, loss: 0.33671695414261943, acc: 0.8555689102564102,precision: 0.8028892306420004, recall: 0.9013923886161445, f_beta: 0.8480999110608156
Saved model checkpoint to ../model/textCNN-600

train: step: 601, loss: 0.24794968962669373, acc: 0.90625, recall: 0.8529411764705882, precision: 0.9666666666666667, f_beta: 0.90625
train: step: 602, loss: 0.26436248421669006, acc: 0.9375, recall: 0.8939393939393939, precision: 0.9833333333333333, f_beta: 0.9365079365079364
train: step: 603, loss: 0.19413676857948303, acc: 0.9140625, recall: 0.8787878787878788, precision: 0.9508196721311475, f_beta: 0.9133858267716536
train: step: 604, loss: 0.16696089506149292, acc: 0.953125, recall: 0.9655172413793104, precision: 0.9333333333333333, f_beta: 0.9491525423728815
train: step: 605, loss: 0.20650047063827515, acc: 0.9375, recall: 0.967741935483871, precision: 0.9090909090909091, f_beta: 0.9374999999999999
train: step: 606, loss: 0.2494959682226181, acc: 0.890625, rec

train: step: 658, loss: 0.19598627090454102, acc: 0.9453125, recall: 0.927536231884058, precision: 0.9696969696969697, f_beta: 0.9481481481481481
train: step: 659, loss: 0.199200838804245, acc: 0.9375, recall: 0.9206349206349206, precision: 0.9508196721311475, f_beta: 0.9354838709677418
train: step: 660, loss: 0.13629686832427979, acc: 0.9609375, recall: 0.9824561403508771, precision: 0.9333333333333333, f_beta: 0.9572649572649572
train: step: 661, loss: 0.16777881979942322, acc: 0.9375, recall: 0.9344262295081968, precision: 0.9344262295081968, f_beta: 0.9344262295081968
train: step: 662, loss: 0.11557561904191971, acc: 0.984375, recall: 0.9836065573770492, precision: 0.9836065573770492, f_beta: 0.9836065573770492
train: step: 663, loss: 0.1638304740190506, acc: 0.9296875, recall: 0.8904109589041096, precision: 0.9848484848484849, f_beta: 0.935251798561151
train: step: 664, loss: 0.16220177710056305, acc: 0.9375, recall: 0.9666666666666667, precision: 0.90625, f_beta: 0.93548387096774

train: step: 715, loss: 0.12893709540367126, acc: 0.953125, recall: 0.9545454545454546, precision: 0.9545454545454546, f_beta: 0.9545454545454546
train: step: 716, loss: 0.16032053530216217, acc: 0.9296875, recall: 0.9230769230769231, precision: 0.9375, f_beta: 0.9302325581395349
train: step: 717, loss: 0.1466818004846573, acc: 0.9453125, recall: 0.9846153846153847, precision: 0.9142857142857143, f_beta: 0.9481481481481482
train: step: 718, loss: 0.15627935528755188, acc: 0.9453125, recall: 0.9393939393939394, precision: 0.9538461538461539, f_beta: 0.9465648854961831
train: step: 719, loss: 0.19274261593818665, acc: 0.9140625, recall: 0.8970588235294118, precision: 0.9384615384615385, f_beta: 0.9172932330827067
train: step: 720, loss: 0.1222994327545166, acc: 0.9765625, recall: 0.9846153846153847, precision: 0.9696969696969697, f_beta: 0.9770992366412214
train: step: 721, loss: 0.2129334807395935, acc: 0.9453125, recall: 0.9827586206896551, precision: 0.9047619047619048, f_beta: 0.9421

train: step: 773, loss: 0.22417190670967102, acc: 0.90625, recall: 0.8113207547169812, precision: 0.9555555555555556, f_beta: 0.8775510204081634
train: step: 774, loss: 0.1333397924900055, acc: 0.96875, recall: 0.9491525423728814, precision: 0.9824561403508771, f_beta: 0.9655172413793103
train: step: 775, loss: 0.11234234273433685, acc: 0.96875, recall: 0.9692307692307692, precision: 0.9692307692307692, f_beta: 0.9692307692307692
train: step: 776, loss: 0.09409839659929276, acc: 0.9765625, recall: 0.9852941176470589, precision: 0.9710144927536232, f_beta: 0.9781021897810219
train: step: 777, loss: 0.16091269254684448, acc: 0.953125, recall: 0.967741935483871, precision: 0.9375, f_beta: 0.9523809523809523
train: step: 778, loss: 0.18320634961128235, acc: 0.953125, recall: 0.971830985915493, precision: 0.9452054794520548, f_beta: 0.9583333333333334
train: step: 779, loss: 0.17399711906909943, acc: 0.9140625, recall: 0.9230769230769231, precision: 0.9090909090909091, f_beta: 0.91603053435

W0726 14:08:16.673212 140058595800832 deprecation.py:323] From <ipython-input-28-c2e122e5ff6c>:149: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0726 14:08:16.676818 140058595800832 deprecation.py:506] From <ipython-input-28-c2e122e5ff6c>:158: calling SavedModelBuilder.add_meta_graph_and_variables (from tensorflow.python.saved_model.builder_impl) with legacy_init_op is deprecated and will be removed in a future version.
Instructions for updating:
Pass your op to the equivalent parameter main_op instead.


train: step: 780, loss: 0.14672808349132538, acc: 0.9609375, recall: 0.9733333333333334, precision: 0.9605263157894737, f_beta: 0.9668874172185431


In [31]:
x = "this movie is full of references like mad max ii the wild one and many others the ladybug´s face it´s a clear reference or tribute to peter lorre this movie is a masterpiece we´ll talk much more about in the future"

# 注：下面两个词典要保证和当前加载的模型对应的词典是一致的
with open("../data/wordJson/word2idx.json", "r", encoding="utf-8") as f:
    word2idx = json.load(f)
        
with open("../data/wordJson/label2idx.json", "r", encoding="utf-8") as f:
    label2idx = json.load(f)
idx2label = {value: key for key, value in label2idx.items()}
    
xIds = [word2idx.get(item, word2idx["UNK"]) for item in x.split(" ")]
if len(xIds) >= config.sequenceLength:
    xIds = xIds[:config.sequenceLength]
else:
    xIds = xIds + [word2idx["PAD"]] * (config.sequenceLength - len(xIds))

graph = tf.Graph()
with graph.as_default():
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options)
    sess = tf.Session(config=session_conf)

    with sess.as_default():
        checkpoint_file = tf.train.latest_checkpoint("../model")
        saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
        saver.restore(sess, checkpoint_file)

        # 获得需要喂给模型的参数，输出的结果依赖的输入值
        inputX = graph.get_operation_by_name("inputX").outputs[0]
        dropoutKeepProb = graph.get_operation_by_name("dropoutKeepProb").outputs[0]

        # 获得输出的结果
        predictions = graph.get_tensor_by_name("output/predictions:0")

        pred = sess.run(predictions, feed_dict={inputX: [xIds], dropoutKeepProb: 1.0})[0]
        
pred = [idx2label[item] for item in pred]     
print(pred)

W0726 14:13:23.752521 140058595800832 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


['1']
