## Download the data

In [1]:
import numpy as np
import os
import tensorflow as tf

###### Do not modify here ###### 

# to make this notebook's output stable across runs
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

reset_graph()

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")

# training on MNIST but only on digits 0 to 4
X_train1 = mnist.train.images[mnist.train.labels < 5]
y_train1 = mnist.train.labels[mnist.train.labels < 5]
X_valid1 = mnist.validation.images[mnist.validation.labels < 5]
y_valid1 = mnist.validation.labels[mnist.validation.labels < 5]
X_test1 = mnist.test.images[mnist.test.labels < 5]
y_test1 = mnist.test.labels[mnist.test.labels < 5]
###### Do not modify here ###### 


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


## Get the data shape

In [2]:
print('Train', mnist.train.num_examples, 
      'Validation', mnist.validation.num_examples,
      'Test', mnist.test.num_examples)

print('Train images :', X_train1.shape,
      'Labels :', y_train1.shape)

print('Validation images :', X_valid1.shape,
      'Labels :', y_valid1.shape)

print('Test images :', X_test1.shape,
      'Labels :', y_test1.shape)


Train 55000 Validation 5000 Test 10000
Train images : (28038, 784) Labels : (28038,)
Validation images : (2558, 784) Labels : (2558,)
Test images : (5139, 784) Labels : (5139,)


## One-hot encoded y values

In [3]:
y_train_one_hot = np.zeros((len(y_train1), 5))
y_train_one_hot[np.arange(len(y_train1)), y_train1] = 1

y_valid_one_hot = np.zeros((len(y_valid1), 5))
y_valid_one_hot[np.arange(len(y_valid1)), y_valid1] = 1

y_test_one_hot = np.zeros((len(y_test1), 5))
y_test_one_hot[np.arange(len(y_test1)), y_test1] = 1

print('Check some examples of one-hot encoding:')
for i in range(10):
    print(y_train1[i], y_train_one_hot[i], '\n')

print ("Train Label: ", y_train_one_hot.shape)
print ("Validation Label: ", y_valid_one_hot.shape)
print ("Test Label: ", y_test_one_hot.shape)


Check some examples of one-hot encoding:
3 [ 0.  0.  0.  1.  0.] 

4 [ 0.  0.  0.  0.  1.] 

1 [ 0.  1.  0.  0.  0.] 

1 [ 0.  1.  0.  0.  0.] 

0 [ 1.  0.  0.  0.  0.] 

0 [ 1.  0.  0.  0.  0.] 

3 [ 0.  0.  0.  1.  0.] 

1 [ 0.  1.  0.  0.  0.] 

2 [ 0.  0.  1.  0.  0.] 

0 [ 1.  0.  0.  0.  0.] 

Train Label:  (28038, 5)
Validation Label:  (2558, 5)
Test Label:  (5139, 5)


### Construct the layer function

In [4]:
def layer(output_dim, input_dim, inputs, layer, activation=None):
#     W = tf.Variable(tf.random_normal([input_dim, output_dim]))
#     b = tf.get_variable(tf.random_normal([1, output_dim]))
#     W = tf.Variable(tf.truncated_normal(shape = (input_dim, output_dim),mean = 0, stddev = 0.1))
#     b = tf.Variable(tf.zeros(output_dim))

    W = tf.get_variable("W" + layer, shape=[input_dim, output_dim], initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable("b" + layer, shape=[1, output_dim], initializer=tf.contrib.layers.xavier_initializer())

    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs


## contruct the input layer

In [5]:
x = tf.placeholder("float", [None, 784])


## contruct the hidden layers

In [6]:
h1 = layer(output_dim=128, input_dim=784, inputs=x, layer="h1", activation=tf.nn.elu)
h2 = layer(output_dim=128, input_dim=128, inputs=h1, layer="h2", activation=tf.nn.elu)
h3 = layer(output_dim=128, input_dim=128, inputs=h2, layer="h3", activation=tf.nn.elu)
h4 = layer(output_dim=128, input_dim=128, inputs=h3, layer="h4", activation=tf.nn.elu)
h5 = layer(output_dim=128, input_dim=128, inputs=h4, layer="h5", activation=tf.nn.elu)


## contruct the output layer

In [7]:
y_predict = layer(output_dim=5, input_dim=128, inputs=h5, layer="output", activation=None)
y_label = tf.placeholder("float", [None, 5])


### start training

In [23]:
# hyper parameters
lr = 0.005
batch_size = 256
epochs = 350
saturate_limit = 20  # for applying early stopping

# parameters
iterations = int(X_train1.shape[0] / batch_size)
saturate_count = 0
best_acc = 0.
best_epoch = -1

# set the loss function; tf.nn.sparse_softmax_cross_entropy_with_logits required by the homework spec
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_label))

