In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 数据集
import LeNet5_infernece
import os
import numpy as np

#### 1. 定义神经网络相关的参数

In [2]:
BATCH_SIZE = 100  # batch大小
LEARNING_RATE_BASE = 0.01 # 初始学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减系数
REGULARIZATION_RATE = 0.0001  # 正则化系数
TRAINING_STEPS = 6000   # 训练轮数
MOVING_AVERAGE_DECAY = 0.99  # 滑动平均系数

#### 2. 定义训练过程

In [3]:
def train(mnist):
    # 定义输出为4维矩阵的placeholder
    x = tf.placeholder(tf.float32, [
            BATCH_SIZE,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.IMAGE_SIZE,
            LeNet5_infernece.NUM_CHANNELS],
        name='x-input')    # batch大小，图像大小， 图像大小， 通道数， x占位符
    y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input') # y占位符
    
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) # l2正则化
    y = LeNet5_infernece.inference(x,False,regularizer)  # 前向传播计算结果
    global_step = tf.Variable(0, trainable=False) # 当前训练轮数

    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)  # 滑动平均
    variables_averages_op = variable_averages.apply(tf.trainable_variables())  # 对tf参数进行滑动平均
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) # 交叉熵
    cross_entropy_mean = tf.reduce_mean(cross_entropy) # 交叉熵损失函数值
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) # 交叉熵损失值+ l2正则化项
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
        staircase=True)  # 指数衰减法，对学习率进行更新

    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) # 优化器
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')  # 只执行train_step和variables_averages_op
        
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()
    with tf.Session() as sess: # 创建会话
        tf.global_variables_initializer().run() # 初始化
        for i in range(TRAINING_STEPS):  # 循环训练
            xs, ys = mnist.train.next_batch(BATCH_SIZE) # 获取训练的batch 

            reshaped_xs = np.reshape(xs, (
                BATCH_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.IMAGE_SIZE,
                LeNet5_infernece.NUM_CHANNELS))  # 改变训练集图像的形状
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})  # 训练，并返回损失值、当前步数

            if i % 1000 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))  # 每1000轮打印当前的步数、损失值

#### 3. 主程序入口

In [4]:
def main(argv=None):
    mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    main()

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
After 1 training step(s), loss on training batch is 4.61288.
After 1001 training step(s), loss on training batch is 0.689922.
After 2001 training step(s), loss on training batch is 0.660581.
After 3001 training step(s), loss on t