In Chapter11_3, we load the parameters from the pre-trained model and continue to train the entire model. In some cases we want to "freeze" these parameters if we just consider the lower layers as well-trained feature extractor. 

In this example, we defined our own DNN network of five hidden layers, then load the weights up to the second hidden layer from tmp/mnist_dnn_final.ckpt. We freeze the first 2 layers and train the rest of the model.

Import numpy and tensorflow and reset the default graph

In [1]:
import numpy as np

In [2]:
import tensorflow as tf
tf.reset_default_graph()

Load mnist data and get X_train, y_train, X_val, and y_val

In [3]:
from tensorflow.examples.tutorials.mnist import input_data

In [4]:
mnist = input_data.read_data_sets('datasets/minst')

Extracting datasets/minst/train-images-idx3-ubyte.gz
Extracting datasets/minst/train-labels-idx1-ubyte.gz
Extracting datasets/minst/t10k-images-idx3-ubyte.gz
Extracting datasets/minst/t10k-labels-idx1-ubyte.gz


In [5]:
X_train, y_train = mnist.train.images, mnist.train.labels
X_val, y_val = mnist.validation.images, mnist.validation.labels

Run next cell to define our neural network.

In [6]:
#
n_inputs = 28 * 28  # MNIST
n_hidden1 = 300 # reused
n_hidden2 = 100  # reused
n_hidden3 = 50  # reused
n_hidden4 = 20  # new!
n_outputs = 10  # new!

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1")       # reused, frozen
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused, frozen
    hidden2_stop = tf.stop_gradient(hidden2)
    hidden3 = tf.layers.dense(hidden2_stop, n_hidden3, activation=tf.nn.relu, name="hidden3") # new
    hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new
    logits = tf.layers.dense(hidden4, n_outputs, name="outputs")                         # new

with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name="loss")

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

with tf.name_scope("train"):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    training_op = optimizer.minimize(loss)

Since the frozen layers won't change, it is possible to cache the output of the topmost frozen layer for each training instance. Since training goes through the whole dataset many times, this will give you a huge speed boost as you will only need to go through the frozen layers once per training instances, instead of once per epoch.

Now run another 20 epoches and print out the validation accuracy. Also print out the weight of hidden 1 before and after training to prove that the weight is frozen.

In [7]:
reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden[12]")
reuse_vars_dict = dict([(var.op.name, var) for var in reuse_vars])
restore_saver = tf.train.Saver(reuse_vars_dict)

In [8]:
init = tf.global_variables_initializer()

In [9]:
n_epoch = 20
batch_size = 50
n_step = mnist.train.num_examples // batch_size

In [10]:
# the weight of hidden 1
graph = tf.get_default_graph()
assign_kernel = graph.get_operation_by_name("hidden1/kernel/Assign")
init_kernel = assign_kernel.inputs[1]

In [11]:
with tf.Session() as sess:
    sess.run(init)
    restore_saver.restore(sess, 'tmp/mnist_dnn_final.ckpt')
    
    # print the weight before training
    print(sess.run(assign_kernel.inputs[0])) 
    
    # cache the output of hidden layer 2
    h2_cache = sess.run(hidden2, feed_dict={X: mnist.train.images})
    h2_cache_val = sess.run(hidden2, feed_dict={X: mnist.validation.images})
    
    for epoch in range(n_epoch):
        shuffle_idx = np.random.permutation(mnist.train.num_examples)
        hidden2_batch_list = np.array_split(h2_cache[shuffle_idx], n_step)
        y_batch_list = np.array_split(y_train[shuffle_idx], n_step)
        for hidden2_batch, y_batch in zip(hidden2_batch_list, y_batch_list):
            sess.run(training_op, feed_dict = {hidden2:hidden2_batch, y:y_batch})
            
        # accuracy_val = sess.run(accuracy, feed_dict={X: mnist.validation.images, y: mnist.validation.labels})
        accuracy_val = sess.run(accuracy, feed_dict={hidden2: h2_cache_val, y: mnist.validation.labels})
        print("Epoch: {}, validation accuracy: {:.4f}".format(epoch, accuracy_val))
    
    # print the weight again. It should be unchanged since we froze the first 2 layers.
    print(sess.run(assign_kernel.inputs[0]))

INFO:tensorflow:Restoring parameters from tmp/mnist_dnn_final.ckpt
[[ 0.00741297  0.04399479 -0.07298389 ...  0.06380799 -0.01855266
   0.06487764]
 [-0.00944764  0.05897304  0.04332318 ... -0.06645274  0.02436784
  -0.01333787]
 [ 0.00074603  0.01887694 -0.00775586 ...  0.00511462 -0.01614702
  -0.05381356]
 ...
 [ 0.00572163 -0.01311519 -0.05707078 ... -0.05213419  0.05671324
  -0.07045254]
 [-0.04161249  0.05519348 -0.04053325 ...  0.01326139  0.06936294
   0.05498397]
 [-0.02907837 -0.01657688 -0.05163883 ... -0.05350867 -0.05564026
  -0.04848889]]
Epoch: 0, validation accuracy: 0.7596
Epoch: 1, validation accuracy: 0.8478
Epoch: 2, validation accuracy: 0.8814
Epoch: 3, validation accuracy: 0.8968
Epoch: 4, validation accuracy: 0.9038
Epoch: 5, validation accuracy: 0.9078
Epoch: 6, validation accuracy: 0.9116
Epoch: 7, validation accuracy: 0.9124
Epoch: 8, validation accuracy: 0.9134
Epoch: 9, validation accuracy: 0.9154
Epoch: 10, validation accuracy: 0.9178
Epoch: 11, validation 