In [28]:
import tensorflow as tf
import time
from maxout import max_out

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [29]:
# 5*240 neurons in each hidden layers
n_hidden_1 = 1200
n_hidden_2 = 1200

# input size is the size of a picture: 28*28
# output size
input_size = 784
output_size = 10

# Parameters
learning_rate = 0.0005
training_epochs = 1000
batch_size = 200
display_step = 1

In [30]:
def layer_maxout(x, weight_shape, bias_shape):
    """
    Defines the network layers
    input:
        - x: input vector of the layer
        - weight_shape: shape the the weight maxtrix
        - bias_shape: shape of the bias vector
    output:
        - output vector of the layer after the matrix multiplication and transformation
    """
    
    weight_init = tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
    W = tf.get_variable("W", weight_shape, initializer=weight_init,constraint = tf.keras.constraints.MaxNorm(4))
    
    bias_init = tf.constant_initializer(value=0)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    
    return max_out(tf.matmul(x, W) + b,240)

In [31]:
def layer(x, weight_shape, bias_shape):
    """
    Defines the network layers
    input:
        - x: input vector of the layer
        - weight_shape: shape the the weight maxtrix
        - bias_shape: shape of the bias vector
    output:
        - output vector of the layer after the matrix multiplication and transformation
    """
    
    weight_init = tf.random_normal_initializer(stddev=(2.0/weight_shape[0])**0.5)
    W = tf.get_variable("W", weight_shape, initializer=weight_init,constraint = tf.keras.constraints.MaxNorm(4))
    
    bias_init = tf.constant_initializer(value=0)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    
    return tf.nn.relu(tf.matmul(x, W) + b)

In [32]:
def inference(x, keep_prob):
    """
    define the structure of the whole network
    input:
        - x: a batch of pictures 
        (input shape = (batch_size*image_size))
        - keep_prob: The keep_prob of dropout layer
    output:
        - a batch vector corresponding to the logits predicted by the network
        (output shape = (batch_size*output_size)) 
    """
    x = tf.cond(keep_prob<1,lambda:tf.nn.dropout(x,0.8),lambda:x)
    x = tf.reshape(x,[-1,28,28,1])

    with tf.variable_scope("fully_connected1"):
        
        # pass the output of max-pooling into a Fully_Connected layer
        x = tf.reshape(x,[-1,28*28])
        # after reshaping, use fully-connected layer to compress
        fc_1 = layer_maxout(x, [28*28, n_hidden_1], [n_hidden_1])
        
        # apply dropout. You may try to add drop out after every pooling layer.
        # outputs the input element scaled up by 1/keep_prob
        # The scaling is so that the expected sum is unchanged
        fc_1_drop = tf.nn.dropout(fc_1, keep_prob)
    
    with tf.variable_scope("fully_connected2"):
        
        # pass the output of max-pooling into a Fully_Connected layer

        # after reshaping, use fully-connected layer to compress
        fc_2 = layer_maxout(fc_1_drop, [n_hidden_1/5, n_hidden_2], [n_hidden_2])
        
        # apply dropout. You may try to add drop out after every pooling layer.
        # outputs the input element scaled up by 1/keep_prob
        # The scaling is so that the expected sum is unchanged
        fc_2_drop = tf.nn.dropout(fc_2, keep_prob)

    with tf.variable_scope("output"):
        output = layer(fc_2_drop, [n_hidden_2/5, 10], [10])

    return output

In [33]:
def loss(output, y):
    """
    Computes softmax cross entropy between logits and labels and then the loss 
    
    intput:
        - output: the output of the inference function 
        - y: true value of the sample batch
        
        the two have the same shape (batch_size * num_of_classes)
    output:
        - loss: loss of the corresponding batch (scalar tensor)
    
    """
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)    
    loss = tf.reduce_mean(xentropy)
    return loss

