In [10]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('.', one_hot=True, reshape=False)

import tensorflow as tf
import numpy as np

Extracting ./train-images-idx3-ubyte.gz
Extracting ./train-labels-idx1-ubyte.gz
Extracting ./t10k-images-idx3-ubyte.gz
Extracting ./t10k-labels-idx1-ubyte.gz


In [11]:
learning_rate = 0.00001
epochs = 2
batch_size = 32

test_valid_size = 256
n_classes = 10
dropout = 0.75

weights = {'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
          'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),
          'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])), 
          'out': tf.Variable(tf.random_normal([1024, n_classes]))}

biases = {'bc1': tf.Variable(tf.random_normal([32])),
         'bc2': tf.Variable(tf.random_normal([64])),
         'bd1': tf.Variable(tf.random_normal([1024])),
         'out': tf.Variable(tf.random_normal([n_classes]))}

In [12]:
with tf.Session() as sess:
    print(weights['wd1'].get_shape())

(3136, 1024)


In [13]:
def conv2d(x, W, b, strides=1):
    x = tf.nn.conv2d(x, W, [1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

In [14]:
def maxpool2d(x, k=2):
    return tf.nn.max_pool(x, [1, k, k, 1], [1, k, k, 1], padding='SAME')    

In [15]:
def conv_net(x, weights, biases, dropout):
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    conv1 = maxpool2d(conv1, k=2)
    
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2 = maxpool2d(conv2, k=2)
    
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    fc1 = tf.nn.dropout(fc1, dropout)
    
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

In [16]:
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)

logits = conv_net(x, weights, biases, dropout)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))

optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [17]:
with tf.Session() as sess:
    print(accuracy.get_shape())

()


In [None]:
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    
    for epoch in range(epochs):
        total_batch = int(np.ceil(mnist.train.num_examples/batch_size))
        for batch in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x:batch_x, y:batch_y, keep_prob:dropout})
            
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
            valid_acc = sess.run(accuracy, feed_dict={
                x: mnist.validation.images[:test_valid_size],
                y: mnist.validation.labels[:test_valid_size],
                keep_prob: 1.})
            
            print('Epoch {:>2}, Batch {:>3} - Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(
                epoch + 1,
                batch + 1,
                loss,
                valid_acc))
            
    test_acc = sess.run(accuracy, feed_dict={
        x: mnist.test.images[:test_valid_size],
        y: mnist.test.labels[:test_valid_size],
        keep_prob: 1.})
    print('Testing Accuracy: {}'.format(test_acc))
            

Epoch  1, Batch   1 - Loss: 70357.4375 Validation Accuracy: 0.113281
Epoch  1, Batch   2 - Loss: 72749.7344 Validation Accuracy: 0.105469
Epoch  1, Batch   3 - Loss: 53049.8594 Validation Accuracy: 0.089844
Epoch  1, Batch   4 - Loss: 54644.0820 Validation Accuracy: 0.117188
Epoch  1, Batch   5 - Loss: 43548.2656 Validation Accuracy: 0.128906
Epoch  1, Batch   6 - Loss: 42010.3633 Validation Accuracy: 0.105469
Epoch  1, Batch   7 - Loss: 53173.7734 Validation Accuracy: 0.125000
Epoch  1, Batch   8 - Loss: 44550.2812 Validation Accuracy: 0.117188
Epoch  1, Batch   9 - Loss: 41672.6172 Validation Accuracy: 0.128906
Epoch  1, Batch  10 - Loss: 34617.9102 Validation Accuracy: 0.121094
Epoch  1, Batch  11 - Loss: 48254.4531 Validation Accuracy: 0.093750
Epoch  1, Batch  12 - Loss: 43771.9375 Validation Accuracy: 0.101562
Epoch  1, Batch  13 - Loss: 35973.0625 Validation Accuracy: 0.144531
Epoch  1, Batch  14 - Loss: 40766.9766 Validation Accuracy: 0.171875
Epoch  1, Batch  15 - Loss: 33327.

