In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(".", one_hot=True, reshape=False)

import tensorflow as tf

Extracting ./train-images-idx3-ubyte.gz
Extracting ./train-labels-idx1-ubyte.gz
Extracting ./t10k-images-idx3-ubyte.gz
Extracting ./t10k-labels-idx1-ubyte.gz


In [2]:
# Parameters
# 参数
learning_rate = 0.00001
epochs = 10
batch_size = 128

# Number of samples to calculate validation and accuracy
# Decrease this if you're running out of memory to calculate accuracy
# 用来验证和计算准确率的样本数
# 如果内存不够，可以调小这个数字
test_valid_size = 256

# Network Parameters
# 神经网络参数
n_classes = 10  # MNIST total classes (0-9 digits)
dropout = 0.75  # Dropout, probability to keep units

### Weights and Biases

In [3]:
# Store layers weight & bias
weights = {
    'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
    'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),
    'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])),
    'out': tf.Variable(tf.random_normal([1024, n_classes]))}

biases = {
    'bc1': tf.Variable(tf.random_normal([32])),
    'bc2': tf.Variable(tf.random_normal([64])),
    'bd1': tf.Variable(tf.random_normal([1024])),
    'out': tf.Variable(tf.random_normal([n_classes]))}

In [4]:
def conv2d(x, W, b, strides=1):
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

In [5]:
def maxpool2d(x, k=2):
    return tf.nn.max_pool(
        x,
        ksize=[1, k, k, 1],
        strides=[1, k, k, 1],
        padding='SAME')

In [10]:
def conv_net(x, weights, biases, dropout):
    # Layer 1 - 28*28*1 to 14*14*32
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    conv1 = maxpool2d(conv1, k=2)

    # Layer 2 - 14*14*32 to 7*7*64
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer - 7*7*64 to 1024
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    fc1 = tf.nn.dropout(fc1, dropout)

    # Output Layer - class prediction - 1024 to 10
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

### Session

In [11]:
# tf Graph input
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)

# Model
logits = conv_net(x, weights, biases, keep_prob)

# Define loss and optimizer
cost = tf.reduce_mean(\
    tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\
    .minimize(cost)

# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf. global_variables_initializer()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(epochs):
        for batch in range(mnist.train.num_examples//batch_size):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: dropout})

            # Calculate batch loss and accuracy
            loss = sess.run(cost, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: 1.})
            valid_acc = sess.run(accuracy, feed_dict={
                x: mnist.validation.images[:test_valid_size],
                y: mnist.validation.labels[:test_valid_size],
                keep_prob: 1.})

            print('Epoch {:>2}, Batch {:>3} -'
                  'Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(
                epoch + 1,
                batch + 1,
                loss,
                valid_acc))

    # Calculate Test Accuracy
    test_acc = sess.run(accuracy, feed_dict={
        x: mnist.test.images[:test_valid_size],
        y: mnist.test.labels[:test_valid_size],
        keep_prob: 1.})
    print('Testing Accuracy: {}'.format(test_acc))

Epoch  1, Batch   1 -Loss: 59823.7500 Validation Accuracy: 0.117188
Epoch  1, Batch   2 -Loss: 49509.4766 Validation Accuracy: 0.105469
Epoch  1, Batch   3 -Loss: 37248.6094 Validation Accuracy: 0.117188
Epoch  1, Batch   4 -Loss: 32630.1074 Validation Accuracy: 0.125000
Epoch  1, Batch   5 -Loss: 28355.1016 Validation Accuracy: 0.128906
Epoch  1, Batch   6 -Loss: 25600.2715 Validation Accuracy: 0.144531
Epoch  1, Batch   7 -Loss: 23424.0078 Validation Accuracy: 0.175781
Epoch  1, Batch   8 -Loss: 20961.8164 Validation Accuracy: 0.218750
Epoch  1, Batch   9 -Loss: 16617.2207 Validation Accuracy: 0.242188
Epoch  1, Batch  10 -Loss: 20034.5957 Validation Accuracy: 0.261719
Epoch  1, Batch  11 -Loss: 16036.7197 Validation Accuracy: 0.277344
Epoch  1, Batch  12 -Loss: 15199.7344 Validation Accuracy: 0.285156
Epoch  1, Batch  13 -Loss: 13412.7461 Validation Accuracy: 0.304688
Epoch  1, Batch  14 -Loss: 16831.9531 Validation Accuracy: 0.308594
Epoch  1, Batch  15 -Loss: 16249.4922 Validation

