# 导入参数

In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


# 模型建构

## 输入层

In [5]:
x = tf.placeholder(tf.float32, [None, 784], name='X')
y = tf.placeholder(tf.float32, [None, 10], name='Y')

## 隐藏层

In [6]:
# 隐藏层神经元数量
H1_NN = 256

with tf.name_scope('Hidden_Layer'):
    W1 = tf.Variable(tf.random_normal([784, H1_NN]))
    b1 = tf.Variable(tf.zeros([H1_NN]))
    
    Y1 = tf.nn.relu(tf.matmul(x, W1) + b1)

## 输出层

In [8]:
with tf.name_scope('Output_Layer'):
    W2 = tf.Variable(tf.random_normal([H1_NN, 10]))
    b2 = tf.Variable(tf.zeros([10]))
    
    forward = tf.matmul(Y1, W2) + b2
    pred = tf.nn.softmax(forward)

# 模型训练
## 参数及损失函数、优化器的定义

In [10]:
# 训练参数
train_epochs = 40
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1
learning_rate = 0.01

# 损失函数
# 运用TensorFlow提供的结合Softmax的交叉熵损失函数
# 避免因为log(0)值为NaN造成的数据不稳定
loss_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))

# 选择优化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_func)

## 定义准确率

In [12]:
correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

## 训练模型

In [15]:
from time import time
startTime = time()

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer, feed_dict={x: xs, y: ys})
        
    loss, acc = sess.run([loss_func, accuracy], 
                         feed_dict={x: mnist.validation.images,
                                    y: mnist.validation.labels})
    
    if (epoch + 1) % display_step == 0:
        print('Train Epoch: %02d' % (epoch + 1),
              'Loss:', '{:0.9f}'.format(loss),
              'Accuracy:', '{:0.4f}'.format(acc))

# 显示运行时间
duration = time() - startTime
print('Finished. Train takes:', '{:0.2f}'.format(duration))

Train Epoch: 01 Loss: 1.169759631 Accuracy: 0.9334
Train Epoch: 02 Loss: 0.721808016 Accuracy: 0.9442
Train Epoch: 03 Loss: 0.583913028 Accuracy: 0.9500
Train Epoch: 04 Loss: 0.455841362 Accuracy: 0.9504
Train Epoch: 05 Loss: 0.408018410 Accuracy: 0.9494
Train Epoch: 06 Loss: 0.402976066 Accuracy: 0.9546
Train Epoch: 07 Loss: 0.411095887 Accuracy: 0.9488
Train Epoch: 08 Loss: 0.412107706 Accuracy: 0.9532
Train Epoch: 09 Loss: 0.375777334 Accuracy: 0.9624
Train Epoch: 10 Loss: 0.347757310 Accuracy: 0.9606
Train Epoch: 11 Loss: 0.357870251 Accuracy: 0.9654
Train Epoch: 12 Loss: 0.388767958 Accuracy: 0.9632
Train Epoch: 13 Loss: 0.473384947 Accuracy: 0.9652
Train Epoch: 14 Loss: 0.430510193 Accuracy: 0.9640
Train Epoch: 15 Loss: 0.436181456 Accuracy: 0.9676
Train Epoch: 16 Loss: 0.433787107 Accuracy: 0.9666
Train Epoch: 17 Loss: 0.558526158 Accuracy: 0.9690
Train Epoch: 18 Loss: 0.524459422 Accuracy: 0.9672
Train Epoch: 19 Loss: 0.541688263 Accuracy: 0.9666
Train Epoch: 20 Loss: 0.4340189

# 模型评估

In [17]:
acc_test = sess.run(accuracy,
                    feed_dict={x: mnist.test.images,
                               y: mnist.test.labels})
print('Test Accuracy:', acc_test)

Test Accuracy: 0.9695


# 模型应用
## 预测

In [31]:
prediction_result = sess.run(tf.argmax(pred, 1),
                             feed_dict={x:mnist.test.images})

## 找出错误预测项

In [32]:
compare_list = prediction_result == np.argmax(mnist.test.labels, 1)
err_list = [i for i in range(len(compare_list)) if compare_list[i]==False]
print(err_list, len(err_list))

