In [1]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os
import tensorflow as tf

# to make this notebook's output stable across runs
def reset_graph(seed=42): 
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

reset_graph()

In [2]:
# load data: digits 5 to 9, but still label with 0 to 4, 
# because TensorFlow expects label's integers from 0 to n_classes-1.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

X_train2_full = mnist.train.images[mnist.train.labels >= 5]
y_train2_full = mnist.train.labels[mnist.train.labels >= 5] - 5
X_valid2_full = mnist.validation.images[mnist.validation.labels >= 5]
y_valid2_full = mnist.validation.labels[mnist.validation.labels >= 5] - 5
X_test2 = mnist.test.images[mnist.test.labels >= 5]
y_test2 = mnist.test.labels[mnist.test.labels >= 5] - 5

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [3]:
# we want to keep only 100 instances per class in the training set 
# and let's keep only 30 instances per class in the validation set
# tesing set is already loaded above
def sample_n_instances_per_class(X, y, n=100):
    Xs, ys = [], []
    for label in np.unique(y):
        idx = (y == label)
        Xc = X[idx][:n]
        yc = y[idx][:n]
        Xs.append(Xc)
        ys.append(yc)
    return np.concatenate(Xs), np.concatenate(ys)

X_train2, y_train2 = sample_n_instances_per_class(X_train2_full, y_train2_full, n=100)
X_valid2, y_valid2 = sample_n_instances_per_class(X_valid2_full, y_valid2_full, n=30)

In [5]:
learning_rate = 0.001
batch_size = 256

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('/Team48_HW2.ckpt.meta', clear_devices=True)
    new_saver.restore(sess, '/Team48_HW2.ckpt')
    xs = tf.get_default_graph().get_tensor_by_name("xs:0")
    ys = tf.get_default_graph().get_tensor_by_name("ys:0")
    y = tf.get_default_graph().get_tensor_by_name("y:0")
    #logits = y.op.inputs[0]
    accuracy = tf.get_default_graph().get_tensor_by_name("accuracy:0")
    output_layer_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="logits")
    loss = tf.get_default_graph().get_tensor_by_name("loss:0")
    optimizer = tf.train.AdamOptimizer(learning_rate, name="optimizer")
    training_op = optimizer.minimize(loss, var_list=output_layer_vars)
    
    tf.global_variables_initializer().run() 
    
    for i in range(1000):
        sess.run(training_op, feed_dict={xs: X_train2, ys: y_train2})
        val_loss = sess.run(loss, feed_dict={xs: X_valid2, ys: y_valid2})
        accu = sess.run(accuracy, feed_dict={xs: X_test2, ys: y_test2})
        print(i,"Validation loss:", val_loss, "Accuracy:", accu)

INFO:tensorflow:Restoring parameters from /Team48_HW2.ckpt
0 Validation loss: 236.956 Accuracy: 0.277515
1 Validation loss: 235.41 Accuracy: 0.302407
2 Validation loss: 233.882 Accuracy: 0.326887
3 Validation loss: 232.39 Accuracy: 0.346637
4 Validation loss: 230.945 Accuracy: 0.36906
5 Validation loss: 229.55 Accuracy: 0.387575
6 Validation loss: 228.209 Accuracy: 0.402798
7 Validation loss: 226.922 Accuracy: 0.417198
8 Validation loss: 225.692 Accuracy: 0.429953
9 Validation loss: 224.516 Accuracy: 0.43777
10 Validation loss: 223.393 Accuracy: 0.446616
11 Validation loss: 222.318 Accuracy: 0.452993
12 Validation loss: 221.286 Accuracy: 0.462045
13 Validation loss: 220.29 Accuracy: 0.471302
14 Validation loss: 219.323 Accuracy: 0.480354
15 Validation loss: 218.378 Accuracy: 0.491874
16 Validation loss: 217.449 Accuracy: 0.504012
17 Validation loss: 216.531 Accuracy: 0.50936
18 Validation loss: 215.62 Accuracy: 0.519852
19 Validation loss: 214.714 Accuracy: 0.529315
20 Validation loss:

178 Validation loss: 173.192 Accuracy: 0.785641
179 Validation loss: 173.138 Accuracy: 0.786669
180 Validation loss: 173.085 Accuracy: 0.787287
181 Validation loss: 173.032 Accuracy: 0.787081
182 Validation loss: 172.98 Accuracy: 0.787287
183 Validation loss: 172.928 Accuracy: 0.787287
184 Validation loss: 172.877 Accuracy: 0.787287
185 Validation loss: 172.826 Accuracy: 0.787904
186 Validation loss: 172.776 Accuracy: 0.787904
187 Validation loss: 172.726 Accuracy: 0.787904
188 Validation loss: 172.677 Accuracy: 0.788315
189 Validation loss: 172.628 Accuracy: 0.788727
190 Validation loss: 172.58 Accuracy: 0.788727
191 Validation loss: 172.532 Accuracy: 0.788521
192 Validation loss: 172.485 Accuracy: 0.788727
193 Validation loss: 172.438 Accuracy: 0.788727
194 Validation loss: 172.391 Accuracy: 0.789549
195 Validation loss: 172.345 Accuracy: 0.790166
196 Validation loss: 172.3 Accuracy: 0.790578
197 Validation loss: 172.255 Accuracy: 0.790784
198 Validation loss: 172.21 Accuracy: 0.7907