Epoch  1, Batch 120 - Loss:  5995.1821 Validation Accuracy: 0.460938
Epoch  1, Batch 121 - Loss: 11099.9707 Validation Accuracy: 0.406250
Epoch  1, Batch 122 - Loss: 11144.8320 Validation Accuracy: 0.425781
Epoch  1, Batch 123 - Loss: 10902.4521 Validation Accuracy: 0.425781
Epoch  1, Batch 124 - Loss: 10137.2852 Validation Accuracy: 0.507812
Epoch  1, Batch 125 - Loss:  7950.5723 Validation Accuracy: 0.417969
Epoch  1, Batch 126 - Loss:  9274.0332 Validation Accuracy: 0.386719
Epoch  1, Batch 127 - Loss:  6745.8818 Validation Accuracy: 0.445312
Epoch  1, Batch 128 - Loss:  9712.1172 Validation Accuracy: 0.445312
Epoch  1, Batch 129 - Loss:  6007.8218 Validation Accuracy: 0.457031
Epoch  1, Batch 130 - Loss: 10451.2617 Validation Accuracy: 0.464844
Epoch  1, Batch 131 - Loss:  8678.7617 Validation Accuracy: 0.390625
Epoch  1, Batch 132 - Loss: 15610.7959 Validation Accuracy: 0.386719
Epoch  1, Batch 133 - Loss:  4608.5996 Validation Accuracy: 0.390625
Epoch  1, Batch 134 - Loss:  7618.

Epoch  1, Batch 239 - Loss:  5482.6655 Validation Accuracy: 0.566406
Epoch  1, Batch 240 - Loss:  7892.3037 Validation Accuracy: 0.582031
Epoch  1, Batch 241 - Loss:  5769.8535 Validation Accuracy: 0.562500
Epoch  1, Batch 242 - Loss:  6840.5928 Validation Accuracy: 0.496094
Epoch  1, Batch 243 - Loss:  5899.7783 Validation Accuracy: 0.519531
Epoch  1, Batch 244 - Loss:  5156.7886 Validation Accuracy: 0.554688
Epoch  1, Batch 245 - Loss:  5454.1377 Validation Accuracy: 0.531250
Epoch  1, Batch 246 - Loss:  4238.6704 Validation Accuracy: 0.527344
Epoch  1, Batch 247 - Loss:  3948.4180 Validation Accuracy: 0.480469
Epoch  1, Batch 248 - Loss:  6073.8320 Validation Accuracy: 0.535156
Epoch  1, Batch 249 - Loss:  9563.6025 Validation Accuracy: 0.500000
Epoch  1, Batch 250 - Loss:  6355.2544 Validation Accuracy: 0.496094
Epoch  1, Batch 251 - Loss:  8065.4102 Validation Accuracy: 0.511719
Epoch  1, Batch 252 - Loss:  4762.4639 Validation Accuracy: 0.519531
Epoch  1, Batch 253 - Loss:  4507.

Epoch  1, Batch 358 - Loss:  3573.6387 Validation Accuracy: 0.507812
Epoch  1, Batch 359 - Loss:  3199.9224 Validation Accuracy: 0.554688
Epoch  1, Batch 360 - Loss:  3291.7583 Validation Accuracy: 0.546875
Epoch  1, Batch 361 - Loss:  3113.2068 Validation Accuracy: 0.605469
Epoch  1, Batch 362 - Loss:  5675.2119 Validation Accuracy: 0.570312
Epoch  1, Batch 363 - Loss:  1305.6846 Validation Accuracy: 0.550781
Epoch  1, Batch 364 - Loss:  3044.2336 Validation Accuracy: 0.570312
Epoch  1, Batch 365 - Loss:  4606.3218 Validation Accuracy: 0.570312
Epoch  1, Batch 366 - Loss:  3449.5198 Validation Accuracy: 0.570312
Epoch  1, Batch 367 - Loss:  1377.3899 Validation Accuracy: 0.585938
Epoch  1, Batch 368 - Loss:  6606.9116 Validation Accuracy: 0.566406
Epoch  1, Batch 369 - Loss:  6837.0366 Validation Accuracy: 0.523438
Epoch  1, Batch 370 - Loss:  4643.5464 Validation Accuracy: 0.558594
Epoch  1, Batch 371 - Loss:  4435.7441 Validation Accuracy: 0.554688
Epoch  1, Batch 372 - Loss:  4386.

Epoch  1, Batch 477 - Loss:  2968.1765 Validation Accuracy: 0.582031
Epoch  1, Batch 478 - Loss:  4094.3032 Validation Accuracy: 0.613281
Epoch  1, Batch 479 - Loss:  1810.0642 Validation Accuracy: 0.585938
Epoch  1, Batch 480 - Loss:  3768.9602 Validation Accuracy: 0.574219
Epoch  1, Batch 481 - Loss:  2793.3335 Validation Accuracy: 0.582031
Epoch  1, Batch 482 - Loss:  3480.7996 Validation Accuracy: 0.597656
Epoch  1, Batch 483 - Loss:  3066.3386 Validation Accuracy: 0.562500
Epoch  1, Batch 484 - Loss:  2544.8979 Validation Accuracy: 0.578125
Epoch  1, Batch 485 - Loss:  2806.2095 Validation Accuracy: 0.570312
Epoch  1, Batch 486 - Loss:  2737.8853 Validation Accuracy: 0.574219
Epoch  1, Batch 487 - Loss:  3843.7666 Validation Accuracy: 0.601562
Epoch  1, Batch 488 - Loss:  3205.2993 Validation Accuracy: 0.550781
Epoch  1, Batch 489 - Loss:  3278.8394 Validation Accuracy: 0.632812
Epoch  1, Batch 490 - Loss:  4297.0078 Validation Accuracy: 0.597656
Epoch  1, Batch 491 - Loss:  2439.