[92, 121, 125, 199, 241, 247, 259, 321, 340, 359, 403, 404, 447, 456, 469, 495, 522, 582, 591, 659, 674, 684, 691, 707, 720, 726, 797, 839, 857, 877, 900, 924, 947, 951, 956, 965, 990, 1002, 1014, 1032, 1039, 1082, 1107, 1112, 1173, 1178, 1181, 1192, 1226, 1228, 1232, 1242, 1247, 1256, 1289, 1319, 1331, 1349, 1378, 1393, 1394, 1395, 1425, 1494, 1500, 1522, 1530, 1549, 1551, 1553, 1642, 1669, 1670, 1678, 1681, 1687, 1717, 1730, 1754, 1828, 1850, 1871, 1901, 1941, 1952, 2004, 2024, 2035, 2043, 2044, 2053, 2058, 2063, 2070, 2093, 2107, 2109, 2129, 2130, 2135, 2182, 2189, 2237, 2272, 2293, 2325, 2369, 2387, 2395, 2406, 2414, 2426, 2433, 2462, 2488, 2526, 2573, 2577, 2597, 2631, 2648, 2654, 2730, 2742, 2743, 2758, 2877, 2896, 2921, 2927, 2939, 2953, 2979, 3005, 3030, 3073, 3102, 3115, 3183, 3225, 3267, 3280, 3289, 3339, 3503, 3520, 3558, 3559, 3567, 3574, 3575, 3604, 3626, 3629, 3751, 3757, 3762, 3776, 3780, 3853, 3906, 3926, 3941, 3967, 3968, 3976, 3985, 4007, 4027, 4075, 4078, 4131, 4140,

## 定义输出错误的函数

In [45]:
def print_predict_err(labels, # 标签列表
                      prediction): # 预测值列表
    count = 0
    compare_list = prediction_result == np.argmax(mnist.test.labels, 1)
    err_list = [i for i in range(len(compare_list)) if compare_list[i]==False]
    for x in err_list:
        print('index: %03d' % (x),
              '标签值:', np.argmax(labels[x]),
              '预测值:', prediction[x])
        count += 1
    print('Total:', count)

In [46]:
print_predict_err(labels=mnist.test.labels, 
                  prediction=prediction_result)

index: 092 标签值: 9 预测值: 4
index: 121 标签值: 4 预测值: 6
index: 125 标签值: 9 预测值: 4
index: 199 标签值: 2 预测值: 3
index: 241 标签值: 9 预测值: 3
index: 247 标签值: 4 预测值: 6
index: 259 标签值: 6 预测值: 0
index: 321 标签值: 2 预测值: 7
index: 340 标签值: 5 预测值: 3
index: 359 标签值: 9 预测值: 4
index: 403 标签值: 8 预测值: 5
index: 404 标签值: 2 预测值: 7
index: 447 标签值: 4 预测值: 7
index: 456 标签值: 2 预测值: 1
index: 469 标签值: 5 预测值: 3
index: 495 标签值: 8 预测值: 0
index: 522 标签值: 7 预测值: 3
index: 582 标签值: 8 预测值: 2
index: 591 标签值: 8 预测值: 0
index: 659 标签值: 2 预测值: 3
index: 674 标签值: 5 预测值: 3
index: 684 标签值: 7 预测值: 3
index: 691 标签值: 8 预测值: 4
index: 707 标签值: 4 预测值: 9
index: 720 标签值: 5 预测值: 8
index: 726 标签值: 7 预测值: 3
index: 797 标签值: 5 预测值: 8
index: 839 标签值: 8 预测值: 3
index: 857 标签值: 5 预测值: 3
index: 877 标签值: 8 预测值: 6
index: 900 标签值: 1 预测值: 3
index: 924 标签值: 2 预测值: 3
index: 947 标签值: 8 预测值: 9
index: 951 标签值: 5 预测值: 4
index: 956 标签值: 1 预测值: 2
index: 965 标签值: 6 预测值: 0
index: 990 标签值: 2 预测值: 4
index: 1002 标签值: 2 预测值: 3
index: 1014 标签值: 6 预测值: 0
index: 1032 标签值: 5 预测值: