# 可视化计算图

本例可视化mnist手写字识别的计算图

In [9]:
# 加载库文件
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [10]:
tf.reset_default_graph()

# 定义参数
INPUT_NODE = 784
LAYER1_NODE = 500
OUTPUT_NODE = 10

BATCH_SIZE = 32              # batch的大小
LEARNING_RATE_BASE = 0.8     # 基础学习率
LEARNING_RATE_DECAY = 0.99   # 学习率的衰减系数
LEARNING_DECAY_STEPS = 50    # 循环一次数据集的轮数
REGULARIZATION_RATE = 0.0001 # 正则化系数
TRAINING_STEPS = 3000        # 训练的步数
MOVING_AVERAGE_DECAY = 0.99  # 动量系数

LOG_PATH = "../../../../other/test.log"

global_step = tf.Variable(0, trainable=False, name="global_step")

In [11]:
# 获取数据集
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)

# 计算LEARING_DECAY_STEP
LEARNING_DECAY_STEPS = mnist.train.num_examples/BATCH_SIZE

# 1.输入层
with tf.name_scope("Input"):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")
    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")

Extracting ../../../datasets/MNIST_data\train-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data\train-labels-idx1-ubyte.gz
Extracting ../../../datasets/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data\t10k-labels-idx1-ubyte.gz


In [12]:
# 2.隐含层
with tf.name_scope("Layer1"):
    w1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE]), name="w1")
    b1 = tf.Variable(tf.truncated_normal([LAYER1_NODE]), name="b1")
    before_activate1 = tf.add(tf.matmul(x, w1), b1, name="layer1_add")
    layer1 = tf.nn.leaky_relu(before_activate1, name="layer1_result")
    regularizer_w1 = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)(w1)
    tf.add_to_collection("regularizer", regularizer_w1)
    
    # 把相关参数添加进直方图
    tf.summary.histogram("w1", w1)
    tf.summary.histogram("b1", b1)
    tf.summary.histogram("before_activate1", before_activate1)
    tf.summary.histogram("layer1", layer1)

In [13]:
# 3.输出层
with tf.name_scope("Output"):
    w2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE]), name="w2")
    b2 = tf.Variable(tf.truncated_normal([OUTPUT_NODE]), name="b2")
    before_activate2 = tf.add(tf.matmul(layer1, w2), b2, name="output_add")
    regularizer_w2 = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)(w2)
    tf.add_to_collection("regularizer", regularizer_w2)
    
    # 把相关参数添加进直方图
    tf.summary.histogram("w2", w2)
    tf.summary.histogram("b2", b2)
    tf.summary.histogram("before_activate2", before_activate2)

In [14]:
# 4.预测值
with tf.name_scope("prediction"):
    output_result = tf.nn.softmax(before_activate2, name="prediction")
    prediction = tf.equal(tf.argmax(output_result), tf.argmax(y_))
    accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

In [15]:
# 5. 申请动量操作
with tf.name_scope("movingaverage"):
    # 申请动量
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    # 把一切可以进行训练的变量加入动量
    maintain_average_op = ema.apply(tf.trainable_variables())

In [16]:
# 5.损失函数
with tf.name_scope("Loss"):
    # 计算交叉熵
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=before_activate2, name="cross_entropy")
    # 计算交叉熵的平均值
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    # 计算交叉熵
    regularizer_loss = tf.add_n(tf.get_collection("regularizer"), name="add_regularizer_loss")
    # 计算总的算是函数
    sum_loss = tf.add(cross_entropy, regularizer_loss, name="add_all_loss")
    
    # 把损失函数添加进图示
    tf.summary.scalar("sum_loss", sum_loss)

In [17]:
# 6.优化空间
with tf.name_scope("optimize"):
    # 定义可变学习率
    update_learning = tf.train.exponential_decay(learning_rate=LEARNING_RATE_BASE,
                                                decay_steps=50,
                                                decay_rate=LEARNING_RATE_DECAY,
                                                global_step=global_step,
                                                name="learning")
    # 定义优化操作
    optimize_op = tf.train.GradientDescentOptimizer(update_learning).minimize(sum_loss, global_step=global_step)
    
# 7. 定义依赖操作
with tf.control_dependencies([optimize_op, maintain_average_op]):
    main_op = tf.no_op(name="main_op")

In [19]:
# 配置会话参数
config = tf.ConfigProto(allow_soft__placement=True, 
                        log_device_placement=True)
config.gpu_options.allow_growth = True

# 定义 summary 的综合操作
summary_merge = tf.summary.merge_all()

# 7. 进行训练和预测
with tf.Session(config=config) as sess:
    # 定义写日志FileWriter
    FileWriter = tf.summary.FileWriter(LOG_PATH, sess.graph)
    
    # 变量初始化
    tf.global_variables_initializer().run()
    
    # 迭代过程
    for i in range(TRAINING_STEPS):
        # 获取batch数据
        xs, ys = mnist.train_next_batch(batch_size=batch_size)
        tf.summary.image("image", tensor=xs, max_outputs=BATCH_SIZE)
        
        # 每100次保存运行时的信息
        if i % 100 == 0:
            # 配置需要记录的信息种类
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            
            # 建立记录信息的protocol
            run_metadata = tf.RunMetadata()
            
            # 将配置信息和记录运行信息的protocol传入运行的过程，
            # 从而记录运行时每一个节点的时间、空间开销信息
            _, loss, step, merge_all = sess.run([main_op, loss, global_step, summary_merge], 
                                     feed_dict={x:xs, y_:ys},
                                     options=run_options,
                                     run_metadata=run_metadata)
            # 将节点在运行时的信息写入日志文件
            FileWriter.add_run_metadata(run_metadata=run_metadata,
                                        tag=("step%3d" % i),
                                        global_step=i)
            # 将 summary 数据写入日志中
            FileWriter.add_summary(merge_all, i)
            
            # 打印运行的信息
            printf("After steps(%d) loss%g." % (step, loss))
            
        else:
            # 运行计算图
            _, loss, step, merge_all= sess.run([main_op, sum_loss, global_step, summary_merge], feed_dict={x: xs, y_:ys})
            # 将 summary 数据写入日志中
            FileWriter.add_summary(merge_all, i)
        
    # 关闭日志
    FileWriter.close()

SyntaxError: keyword can't be an expression (<ipython-input-19-3acb969b9623>, line 4)