Epoch  1, Batch 596 - Loss:  3908.0186 Validation Accuracy: 0.589844
Epoch  1, Batch 597 - Loss:  2758.7820 Validation Accuracy: 0.613281
Epoch  1, Batch 598 - Loss:  1701.2026 Validation Accuracy: 0.617188
Epoch  1, Batch 599 - Loss:  1229.5245 Validation Accuracy: 0.640625
Epoch  1, Batch 600 - Loss:  2945.5386 Validation Accuracy: 0.628906
Epoch  1, Batch 601 - Loss:  2073.2747 Validation Accuracy: 0.589844
Epoch  1, Batch 602 - Loss:  3454.5566 Validation Accuracy: 0.578125
Epoch  1, Batch 603 - Loss:  2174.2646 Validation Accuracy: 0.605469
Epoch  1, Batch 604 - Loss:  2100.9053 Validation Accuracy: 0.593750
Epoch  1, Batch 605 - Loss:  2278.1655 Validation Accuracy: 0.613281
Epoch  1, Batch 606 - Loss:  2821.1575 Validation Accuracy: 0.640625
Epoch  1, Batch 607 - Loss:  2262.7505 Validation Accuracy: 0.656250
Epoch  1, Batch 608 - Loss:  2963.7778 Validation Accuracy: 0.617188
Epoch  1, Batch 609 - Loss:  4648.1582 Validation Accuracy: 0.640625
Epoch  1, Batch 610 - Loss:  2758.

Epoch  1, Batch 715 - Loss:  3989.6604 Validation Accuracy: 0.664062
Epoch  1, Batch 716 - Loss:  1967.9590 Validation Accuracy: 0.605469
Epoch  1, Batch 717 - Loss:  1620.0922 Validation Accuracy: 0.589844
Epoch  1, Batch 718 - Loss:  2455.7568 Validation Accuracy: 0.601562
Epoch  1, Batch 719 - Loss:  2185.5122 Validation Accuracy: 0.625000
Epoch  1, Batch 720 - Loss:  1761.3917 Validation Accuracy: 0.667969
Epoch  1, Batch 721 - Loss:  2224.8284 Validation Accuracy: 0.601562
Epoch  1, Batch 722 - Loss:  1323.5779 Validation Accuracy: 0.648438
Epoch  1, Batch 723 - Loss:  1452.6158 Validation Accuracy: 0.636719
Epoch  1, Batch 724 - Loss:  2671.6440 Validation Accuracy: 0.628906
Epoch  1, Batch 725 - Loss:  3090.8628 Validation Accuracy: 0.656250
Epoch  1, Batch 726 - Loss:  1685.1052 Validation Accuracy: 0.609375
Epoch  1, Batch 727 - Loss:  2055.5767 Validation Accuracy: 0.636719
Epoch  1, Batch 728 - Loss:  3056.0537 Validation Accuracy: 0.621094
Epoch  1, Batch 729 - Loss:  1739.

Epoch  1, Batch 834 - Loss:  3194.8906 Validation Accuracy: 0.613281
Epoch  1, Batch 835 - Loss:  3780.2231 Validation Accuracy: 0.625000
Epoch  1, Batch 836 - Loss:  2326.4448 Validation Accuracy: 0.625000
Epoch  1, Batch 837 - Loss:  2256.6597 Validation Accuracy: 0.660156
Epoch  1, Batch 838 - Loss:  2142.4614 Validation Accuracy: 0.656250
Epoch  1, Batch 839 - Loss:  1909.9790 Validation Accuracy: 0.656250
Epoch  1, Batch 840 - Loss:  1401.8104 Validation Accuracy: 0.660156
Epoch  1, Batch 841 - Loss:  1400.9575 Validation Accuracy: 0.628906
Epoch  1, Batch 842 - Loss:  1669.6250 Validation Accuracy: 0.652344
Epoch  1, Batch 843 - Loss:  2062.2712 Validation Accuracy: 0.671875
Epoch  1, Batch 844 - Loss:  2949.8047 Validation Accuracy: 0.644531
Epoch  1, Batch 845 - Loss:  3071.9956 Validation Accuracy: 0.660156
Epoch  1, Batch 846 - Loss:  2785.2830 Validation Accuracy: 0.656250
Epoch  1, Batch 847 - Loss:  2186.5039 Validation Accuracy: 0.671875
Epoch  1, Batch 848 - Loss:  2370.

