Sometimes we just want to use the lower layers of predefined models, and build our own layers on top of them.

In chapter10_DNN we build a network of of two hidden layers. In this example we use the reuse hidden layer 1 but build a new hidden layer 2 one top of it.

Import tensorflow and reset the default graph

In [1]:
import tensorflow as tf
tf.reset_default_graph()

Import mnist data, and define training and validation data

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
mnist = input_data.read_data_sets('datasets/mnist')

Extracting datasets/mnist/train-images-idx3-ubyte.gz
Extracting datasets/mnist/train-labels-idx1-ubyte.gz
Extracting datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting datasets/mnist/t10k-labels-idx1-ubyte.gz


In [4]:
X_train, y_train = mnist.train.images, mnist.train.labels
X_val, y_val = mnist.validation.images, mnist.validation.labels

Load network from 'tmp/mnist_dnn_final.ckpt', and list all operations

In [5]:
saver = tf.train.import_meta_graph('tmp/mnist_dnn_final.ckpt.meta')

In [6]:
for op in tf.get_default_graph().get_operations():
    print(op.name)

X
y
hidden1/kernel/Initializer/random_uniform/shape
hidden1/kernel/Initializer/random_uniform/min
hidden1/kernel/Initializer/random_uniform/max
hidden1/kernel/Initializer/random_uniform/RandomUniform
hidden1/kernel/Initializer/random_uniform/sub
hidden1/kernel/Initializer/random_uniform/mul
hidden1/kernel/Initializer/random_uniform
hidden1/kernel
hidden1/kernel/Assign
hidden1/kernel/read
hidden1/bias/Initializer/zeros
hidden1/bias
hidden1/bias/Assign
hidden1/bias/read
DNN/hidden1/MatMul
DNN/hidden1/BiasAdd
DNN/hidden1/Relu
hidden2/kernel/Initializer/random_uniform/shape
hidden2/kernel/Initializer/random_uniform/min
hidden2/kernel/Initializer/random_uniform/max
hidden2/kernel/Initializer/random_uniform/RandomUniform
hidden2/kernel/Initializer/random_uniform/sub
hidden2/kernel/Initializer/random_uniform/mul
hidden2/kernel/Initializer/random_uniform
hidden2/kernel
hidden2/kernel/Assign
hidden2/kernel/read
hidden2/bias/Initializer/zeros
hidden2/bias
hidden2/bias/Assign
hidden2/bias/read
DN

Get tensors of X, y, hidden1

In [7]:
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")
hidden1 = tf.get_default_graph().get_tensor_by_name('DNN/hidden1/Relu:0')

Define new hidden2 of 50 neurons and other parts of the network for training.

In [8]:
n_hidden2 = 50
n_output = 10

In [9]:
with tf.name_scope('new_DNN'):
    new_hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name='new_hidden2')
    new_logits = tf.layers.dense(new_hidden2, n_output, activation=None, name='new_logits')

In [10]:
with tf.name_scope('new_loss'):
    entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=new_logits)
    loss = tf.reduce_mean(entropy, name='new_loss')

In [11]:
with tf.name_scope('new_eval'):
    correct = tf.nn.in_top_k(new_logits, y, 1)
    accruacy = tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')

In [12]:
with tf.name_scope('new_train'):
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss, name='training_op')

Load the weights from 'tmp/mnist_dnn_final.ckpt', initialize new variables, and train another 20 epoches at batch size 50. Print out the validation accuracy after each epoch.

In [13]:
init = tf.global_variables_initializer()
n_epoch = 20
batch_size = 50
n_step = mnist.train.num_examples // batch_size

In [14]:
with tf.Session() as sess:
    sess.run(init) # initialize all variables
    saver.restore(sess, 'tmp/mnist_dnn_final.ckpt') # restore variables 
    
    for epoch in range(n_epoch):
        for step in range(n_step):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
        accuracy_val = sess.run(accruacy, feed_dict={X:X_val, y:y_val})
        print("Epoch: {}, validation accuracy: {:.4f}".format(epoch, accuracy_val))

INFO:tensorflow:Restoring parameters from tmp/mnist_dnn_final.ckpt
Epoch: 0, validation accuracy: 0.9690
Epoch: 1, validation accuracy: 0.9754
Epoch: 2, validation accuracy: 0.9784
Epoch: 3, validation accuracy: 0.9798
Epoch: 4, validation accuracy: 0.9776
Epoch: 5, validation accuracy: 0.9772
Epoch: 6, validation accuracy: 0.9776
Epoch: 7, validation accuracy: 0.9766
Epoch: 8, validation accuracy: 0.9818
Epoch: 9, validation accuracy: 0.9808
Epoch: 10, validation accuracy: 0.9836
Epoch: 11, validation accuracy: 0.9784
Epoch: 12, validation accuracy: 0.9784
Epoch: 13, validation accuracy: 0.9826
Epoch: 14, validation accuracy: 0.9746
Epoch: 15, validation accuracy: 0.9836
Epoch: 16, validation accuracy: 0.9830
Epoch: 17, validation accuracy: 0.9804
Epoch: 18, validation accuracy: 0.9826
Epoch: 19, validation accuracy: 0.9836
