In the Chapter11_2, we load the graph and the parameter values. Another case is that we define the network and just load the parameters from the pre-trained model. Of course the reused parts should have the same definition in the new and old networks.

In this example, we defined our own DNN network of five hidden layers, then load the weights up to the second hidden layer from tmp/mnist_dnn_final.ckpt. Then train the new model.

Import tensorflow and reset the default graph

In [1]:
import tensorflow as tf
tf.reset_default_graph()

Load mnist data and get X_train, y_train, X_val, and y_val

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
mnist = input_data.read_data_sets('datasets/minst')

Extracting datasets/minst/train-images-idx3-ubyte.gz
Extracting datasets/minst/train-labels-idx1-ubyte.gz
Extracting datasets/minst/t10k-images-idx3-ubyte.gz
Extracting datasets/minst/t10k-labels-idx1-ubyte.gz


In [4]:
X_train, y_train = mnist.train.images, mnist.train.labels
X_val, y_val = mnist.validation.images, mnist.validation.labels

Run next cell to define our neural network.

In [5]:
#
n_inputs = 28 * 28  # MNIST
n_hidden1 = 300 # reused
n_hidden2 = 100  # reused
n_hidden3 = 50  # reused
n_hidden4 = 20  # new!
n_outputs = 10  # new!

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1")       # reused
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") # reused
    hidden3 = tf.layers.dense(hidden2, n_hidden3, activation=tf.nn.relu, name="hidden3") # new
    hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") # new
    logits = tf.layers.dense(hidden4, n_outputs, name="outputs")                         # new

with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xentropy, name="loss")

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

with tf.name_scope("train"):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    training_op = optimizer.minimize(loss)

Now restor the parameters up to hidden layer 2 by 'tmp/mnist_dnn_final.ckpt', and run the model for another 20 epoches. Print out the validation accuracy.

In [6]:
reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="hidden[12]")
reuse_vars_dict = dict([(var.op.name, var) for var in reuse_vars])
restore_saver = tf.train.Saver(reuse_vars_dict)
# or you can directly use
# restore_saver = tf.train.Saver(reuse_vars)

In [7]:
init = tf.global_variables_initializer()

In [8]:
n_epoch = 20
batch_size = 50
n_step = mnist.train.num_examples // batch_size

In [9]:
with tf.Session() as sess:
    sess.run(init)
    restore_saver.restore(sess, 'tmp/mnist_dnn_final.ckpt')
    
    for epoch in range(n_epoch):
        for step in range(n_step):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict = {X:X_batch, y:y_batch})
        accuracy_val = sess.run(accuracy, feed_dict={X: mnist.validation.images, y: mnist.validation.labels})
        print("Epoch: {}, validation accuracy: {:.4f}".format(epoch, accuracy_val))

INFO:tensorflow:Restoring parameters from tmp/mnist_dnn_final.ckpt
Epoch: 0, validation accuracy: 0.8156
Epoch: 1, validation accuracy: 0.8696
Epoch: 2, validation accuracy: 0.8914
Epoch: 3, validation accuracy: 0.9034
Epoch: 4, validation accuracy: 0.9096
Epoch: 5, validation accuracy: 0.9142
Epoch: 6, validation accuracy: 0.9182
Epoch: 7, validation accuracy: 0.9206
Epoch: 8, validation accuracy: 0.9236
Epoch: 9, validation accuracy: 0.9252
Epoch: 10, validation accuracy: 0.9272
Epoch: 11, validation accuracy: 0.9296
Epoch: 12, validation accuracy: 0.9308
Epoch: 13, validation accuracy: 0.9324
Epoch: 14, validation accuracy: 0.9328
Epoch: 15, validation accuracy: 0.9362
Epoch: 16, validation accuracy: 0.9368
Epoch: 17, validation accuracy: 0.9364
Epoch: 18, validation accuracy: 0.9388
Epoch: 19, validation accuracy: 0.9396