# set the optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=lr) \
                    .minimize(loss_function)

correct_prediction = tf.equal(tf.argmax(y_label, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for e in range(epochs):
        for i in range(iterations):
            if(i==0):
                batch_x = X_train1[:batch_size]
                batch_y = y_train_one_hot[:batch_size]
            else:
                batch_x = X_train1[i * batch_size : (i + 1) * batch_size]
                batch_y = y_train_one_hot[i * batch_size : (i + 1) * batch_size]

            sess.run(optimizer, feed_dict={x: batch_x, y_label: batch_y})
        
        # validate
        loss, acc = sess.run([loss_function, accuracy], feed_dict={x: X_valid1, y_label: y_valid_one_hot})
        
#         print('-' * 40)
#         print('Epoch:', e)
#         print('Loss:', loss)
#         print('Accurancy:', acc)
        
        if best_acc < acc:
#             print('best_acc occurs!')
            best_acc = acc
            best_epoch = e
            saturate_count = 0
        else:
            saturate_count += 1
#             print('increment saturate_count:', saturate_count)
#             print('current best_acc:', best_acc)
            
            if saturate_count >= saturate_limit:  # stop if saturate
#                 print('SATURATE!')
                break
                
#         print('-' * 40, '\n')
        
        
    print('*' * 60)
    print('Best epoch:', best_epoch)
    print('Best accurancy:', best_acc)
    print('*' * 60, '\n')
    
    print('=' * 60)
    print("Accurancy", sess.run(accuracy,
                               feed_dict={x: X_test1,
                                          y_label: y_test_one_hot}))
    print('=' * 60, '\n')
    
    #get prediction results
    predict_result = sess.run(tf.argmax(y_predict, 1), 
                                 feed_dict={x: X_test1})
    
    # count the true or false num of label
    recall_matrix = np.zeros((5, 2))
    precision_matrix = np.zeros((5, 2))
    compare_result = sess.run(tf.equal(predict_result, y_test1))
   
    for i in range(len(predict_result)):
        if(compare_result[i]==True):
            recall_matrix[y_test1[i]][0] += 1
            precision_matrix[predict_result[i]][0] += 1
        else:
            recall_matrix[y_test1[i]][1] += 1
            precision_matrix[predict_result[i]][1] += 1
    
    print('#' * 60)
    print('label', '\t', 'Precision', '\t', 'Recall')
    for i in range(5):
        pre =  precision_matrix[i][0]/(precision_matrix[i][0] + precision_matrix[i][1])
        rec = recall_matrix[i][0]/(recall_matrix[i][0] + recall_matrix[i][1])
        print(i, '\t', pre, '\t', rec)
    print('#' * 60, '\n')
        
    

************************************************************
Best epoch: 29
Best accurancy: 0.994918
************************************************************ 

Accurancy 0.992411

############################################################
label 	 Precision 	 Recall
0 	 0.994903160041 	 0.995918367347
1 	 0.999114260407 	 0.993832599119
2 	 0.984511132623 	 0.985465116279
3 	 0.989151873767 	 0.993069306931
4 	 0.993890020367 	 0.993890020367
############################################################ 