In [34]:
def training(cost, global_step):
    """
    defines the necessary elements to train the network
    
    intput:
        - cost: the cost is the loss of the corresponding batch
        - global_step: number of batch seen so far, it is incremented by one each time the .minimize() function is called
    """
    tf.summary.scalar("cost", cost)
    # using Adam Optimizer 
    optimizer = tf.train.AdamOptimizer(learning_rate)
    #grads = optimizer.compute_gradients(cost)
    #for i, (g,v) in enumerate(grads):
    #    grads[i] = (tf.clip_by_norm(g,3.5),v)
    #train_op = optimizer.apply_gradients(grads, global_step=global_step)
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [35]:
def evaluate(output, y):
    """
    evaluates the accuracy on the validation set 
    input:
        -output: prediction vector of the network for the validation set
        -y: true value for the validation set
    output:
        - accuracy: accuracy on the validation set (scalar between 0 and 1)
    """
    #correct prediction is a binary vector which equals one when the output and y match
    #otherwise the vector equals 0
    #tf.cast: change the type of a tensor into another one
    #then, by taking the mean of the tensor, we directly have the average score, so the accuracy
    
    correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar("validation_error", (1.0 - accuracy))
    return accuracy

In [36]:
earlystop_cnt = 0
earlystop_threshold = 16
if __name__ == '__main__':
    
    #please, make sure you changed for your own path 
    log_files_path = 'C:/Users/WeiLiu/logs/CNNs/'
    start_time = time.time()
    
    with tf.Graph().as_default():

        with tf.variable_scope("MNIST_DropoutNNRelu_maxnorm_model"):
            #neural network definition
            
            #the input variables are first define as placeholder 
            # a placeholder is a variable/data which will be assigned later 
            # MNIST data image of shape 28*28=784
            x = tf.placeholder("float", [None, 784]) 
            # 0-9 digits recognition
            y = tf.placeholder("float", [None, 10])  
            
            # dropout probability
            keep_prob = tf.placeholder(tf.float32) 
            #the network is defined using the inference function defined above in the code
            output = inference(x, keep_prob)
            cost = loss(output, y)
            #initialize the value of the global_step variable 
            # recall: it is incremented by one each time the .minimise() is called
            global_step = tf.Variable(0, name='global_step', trainable=False)
            train_op = training(cost, global_step)
            #evaluate the accuracy of the network (done on a validation set)
            eval_op = evaluate(output, y)
            summary_op = tf.summary.merge_all()
            saver = tf.train.Saver()
            sess = tf.Session()
            
            summary_writer = tf.summary.FileWriter(log_files_path, sess.graph)
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            
            
            total_batch = int(mnist.train.num_examples/batch_size)
            max_val_acc = 0.0
            prev_cost = 0.0
            
            # Training cycle
            for epoch in range(training_epochs):

                avg_cost = 0.0
                
                # Loop over all batches
                for i in range(total_batch):
                    
                    minibatch_x, minibatch_y = mnist.train.next_batch(batch_size)
                    
                    # Fit training using batch data
                    sess.run(train_op, feed_dict={x: minibatch_x, y: minibatch_y, keep_prob: 0.5})
                    
                    # Compute average loss
                    avg_cost += sess.run(cost, feed_dict={x: minibatch_x, y: minibatch_y, keep_prob: 0.5})/total_batch
                    
                
                # Display logs per epoch step
                if epoch % display_step == 0:
                    
                    print("Epoch:", '%04d' % (epoch+1), "cost =", "{:0.9f}".format(avg_cost))
                    
                    #probability dropout of 1 during validation
                    accuracy_val = sess.run(eval_op, feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob: 1})
                    print("Validation Error:", (1 - accuracy_val))
                    
                    if accuracy_val < max_val_acc:
                        if (avg_cost < prev_cost):
                            if earlystop_cnt == earlystop_threshold:
                                print("early stopped on" + str(epoch))
                                break
                            else:
                                print("overfitting warning:" + str(earlystop_cnt))
                                earlystop_cnt += 1
                        else:
                            earlystop_cnt = 0
                    else:
                        earlystop_cnt = 0
                        max_val_acc = accuracy_val
                        
                    prev_cost = avg_cost
                    
                    # probability dropout of 0.25 during training
                    summary_str = sess.run(summary_op, feed_dict={x: minibatch_x, y: minibatch_y, keep_prob: 0.5})
                    summary_writer.add_summary(summary_str, sess.run(global_step))
                    
                    saver.save(sess, log_files_path+'model-checkpoint', global_step=global_step)
                    
            print("Optimization Done")
                    
            accuracy = sess.run(eval_op, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1})
            print("Test Accuracy:", accuracy)
                    
        elapsed_time = time.time() - start_time
        print('Execution time was %0.3f' % elapsed_time)

Epoch: 0001 cost = 2.094841647
Validation Error: 0.256399989128
Epoch: 0002 cost = 0.585251879
Validation Error: 0.0469999909401
Epoch: 0003 cost = 0.298881165
Validation Error: 0.0374000072479
Epoch: 0004 cost = 0.239827930
Validation Error: 0.0314000248909
Epoch: 0005 cost = 0.197133813
Validation Error: 0.0278000235558
Epoch: 0006 cost = 0.172258625
Validation Error: 0.0256000161171
Epoch: 0007 cost = 0.153605135
Validation Error: 0.0221999883652
Epoch: 0008 cost = 0.138847135
Validation Error: 0.0206000208855
Epoch: 0009 cost = 0.131160971
Validation Error: 0.0188000202179
Epoch: 0010 cost = 0.120204206
Validation Error: 0.0163999795914
Epoch: 0011 cost = 0.113883973
Validation Error: 0.0180000066757
Epoch: 0012 cost = 0.104447685
Validation Error: 0.0162000060081
Epoch: 0013 cost = 0.097113721
Validation Error: 0.0157999992371
Epoch: 0014 cost = 0.093153260
Validation Error: 0.0157999992371
Epoch: 0015 cost = 0.088957570
Validation Error: 0.0156000256538
Epoch: 0016 cost = 0.08809

Epoch: 0111 cost = 0.022431423
Validation Error: 0.0109999775887
Epoch: 0112 cost = 0.022203612
Validation Error: 0.0102000236511
Epoch: 0113 cost = 0.021879743
Validation Error: 0.0105999708176
Epoch: 0114 cost = 0.019128852
Validation Error: 0.0109999775887
Epoch: 0115 cost = 0.020558573
Validation Error: 0.0109999775887
Epoch: 0116 cost = 0.020713641
Validation Error: 0.0103999972343
Epoch: 0117 cost = 0.020203454
Validation Error: 0.0108000040054
Epoch: 0118 cost = 0.020535336
Validation Error: 0.0109999775887
Epoch: 0119 cost = 0.019612452
Validation Error: 0.0102000236511
Epoch: 0120 cost = 0.019042258
Validation Error: 0.0113999843597
Epoch: 0121 cost = 0.021296648
Validation Error: 0.0105999708176
Epoch: 0122 cost = 0.020344120
Validation Error: 0.0117999911308
Epoch: 0123 cost = 0.019024308
Validation Error: 0.0112000107765
Epoch: 0124 cost = 0.021277872
Validation Error: 0.0112000107765
Epoch: 0125 cost = 0.021373748
Validation Error: 0.0116000175476
Epoch: 0126 cost = 0.0204

Epoch: 0218 cost = 0.015840618
Validation Error: 0.0116000175476
Epoch: 0219 cost = 0.015965277
Validation Error: 0.012600004673
Epoch: 0220 cost = 0.015761924
Validation Error: 0.0116000175476
Epoch: 0221 cost = 0.015347287
Validation Error: 0.0108000040054
Epoch: 0222 cost = 0.015728038
Validation Error: 0.0102000236511
Epoch: 0223 cost = 0.014945505
Validation Error: 0.0103999972343
Epoch: 0224 cost = 0.014465749
Validation Error: 0.0108000040054
Epoch: 0225 cost = 0.013236294
Validation Error: 0.0109999775887
Epoch: 0226 cost = 0.014826219
Validation Error: 0.0117999911308
Epoch: 0227 cost = 0.015552122
Validation Error: 0.0105999708176
Epoch: 0228 cost = 0.013900751
Validation Error: 0.0127999782562
Epoch: 0229 cost = 0.014744480
Validation Error: 0.0113999843597
Epoch: 0230 cost = 0.015209809
Validation Error: 0.0116000175476
Epoch: 0231 cost = 0.014726652
Validation Error: 0.0123999714851
Epoch: 0232 cost = 0.016629169
Validation Error: 0.0108000040054
Epoch: 0233 cost = 0.01505