350 Validation loss: 168.265 Accuracy: 0.810121
351 Validation loss: 168.249 Accuracy: 0.810121
352 Validation loss: 168.232 Accuracy: 0.810327
353 Validation loss: 168.216 Accuracy: 0.810121
354 Validation loss: 168.2 Accuracy: 0.809916
355 Validation loss: 168.184 Accuracy: 0.80971
356 Validation loss: 168.168 Accuracy: 0.810121
357 Validation loss: 168.152 Accuracy: 0.810121
358 Validation loss: 168.136 Accuracy: 0.810739
359 Validation loss: 168.12 Accuracy: 0.81115
360 Validation loss: 168.103 Accuracy: 0.81115
361 Validation loss: 168.087 Accuracy: 0.811561
362 Validation loss: 168.072 Accuracy: 0.811973
363 Validation loss: 168.056 Accuracy: 0.811973
364 Validation loss: 168.04 Accuracy: 0.812178
365 Validation loss: 168.024 Accuracy: 0.811973
366 Validation loss: 168.008 Accuracy: 0.811973
367 Validation loss: 167.992 Accuracy: 0.811973
368 Validation loss: 167.976 Accuracy: 0.811973
369 Validation loss: 167.96 Accuracy: 0.812178
370 Validation loss: 167.945 Accuracy: 0.812384


524 Validation loss: 165.916 Accuracy: 0.822053
525 Validation loss: 165.905 Accuracy: 0.822053
526 Validation loss: 165.895 Accuracy: 0.822053
527 Validation loss: 165.884 Accuracy: 0.822053
528 Validation loss: 165.874 Accuracy: 0.822053
529 Validation loss: 165.863 Accuracy: 0.822053
530 Validation loss: 165.853 Accuracy: 0.822053
531 Validation loss: 165.843 Accuracy: 0.822464
532 Validation loss: 165.832 Accuracy: 0.822259
533 Validation loss: 165.822 Accuracy: 0.822259
534 Validation loss: 165.812 Accuracy: 0.822259
535 Validation loss: 165.801 Accuracy: 0.822464
536 Validation loss: 165.791 Accuracy: 0.82267
537 Validation loss: 165.781 Accuracy: 0.822464
538 Validation loss: 165.771 Accuracy: 0.822464
539 Validation loss: 165.761 Accuracy: 0.822876
540 Validation loss: 165.75 Accuracy: 0.82267
541 Validation loss: 165.74 Accuracy: 0.82267
542 Validation loss: 165.73 Accuracy: 0.822876
543 Validation loss: 165.72 Accuracy: 0.82267
544 Validation loss: 165.71 Accuracy: 0.822876
5

702 Validation loss: 164.8 Accuracy: 0.825962
703 Validation loss: 164.802 Accuracy: 0.826373
704 Validation loss: 164.803 Accuracy: 0.826373
705 Validation loss: 164.804 Accuracy: 0.826167
706 Validation loss: 164.805 Accuracy: 0.826373
707 Validation loss: 164.806 Accuracy: 0.826579
708 Validation loss: 164.807 Accuracy: 0.826579
709 Validation loss: 164.808 Accuracy: 0.826784
710 Validation loss: 164.808 Accuracy: 0.826579
711 Validation loss: 164.809 Accuracy: 0.826167
712 Validation loss: 164.809 Accuracy: 0.826373
713 Validation loss: 164.809 Accuracy: 0.826167
714 Validation loss: 164.809 Accuracy: 0.826167
715 Validation loss: 164.809 Accuracy: 0.826373
716 Validation loss: 164.809 Accuracy: 0.826579
717 Validation loss: 164.809 Accuracy: 0.826373
718 Validation loss: 164.809 Accuracy: 0.826373
719 Validation loss: 164.808 Accuracy: 0.826373
720 Validation loss: 164.808 Accuracy: 0.826373
721 Validation loss: 164.807 Accuracy: 0.826579
722 Validation loss: 164.807 Accuracy: 0.8

879 Validation loss: 164.464 Accuracy: 0.82987
880 Validation loss: 164.46 Accuracy: 0.82987
881 Validation loss: 164.457 Accuracy: 0.82987
882 Validation loss: 164.454 Accuracy: 0.829665
883 Validation loss: 164.45 Accuracy: 0.829665
884 Validation loss: 164.447 Accuracy: 0.829665
885 Validation loss: 164.443 Accuracy: 0.829665
886 Validation loss: 164.44 Accuracy: 0.829665
887 Validation loss: 164.436 Accuracy: 0.829459
888 Validation loss: 164.433 Accuracy: 0.829459
889 Validation loss: 164.429 Accuracy: 0.829459
890 Validation loss: 164.426 Accuracy: 0.829459
891 Validation loss: 164.422 Accuracy: 0.829459
892 Validation loss: 164.419 Accuracy: 0.829253
893 Validation loss: 164.415 Accuracy: 0.829459
894 Validation loss: 164.412 Accuracy: 0.829459
895 Validation loss: 164.408 Accuracy: 0.829459
896 Validation loss: 164.405 Accuracy: 0.829459
897 Validation loss: 164.401 Accuracy: 0.829459
898 Validation loss: 164.397 Accuracy: 0.829665
899 Validation loss: 164.394 Accuracy: 0.82966