Epoch  1, Batch 953 - Loss:  1273.0081 Validation Accuracy: 0.636719
Epoch  1, Batch 954 - Loss:   734.6198 Validation Accuracy: 0.691406
Epoch  1, Batch 955 - Loss:  2334.9312 Validation Accuracy: 0.679688
Epoch  1, Batch 956 - Loss:  1555.0774 Validation Accuracy: 0.656250
Epoch  1, Batch 957 - Loss:  3167.9946 Validation Accuracy: 0.695312
Epoch  1, Batch 958 - Loss:  2216.8604 Validation Accuracy: 0.652344
Epoch  1, Batch 959 - Loss:  1657.8423 Validation Accuracy: 0.679688
Epoch  1, Batch 960 - Loss:  1486.4958 Validation Accuracy: 0.664062
Epoch  1, Batch 961 - Loss:  2236.4570 Validation Accuracy: 0.691406
Epoch  1, Batch 962 - Loss:  1329.6149 Validation Accuracy: 0.675781
Epoch  1, Batch 963 - Loss:  1314.5593 Validation Accuracy: 0.636719
Epoch  1, Batch 964 - Loss:  1452.6979 Validation Accuracy: 0.667969
Epoch  1, Batch 965 - Loss:   671.2794 Validation Accuracy: 0.652344
Epoch  1, Batch 966 - Loss:  1332.7349 Validation Accuracy: 0.691406
Epoch  1, Batch 967 - Loss:  2857.

Epoch  1, Batch 1071 - Loss:  1370.0466 Validation Accuracy: 0.687500
Epoch  1, Batch 1072 - Loss:  1985.7285 Validation Accuracy: 0.656250
Epoch  1, Batch 1073 - Loss:  2259.6970 Validation Accuracy: 0.687500
Epoch  1, Batch 1074 - Loss:  1449.2759 Validation Accuracy: 0.667969
Epoch  1, Batch 1075 - Loss:  1027.5853 Validation Accuracy: 0.675781
Epoch  1, Batch 1076 - Loss:  2233.5381 Validation Accuracy: 0.699219
Epoch  1, Batch 1077 - Loss:  1965.3839 Validation Accuracy: 0.687500
Epoch  1, Batch 1078 - Loss:  1685.1379 Validation Accuracy: 0.703125
Epoch  1, Batch 1079 - Loss:  1898.0171 Validation Accuracy: 0.679688
Epoch  1, Batch 1080 - Loss:  1566.4343 Validation Accuracy: 0.691406
Epoch  1, Batch 1081 - Loss:  2255.1909 Validation Accuracy: 0.683594
Epoch  1, Batch 1082 - Loss:  1542.3884 Validation Accuracy: 0.687500
Epoch  1, Batch 1083 - Loss:  2448.4697 Validation Accuracy: 0.683594
Epoch  1, Batch 1084 - Loss:  1362.4171 Validation Accuracy: 0.687500
Epoch  1, Batch 1085

Epoch  1, Batch 1189 - Loss:  1461.0414 Validation Accuracy: 0.664062
Epoch  1, Batch 1190 - Loss:  1082.5221 Validation Accuracy: 0.687500
Epoch  1, Batch 1191 - Loss:  3531.1523 Validation Accuracy: 0.679688
Epoch  1, Batch 1192 - Loss:  1498.2539 Validation Accuracy: 0.726562
Epoch  1, Batch 1193 - Loss:  1040.4634 Validation Accuracy: 0.675781
Epoch  1, Batch 1194 - Loss:  2052.5476 Validation Accuracy: 0.648438
Epoch  1, Batch 1195 - Loss:  2429.9724 Validation Accuracy: 0.691406
Epoch  1, Batch 1196 - Loss:  1147.6735 Validation Accuracy: 0.703125
Epoch  1, Batch 1197 - Loss:  1026.2040 Validation Accuracy: 0.703125
Epoch  1, Batch 1198 - Loss:   683.1639 Validation Accuracy: 0.660156
Epoch  1, Batch 1199 - Loss:   772.7838 Validation Accuracy: 0.679688
Epoch  1, Batch 1200 - Loss:  1308.6077 Validation Accuracy: 0.648438
Epoch  1, Batch 1201 - Loss:   936.4854 Validation Accuracy: 0.664062
Epoch  1, Batch 1202 - Loss:  1709.3707 Validation Accuracy: 0.683594
Epoch  1, Batch 1203