Epoch: 0326 cost = 0.016556223
Validation Error: 0.0117999911308
Epoch: 0327 cost = 0.015615248
Validation Error: 0.0113999843597
Epoch: 0328 cost = 0.017350073
Validation Error: 0.0116000175476
Epoch: 0329 cost = 0.013098737
Validation Error: 0.0116000175476
Epoch: 0330 cost = 0.013439658
Validation Error: 0.0105999708176
Epoch: 0331 cost = 0.013321464
Validation Error: 0.0113999843597
Epoch: 0332 cost = 0.015720371
Validation Error: 0.00999999046326
Epoch: 0333 cost = 0.013947315
Validation Error: 0.00940001010895
Epoch: 0334 cost = 0.012546747
Validation Error: 0.0117999911308
Epoch: 0335 cost = 0.014249928
Validation Error: 0.0116000175476
Epoch: 0336 cost = 0.014182706
Validation Error: 0.0112000107765
Epoch: 0337 cost = 0.015174383
Validation Error: 0.0112000107765
Epoch: 0338 cost = 0.012959558
Validation Error: 0.0123999714851
Epoch: 0339 cost = 0.013792320
Validation Error: 0.0113999843597
Epoch: 0340 cost = 0.013277969
Validation Error: 0.0108000040054
Epoch: 0341 cost = 0.01

Epoch: 0435 cost = 0.013361530
Validation Error: 0.0113999843597
Epoch: 0436 cost = 0.013646459
Validation Error: 0.0108000040054
Epoch: 0437 cost = 0.014863402
Validation Error: 0.0108000040054
Epoch: 0438 cost = 0.013828544
Validation Error: 0.0105999708176
Epoch: 0439 cost = 0.015793146
Validation Error: 0.0112000107765
Epoch: 0440 cost = 0.013409037
Validation Error: 0.0108000040054
Epoch: 0441 cost = 0.013907721
Validation Error: 0.0116000175476
Epoch: 0442 cost = 0.013318044
Validation Error: 0.0108000040054
Epoch: 0443 cost = 0.015156335
Validation Error: 0.0113999843597
Epoch: 0444 cost = 0.012735100
Validation Error: 0.0113999843597
Epoch: 0445 cost = 0.012542948
Validation Error: 0.0117999911308
Epoch: 0446 cost = 0.013628310
Validation Error: 0.0116000175476
Epoch: 0447 cost = 0.015038017
Validation Error: 0.012600004673
Epoch: 0448 cost = 0.013566487
Validation Error: 0.0120000243187
Epoch: 0449 cost = 0.014419866
Validation Error: 0.0116000175476
Epoch: 0450 cost = 0.01277

Epoch: 0544 cost = 0.011640715
Validation Error: 0.0113999843597
Epoch: 0545 cost = 0.013226510
Validation Error: 0.0117999911308
Epoch: 0546 cost = 0.015556396
Validation Error: 0.012600004673
Epoch: 0547 cost = 0.013673201
Validation Error: 0.0121999979019
Epoch: 0548 cost = 0.011124363
Validation Error: 0.0130000114441
Epoch: 0549 cost = 0.012441617
Validation Error: 0.0113999843597
Epoch: 0550 cost = 0.012805251
Validation Error: 0.0121999979019
Epoch: 0551 cost = 0.013395240
Validation Error: 0.0105999708176
Epoch: 0552 cost = 0.012511858
Validation Error: 0.0113999843597
Epoch: 0553 cost = 0.012738377
Validation Error: 0.0112000107765
Epoch: 0554 cost = 0.014572694
Validation Error: 0.0127999782562
Epoch: 0555 cost = 0.014135923
Validation Error: 0.0112000107765
Epoch: 0556 cost = 0.013486711
Validation Error: 0.0130000114441
Epoch: 0557 cost = 0.013787862
Validation Error: 0.0116000175476
Epoch: 0558 cost = 0.014871474
Validation Error: 0.0108000040054
Epoch: 0559 cost = 0.01316

Epoch: 0651 cost = 0.012157006
Validation Error: 0.0121999979019
Epoch: 0652 cost = 0.011637452
Validation Error: 0.012600004673
Epoch: 0653 cost = 0.014093658
Validation Error: 0.0103999972343
Epoch: 0654 cost = 0.012049198
Validation Error: 0.0117999911308
Epoch: 0655 cost = 0.011430306
Validation Error: 0.012600004673
Epoch: 0656 cost = 0.013328100
Validation Error: 0.0123999714851
Epoch: 0657 cost = 0.012260877
Validation Error: 0.0120000243187
Epoch: 0658 cost = 0.013214921
Validation Error: 0.012600004673
Epoch: 0659 cost = 0.011967076
Validation Error: 0.0131999850273
Epoch: 0660 cost = 0.013597342
Validation Error: 0.0127999782562
Epoch: 0661 cost = 0.012634737
Validation Error: 0.0116000175476
Epoch: 0662 cost = 0.011494225
Validation Error: 0.0116000175476
Epoch: 0663 cost = 0.011465071
Validation Error: 0.0121999979019
Epoch: 0664 cost = 0.011667239
Validation Error: 0.0113999843597
Epoch: 0665 cost = 0.012784623
Validation Error: 0.0113999843597
Epoch: 0666 cost = 0.0135161

