In [None]:
"""使用feed dictionary训练和评估MINIST网络"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

# 作为 external flags 的基本模型参数.
FLAGS = None


def placeholder_inputs(batch_size):
  """生成占位符变量以代表张量 
  
    这些占位符将被用作剩余模型构建代码的输入，并且会被喂入后面.run()循环中的下载数据


  参数:
    batch_size: 批处理规模将同时be baked into两个占位符

  返回值:
    images_placeholder: 图片占位符.
    labels_placeholder: 标签占位符.
  """
  # 请注意占位符的形状与完整的图片和标签形状相匹配
  # 但第一维现在是批处理规模而不是完整的训练和测试数据集的规模 
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder


def fill_feed_dict(data_set, images_pl, labels_pl):
  """填充feed_dict以训练给定步骤

 一个feed_dict的格式为:
  feed_dict = {
      <占位符>: <传递给占位符的值的张量>
      ....
  }

  参数:
    data_set: 从input_data.read_data_sets()中获取的图片和标签的集合
    images_pl: 从placeholder_inputs()中获得的图片占位符
    labels_pl: 从placeholder_inputs()中获得的标签占位符

  返回值:
    feed_dict: 从占位符到值得feed字典映射
  """
  # 为填充了下一个批次规模示例的占位符创建feed_dict
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict


def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """针对完整的一个周期运行一次评估。

  参数:
    sess: 关于哪个模型被训练的session
    eval_correct: The Tensor返回正确的预测的数量
    images_placeholder: 图片占位符
    labels_placeholder: 标签占位符
    data_set: 从input_data.read_data_sets()中获得的要评估的图像和标签集
  """
  # 运行一个评估周期
  true_count = 0  # 计算正确预测的数量
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))


def run_training():
  """训练几步MNIST"""
  # 在MNIST上获取图片和标签的训练集验证集和测试集Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # 告诉TensorFlow,模型将构建进默认的图中
  with tf.Graph().as_default():
    # 为图片和标签产生占位符
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # 构建一个图，用于对模型计算预测
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # 将损失的计算结果添加到图和节点当中去.
    loss = mnist.loss(logits, labels_placeholder)

    # 将计算和应用的梯度添加到图和节点当中去
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # 添加OP以在评估期间比较logits和标签
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # 将所有OP汇总为一个OP.
    summary = tf.summary.merge_all()

    # 添加变量初始化OP
    init = tf.global_variables_initializer()

    # 创建一个编写检查点的存档.
    saver = tf.train.Saver()

    # 创建一个在图上运行的节点的Session
    sess = tf.Session()

    # 实例化SummaryWriter以输出摘要和图
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # 在所有都构建好以后:

    # 运行OP来初始化变量
    sess.run(init)

    # 开始训练循环
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # 使用此特定训练步数中的图片和标签来填充feed字典
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # 运行模型一步. 返回值是train_op(被丢弃)和loss op的激活
      # 要想检查Ops和变量的值，你也许要将他们包含在一个sess.run()传递的列表中并且张量值会被返回到调用的元组中
      # 且张量值会被返回到调用的元组中

      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # 编写摘要并经常打印概述
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # 周期性的保存检查点并评估模型
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # 通过训练集评估
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # 通过验证集评估
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # 通过测试集评估
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)


def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/input_data'),
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/logs/fully_connected_feed'),
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
