In [25]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
tf.reset_default_graph()

mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [26]:
# 构建模型
# 此为两层NN示例
H1_NN = 256
H2_NN = 64

# 定义全连接层函数
def fcn_layer(inputs,           # 数入数据
              input_dim,        # 输入神经元数量
              output_dim,       # 输出神经元数量
              activation=None): # 激活函数
    
    # 以截断正态分布的随机函数初始化
    W = tf.Variable(tf.truncated_normal([input_dim, output_dim], stddev=0.1))
    # 以0初始化b
    b = tf.Variable(tf.zeros([output_dim]))
    
    # 建立权重计算
    XWb = tf.matmul(inputs, W) + b
    
    # 激活函数
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
        
    return outputs

# 建构双隐层模型
# 输入层
x = tf.placeholder(tf.float32, [None, 784], name='X')
y = tf.placeholder(tf.float32, [None, 10], name='Y')
# 隐藏层1
h1 = fcn_layer(x, 784, H1_NN, tf.nn.relu)
# 隐藏层2
h2 = fcn_layer(h1, H1_NN, H2_NN, tf.nn.relu)
# 输出层
forward = fcn_layer(h2, H2_NN, 10, None)
pred = tf.nn.softmax(forward)

# 关于训练模型与保存

# 初始化参数和文件目录
# 存储模型的粒度
save_step = 5
# 创建保存模型文件的目录
import os
ckpt_dir = './ckpt_dir'
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

# 训练并存储模型
saver = tf.train.Saver()

# 基础参数
train_epoch = 40
report_step = 1
learning_step = 0.01
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)

# 定义损失函数
loss_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))

# 定义优化选择器
optimizer = tf.train.AdamOptimizer(learning_step).minimize(loss_func)

# 定义准确率
correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# 训练模型
from time import time
startTime = time()

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

for epoch in range(train_epoch):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size)
        sess.run(optimizer, feed_dict={x: xs, y: ys})
    
    loss, acc = sess.run([loss_func, accuracy], feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
    
    if (epoch + 1) % report_step == 0:
        print('Epoch: %02d' % (epoch+1), 'Loss: %.9f' % (loss), 'Accuracy: %.4f' % (acc))
        
    
    if (epoch + 1) % save_step == 0:
        # 存储模型
        saver.save(sess, os.path.join(ckpt_dir, 'mnist_265_64_m{:06d}.ckpt'.format(epoch+1)))
        print('mnist_265_64_m{:06d}.ckpt saver'.format(epoch+1))
        
saver.save(sess, os.path.join(ckpt_dir, 'mnist_265_64_m.ckpt'))
print('Model saved!')
    
# 显示运行时间
duration = time() - startTime
print('Finished. Train takes %.2f duration' % (duration))
        
        

Epoch: 01 Loss: 0.172248676 Accuracy: 0.9520
Epoch: 02 Loss: 0.151537746 Accuracy: 0.9578
Epoch: 03 Loss: 0.126806095 Accuracy: 0.9706
Epoch: 04 Loss: 0.125573605 Accuracy: 0.9678
Epoch: 05 Loss: 0.138096437 Accuracy: 0.9658
mnist_265_64_m000005.ckpt saver
Epoch: 06 Loss: 0.191718429 Accuracy: 0.9608
Epoch: 07 Loss: 0.162698314 Accuracy: 0.9712
Epoch: 08 Loss: 0.155213013 Accuracy: 0.9620
Epoch: 09 Loss: 0.154603943 Accuracy: 0.9732
Epoch: 10 Loss: 0.123715870 Accuracy: 0.9734
mnist_265_64_m000010.ckpt saver
Epoch: 11 Loss: 0.165490240 Accuracy: 0.9712
Epoch: 12 Loss: 0.177991107 Accuracy: 0.9722
Epoch: 13 Loss: 0.149508908 Accuracy: 0.9748
Epoch: 14 Loss: 0.181552991 Accuracy: 0.9706
Epoch: 15 Loss: 0.177471086 Accuracy: 0.9688
mnist_265_64_m000015.ckpt saver
Epoch: 16 Loss: 0.197503045 Accuracy: 0.9686
Epoch: 17 Loss: 0.175129667 Accuracy: 0.9738
Epoch: 18 Loss: 0.171653599 Accuracy: 0.9714
Epoch: 19 Loss: 0.171961501 Accuracy: 0.9756
Epoch: 20 Loss: 0.224270657 Accuracy: 0.9722
mnis

In [27]:
test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})

In [28]:
print(test_acc)

0.9676


In [31]:
sess.close()