Epoch  1, Batch 122 -Loss:  4113.9341 Validation Accuracy: 0.515625
Epoch  1, Batch 123 -Loss:  4263.0596 Validation Accuracy: 0.503906
Epoch  1, Batch 124 -Loss:  4052.3042 Validation Accuracy: 0.515625
Epoch  1, Batch 125 -Loss:  5285.2666 Validation Accuracy: 0.519531
Epoch  1, Batch 126 -Loss:  3559.9976 Validation Accuracy: 0.519531
Epoch  1, Batch 127 -Loss:  2470.4985 Validation Accuracy: 0.515625
Epoch  1, Batch 128 -Loss:  5168.1709 Validation Accuracy: 0.519531
Epoch  1, Batch 129 -Loss:  5030.4014 Validation Accuracy: 0.523438
Epoch  1, Batch 130 -Loss:  3870.2556 Validation Accuracy: 0.523438
Epoch  1, Batch 131 -Loss:  3034.5449 Validation Accuracy: 0.519531
Epoch  1, Batch 132 -Loss:  2331.9197 Validation Accuracy: 0.527344
Epoch  1, Batch 133 -Loss:  3162.2712 Validation Accuracy: 0.527344
Epoch  1, Batch 134 -Loss:  4153.3311 Validation Accuracy: 0.523438
Epoch  1, Batch 135 -Loss:  3935.5654 Validation Accuracy: 0.519531
Epoch  1, Batch 136 -Loss:  3019.2041 Validation

Epoch  1, Batch 243 -Loss:  2767.2944 Validation Accuracy: 0.539062
Epoch  1, Batch 244 -Loss:  2067.7231 Validation Accuracy: 0.542969
Epoch  1, Batch 245 -Loss:  1392.7354 Validation Accuracy: 0.531250
Epoch  1, Batch 246 -Loss:  2675.9482 Validation Accuracy: 0.527344
Epoch  1, Batch 247 -Loss:  2373.1816 Validation Accuracy: 0.539062
Epoch  1, Batch 248 -Loss:  2348.9639 Validation Accuracy: 0.539062
Epoch  1, Batch 249 -Loss:  2242.2256 Validation Accuracy: 0.539062
Epoch  1, Batch 250 -Loss:  1581.3671 Validation Accuracy: 0.542969
Epoch  1, Batch 251 -Loss:  2335.9155 Validation Accuracy: 0.542969
Epoch  1, Batch 252 -Loss:  1665.4009 Validation Accuracy: 0.546875
Epoch  1, Batch 253 -Loss:  2156.9771 Validation Accuracy: 0.539062
Epoch  1, Batch 254 -Loss:  2202.1423 Validation Accuracy: 0.535156
Epoch  1, Batch 255 -Loss:  2376.2031 Validation Accuracy: 0.539062
Epoch  1, Batch 256 -Loss:  1789.4523 Validation Accuracy: 0.527344
Epoch  1, Batch 257 -Loss:  2397.8315 Validation

Epoch  1, Batch 364 -Loss:  1314.9768 Validation Accuracy: 0.507812
Epoch  1, Batch 365 -Loss:  1506.2972 Validation Accuracy: 0.503906
Epoch  1, Batch 366 -Loss:  1417.6766 Validation Accuracy: 0.496094
Epoch  1, Batch 367 -Loss:  1373.5408 Validation Accuracy: 0.492188
Epoch  1, Batch 368 -Loss:  1600.2961 Validation Accuracy: 0.488281
Epoch  1, Batch 369 -Loss:  2169.4934 Validation Accuracy: 0.492188
Epoch  1, Batch 370 -Loss:  1351.3622 Validation Accuracy: 0.492188
Epoch  1, Batch 371 -Loss:  1412.5446 Validation Accuracy: 0.496094
Epoch  1, Batch 372 -Loss:   927.8082 Validation Accuracy: 0.503906
Epoch  1, Batch 373 -Loss:  1739.9696 Validation Accuracy: 0.496094
Epoch  1, Batch 374 -Loss:  1524.0974 Validation Accuracy: 0.507812
Epoch  1, Batch 375 -Loss:  1700.0250 Validation Accuracy: 0.519531
Epoch  1, Batch 376 -Loss:  1875.7798 Validation Accuracy: 0.519531
Epoch  1, Batch 377 -Loss:  1404.0858 Validation Accuracy: 0.515625
Epoch  1, Batch 378 -Loss:  1187.4604 Validation

Epoch  2, Batch  56 -Loss:  1238.2882 Validation Accuracy: 0.511719
Epoch  2, Batch  57 -Loss:  1300.4543 Validation Accuracy: 0.507812
Epoch  2, Batch  58 -Loss:  1287.6328 Validation Accuracy: 0.515625
Epoch  2, Batch  59 -Loss:   943.1017 Validation Accuracy: 0.511719
Epoch  2, Batch  60 -Loss:  1328.4392 Validation Accuracy: 0.511719
Epoch  2, Batch  61 -Loss:  1322.9138 Validation Accuracy: 0.515625
Epoch  2, Batch  62 -Loss:  1225.1838 Validation Accuracy: 0.511719
Epoch  2, Batch  63 -Loss:  1493.8657 Validation Accuracy: 0.511719
Epoch  2, Batch  64 -Loss:  1166.8772 Validation Accuracy: 0.503906
Epoch  2, Batch  65 -Loss:  1415.4694 Validation Accuracy: 0.500000
Epoch  2, Batch  66 -Loss:  1616.0618 Validation Accuracy: 0.503906
Epoch  2, Batch  67 -Loss:  1673.8201 Validation Accuracy: 0.496094
Epoch  2, Batch  68 -Loss:  1015.1953 Validation Accuracy: 0.496094
Epoch  2, Batch  69 -Loss:  1206.8005 Validation Accuracy: 0.496094
Epoch  2, Batch  70 -Loss:  1247.6226 Validation

KeyboardInterrupt: 