Epoch: 0759 cost = 0.014749960
Validation Error: 0.0120000243187
Epoch: 0760 cost = 0.012809103
Validation Error: 0.0113999843597
Epoch: 0761 cost = 0.011969168
Validation Error: 0.0103999972343
Epoch: 0762 cost = 0.012245870
Validation Error: 0.0112000107765
Epoch: 0763 cost = 0.013898097
Validation Error: 0.0105999708176
Epoch: 0764 cost = 0.013403948
Validation Error: 0.0112000107765
Epoch: 0765 cost = 0.011872674
Validation Error: 0.0121999979019
Epoch: 0766 cost = 0.012440374
Validation Error: 0.0109999775887
Epoch: 0767 cost = 0.011670919
Validation Error: 0.0109999775887
Epoch: 0768 cost = 0.011676555
Validation Error: 0.0113999843597
Epoch: 0769 cost = 0.014216710
Validation Error: 0.0109999775887
Epoch: 0770 cost = 0.013278597
Validation Error: 0.0112000107765
Epoch: 0771 cost = 0.012030036
Validation Error: 0.0117999911308
Epoch: 0772 cost = 0.013246021
Validation Error: 0.0112000107765
Epoch: 0773 cost = 0.011213051
Validation Error: 0.0120000243187
Epoch: 0774 cost = 0.0132

Epoch: 0867 cost = 0.010501639
Validation Error: 0.0105999708176
Epoch: 0868 cost = 0.013732254
Validation Error: 0.0121999979019
Epoch: 0869 cost = 0.014914639
Validation Error: 0.0112000107765
Epoch: 0870 cost = 0.010861436
Validation Error: 0.0117999911308
Epoch: 0871 cost = 0.013434576
Validation Error: 0.0102000236511
Epoch: 0872 cost = 0.014254285
Validation Error: 0.0109999775887
Epoch: 0873 cost = 0.013645894
Validation Error: 0.0113999843597
Epoch: 0874 cost = 0.013315948
Validation Error: 0.0109999775887
Epoch: 0875 cost = 0.013751811
Validation Error: 0.0109999775887
Epoch: 0876 cost = 0.011283204
Validation Error: 0.0105999708176
Epoch: 0877 cost = 0.011193516
Validation Error: 0.0116000175476
Epoch: 0878 cost = 0.011555335
Validation Error: 0.0108000040054
Epoch: 0879 cost = 0.014087021
Validation Error: 0.00999999046326
Epoch: 0880 cost = 0.012839961
Validation Error: 0.0105999708176
Epoch: 0881 cost = 0.010708626
Validation Error: 0.0103999972343
Epoch: 0882 cost = 0.010

Epoch: 0975 cost = 0.011676517
Validation Error: 0.0131999850273
Epoch: 0976 cost = 0.011368075
Validation Error: 0.0109999775887
Epoch: 0977 cost = 0.011873511
Validation Error: 0.00980001688004
Epoch: 0978 cost = 0.012487256
Validation Error: 0.0105999708176
Epoch: 0979 cost = 0.012795239
Validation Error: 0.0131999850273
Epoch: 0980 cost = 0.011812410
Validation Error: 0.00999999046326
Epoch: 0981 cost = 0.015797976
Validation Error: 0.0113999843597
Epoch: 0982 cost = 0.012673940
Validation Error: 0.0123999714851
Epoch: 0983 cost = 0.013988890
Validation Error: 0.0105999708176
Epoch: 0984 cost = 0.012551885
Validation Error: 0.0103999972343
Epoch: 0985 cost = 0.012256627
Validation Error: 0.0103999972343
Epoch: 0986 cost = 0.012487372
Validation Error: 0.0109999775887
Epoch: 0987 cost = 0.011656959
Validation Error: 0.0102000236511
Epoch: 0988 cost = 0.012944050
Validation Error: 0.0103999972343
Epoch: 0989 cost = 0.013028558
Validation Error: 0.0109999775887
Epoch: 0990 cost = 0.01