In [1]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets

import numpy as np
import os

from PIL import Image

In [2]:
VGG_MEAN = [123.68, 116.78, 103.94]

In [3]:
x_train_path = './processed_data/x_train.npy'
y_train_path = './processed_data/y_train.npy'
x_val_path = './processed_data/x_val.npy'
y_val_path = './processed_data/y_val.npy'
model_path = 'vgg_16.ckpt'
batch_size = 32
num_workers = 4
num_epochs1 = 1500
const_learning_rate1 = 1e-4
dropout_keep_prob = 0.5
weight_decay = 5e-4

In [4]:
# load training and validation data
x_train = np.load(x_train_path)
y_train = np.load(y_train_path)
x_val = np.load(x_val_path)
y_val = np.load(y_val_path)

total_images = x_train.shape[0]

In [5]:
def check_accuracy(sess, error, is_training,
                  dataset_init_op):
    """
    Check the accuracy of the model on either train 
    or val (depending on dataset_init_op).
    """
    # Initialize the correct dataset
    sess.run(dataset_init_op)
    err = []
    while True:
        try:
            mse, op = sess.run(error, {is_training: False})
            err.append(float(mse))
        except tf.errors.OutOfRangeError:
            break
    
    return sum(err)/float(len(err))

def save_checkpoint(saver, init_err, val_err, sess, epoch, name):
    if val_err < init_err:
        init_err = val_err
        save_path = './checks/' + name
        saver.save(sess, 
                    save_path,
                    global_step=epoch,
                    write_meta_graph=True)
    return init_err    

def get_arrays(fold=None):
    if fold == "train":
        return x_train, y_train
    else:
        return x_val, y_val

In [6]:
def main():
    x_train, y_train = get_arrays(fold="train")
    x_val, y_val = get_arrays(fold="val")
    
    num_classes = 30
    
    graph = tf.Graph()
    with graph.as_default():
        # Standard preprocessing for VGG on ImageNet
        def _parse_function(image, keypoints):
            image = tf.reshape(image, [96, 96, 1])
            
            rgb_image = tf.image.grayscale_to_rgb(image)
            image = tf.cast(rgb_image, tf.float32)
            
            # changing resolution according from 48x48 to 256x256
            resized_image = tf.image.resize_images(image, tf.constant([256, 256]))
            
            return resized_image, keypoints
        
        # Preprocessing for training
        def training_preprocess(image, keypoints):
            # Random Crop
            crop_image = tf.random_crop(image, [224, 224, 3])
            flip_image = tf.image.random_flip_left_right(crop_image)
            
            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = flip_image - means
            
            return centered_image, keypoints
        
        # Preprocessing for validation
        def val_preprocess(image, keypoints):
            # Central Crop
            crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
            
            means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
            centered_image = crop_image - means
            
            return centered_image, keypoints
        
        # DATASET CREATION using tf.contrib.data.Dataset
        
        # Training dataset
        x_train = tf.constant(x_train)
        y_train = tf.constant(y_train)
        
        train_dataset = tf.contrib.data.Dataset.from_tensor_slices((x_train, y_train))
        train_dataset = train_dataset.map(_parse_function,
                        num_threads=num_workers,
                        output_buffer_size=batch_size)
        train_dataset = train_dataset.map(training_preprocess,
                        num_threads=num_workers, 
                        output_buffer_size=batch_size)
        train_dataset = train_dataset.shuffle(buffer_size=10000) # don't forget to shuffle
        batched_train_dataset = train_dataset.batch(batch_size)
        
        # Validation dataset
        x_val = tf.constant(x_val)
        y_val = tf.constant(y_val)
        
        val_dataset = tf.contrib.data.Dataset.from_tensor_slices((x_val, y_val))
        val_dataset = val_dataset.map(_parse_function,
                        num_threads=num_workers, 
                        output_buffer_size=batch_size)
        val_dataset = val_dataset.map(val_preprocess, 
                        num_threads=num_workers, 
                        output_buffer_size=batch_size)
        batched_val_dataset = val_dataset.batch(batch_size)
        
        # Now define the Iterator that can operate on
        # any of the datasets
        # Once this is done, we don't need to feed any value for images and labels
        # as they are automatically pulled out from the iterator queues.
        
        iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types,
                                                          batched_train_dataset.output_shapes)
        
        images, keypoints = iterator.get_next()
        
        train_init_op = iterator.make_initializer(batched_train_dataset)
        val_init_op = iterator.make_initializer(batched_val_dataset)
        
        # Create a placeholder to indicate if we are in training or validation phase
        is_training = tf.placeholder(tf.bool)
        
        # Get the pretrained VGG16 ready
        vgg = tf.contrib.slim.nets.vgg
        with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=weight_decay)):
            predictions, _ = vgg.vgg_16(images, num_classes=num_classes, 
                                   is_training=is_training, 
                                  dropout_keep_prob=dropout_keep_prob)
            
        # specify model checkpoint path
        assert(os.path.isfile(model_path))
        
        # Restore only convolutional layers and not fully connected ones
        variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc6',
                                                                                    'vgg_16/fc7',
                                                                                      'vgg_16/fc8'])
        
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)
        
        # Initialization operation for all three fully connected layers
        fc6_variables = tf.contrib.framework.get_variables('vgg_16/fc6')
        fc6_init = tf.variables_initializer(fc6_variables)
        fc7_variables = tf.contrib.framework.get_variables('vgg_16/fc7')
        fc7_init = tf.variables_initializer(fc7_variables)
        fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
        fc8_init = tf.variables_initializer(fc8_variables)
        
        # Define loss function 
        tf.losses.mean_squared_error(labels=keypoints,
                                    predictions=predictions)
        loss = tf.losses.get_total_loss()
        
        # For tensorboard
        tf.summary.scalar('loss', loss)
        
        # We will train only for last three layers
        var_list = []
        var_list.append(fc6_variables)
        var_list.append(fc7_variables)
        var_list.append(fc8_variables)
        
        # training op to train complete network
        # with step decay for learning rate
        global_step = tf.Variable(0, trainable=False)
        boundaries = [500, 1000, 1500]
        values = [0.0001, 0.00005, 0.00001, 0.000005]
        learning_rate = tf.train.piecewise_constant(global_step,boundaries,values)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(loss, var_list=var_list) 
    
        # Evaluation Metrics
        error = tf.metrics.mean_squared_error(labels=keypoints,
                                             predictions=predictions)
        
        # Save built checkpoint
        saver = tf.train.Saver(max_to_keep=3)
        
        # Variables Initializer
        global_init = tf.global_variables_initializer()
        local_init = tf.local_variables_initializer()
        
        summary_op = tf.summary.merge_all()
        
        tf.get_default_graph().finalize()
        
    # Graph has been built, time to throw some computation
    with tf.Session(graph=graph) as sess:
        init_fn(sess) # load the pretrained weights
        sess.run(fc6_init) # initialize the new fc6 layer
        sess.run(fc7_init)
        sess.run(fc8_init)
        
        sess.run(global_init)
        sess.run(local_init)

        init_mse = 1000000.0 # take a big number to compare error
        
        writer = tf.summary.FileWriter("./logs", graph=None)
        
        # Update only last three layers
        for epoch in range(num_epochs1):
            # Track loss
            loss_history = []
            
            # number of batches in one epoch
            batch_count = int(total_images/batch_size)
            
            # Run an epoch over the training data
            print("Starting epoch %d / %d" % (epoch + 1,
                                             num_epochs1))
            # Initialize the iterator with training set
            sess.run(train_init_op)
            
            while True:
                try:
                    _, summary = sess.run([train_op, summary_op],
                                 {is_training: True})
                    loss_history.append(sess.run(loss, 
                                                 {is_training: True}))
                except tf.errors.OutOfRangeError:
                    break
            
            # print average loss per epoch
            epoch_loss = sum(loss_history) / float(len(loss_history))
            print("Loss: %f" %  epoch_loss, end="")
            
            # Check accuracy on the train and val sets every epoch
            #train_mse = check_accuracy(sess, error, is_training, train_init_op)
            val_mse = check_accuracy(sess, error, is_training, val_init_op)
            
            #print("Train MSE: %f" % train_mse)
            print("    Validation MSE: %f" % val_mse)
            
            # Call the function to save the checkpoint
            init_mse = save_checkpoint(saver, init_mse, val_mse, sess,
                           epoch, 'my_model1')
            #print("Checkpoint Saved")
            
            # tensorboard logging
            writer.add_summary(summary, epoch)
            #print("Tensorboard logging complete")
        
            

In [7]:
if __name__ == '__main__':
    main()

Instructions for updating:
Use `tf.data.Dataset.from_tensor_slices()`.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
INFO:tensorflow:Restoring parameters from vgg_16.ckpt
Starting epoch 1 / 1500
Loss: 1240.696931    Validation MSE: 223.989182
Starting epoch 2 / 1500
Loss: 248.109277    Validation MSE: 221.568116
Starting epoch 3 / 1500
Loss: 222.460108    Validation MSE: 212.833600
Starting epoch 4 / 1500
Loss: 201.370837    Validation MSE: 204.258968
Starting epoch 5 / 1500
Loss: 186.292660    Validation MSE: 195.961145
Starting epoch 6 / 1500
Loss: 162.913122    Validation MSE: 188.632214
Starting epoch 7 / 1500
Loss: 142.170686    Validation MSE: 182.828742
Starting epoch 8 / 1500
Loss: 135.393735    Validation MSE: 1

Starting epoch 102 / 1500
Loss: 40.208584    Validation MSE: 75.094445
Starting epoch 103 / 1500
Loss: 36.123011    Validation MSE: 74.576109
Starting epoch 104 / 1500
Loss: 38.392805    Validation MSE: 74.132091
Starting epoch 105 / 1500
Loss: 34.882440    Validation MSE: 73.718072
Starting epoch 106 / 1500
Loss: 37.890871    Validation MSE: 73.390243
Starting epoch 107 / 1500
Loss: 33.277965    Validation MSE: 73.031572
Starting epoch 108 / 1500
Loss: 35.007320    Validation MSE: 72.559057
Starting epoch 109 / 1500
Loss: 34.320242    Validation MSE: 72.022668
Starting epoch 110 / 1500
Loss: 36.559828    Validation MSE: 71.628873
Starting epoch 111 / 1500
Loss: 33.058951    Validation MSE: 71.279397
Starting epoch 112 / 1500
Loss: 32.342930    Validation MSE: 71.028926
Starting epoch 113 / 1500
Loss: 36.035895    Validation MSE: 70.929394
Starting epoch 114 / 1500
Loss: 33.017566    Validation MSE: 70.645208
Starting epoch 115 / 1500
Loss: 28.080312    Validation MSE: 70.212596
Starti

Loss: 21.830717    Validation MSE: 46.051981
Starting epoch 218 / 1500
Loss: 22.024616    Validation MSE: 45.881426
Starting epoch 219 / 1500
Loss: 22.647059    Validation MSE: 45.816144
Starting epoch 220 / 1500
Loss: 20.894992    Validation MSE: 45.750188
Starting epoch 221 / 1500
Loss: 22.462881    Validation MSE: 45.578103
Starting epoch 222 / 1500
Loss: 20.926817    Validation MSE: 45.406229
Starting epoch 223 / 1500
Loss: 20.648262    Validation MSE: 45.237411
Starting epoch 224 / 1500
Loss: 21.117520    Validation MSE: 45.077491
Starting epoch 225 / 1500
Loss: 22.278626    Validation MSE: 44.948382
Starting epoch 226 / 1500
Loss: 23.893043    Validation MSE: 44.811432
Starting epoch 227 / 1500
Loss: 21.058442    Validation MSE: 44.644471
Starting epoch 228 / 1500
Loss: 21.347570    Validation MSE: 44.482873
Starting epoch 229 / 1500
Loss: 20.727605    Validation MSE: 44.376827
Starting epoch 230 / 1500
Loss: 22.350950    Validation MSE: 44.284715
Starting epoch 231 / 1500
Loss: 

Starting epoch 333 / 1500
Loss: 19.029555    Validation MSE: 34.638176
Starting epoch 334 / 1500
Loss: 19.028015    Validation MSE: 34.566953
Starting epoch 335 / 1500
Loss: 18.322993    Validation MSE: 34.500494
Starting epoch 336 / 1500
Loss: 18.651580    Validation MSE: 34.424781
Starting epoch 337 / 1500
Loss: 17.401853    Validation MSE: 34.342760
Starting epoch 338 / 1500
Loss: 17.853953    Validation MSE: 34.267099
Starting epoch 339 / 1500
Loss: 18.887344    Validation MSE: 34.194934
Starting epoch 340 / 1500
Loss: 17.945736    Validation MSE: 34.125967
Starting epoch 341 / 1500
Loss: 18.784068    Validation MSE: 34.052813
Starting epoch 342 / 1500
Loss: 18.072335    Validation MSE: 33.975552
Starting epoch 343 / 1500
Loss: 19.217933    Validation MSE: 33.903087
Starting epoch 344 / 1500
Loss: 18.564590    Validation MSE: 33.855263
Starting epoch 345 / 1500
Loss: 19.458305    Validation MSE: 33.803484
Starting epoch 346 / 1500
Loss: 18.336636    Validation MSE: 33.724133
Starti

Loss: 17.409867    Validation MSE: 28.275222
Starting epoch 449 / 1500
Loss: 18.148749    Validation MSE: 28.233470
Starting epoch 450 / 1500
Loss: 17.296636    Validation MSE: 28.200103
Starting epoch 451 / 1500
Loss: 18.346378    Validation MSE: 28.160419
Starting epoch 452 / 1500
Loss: 18.233296    Validation MSE: 28.117617
Starting epoch 453 / 1500
Loss: 17.189277    Validation MSE: 28.080442
Starting epoch 454 / 1500
Loss: 17.099630    Validation MSE: 28.040156
Starting epoch 455 / 1500
Loss: 16.896554    Validation MSE: 27.992963
Starting epoch 456 / 1500
Loss: 17.019022    Validation MSE: 27.950190
Starting epoch 457 / 1500
Loss: 17.456893    Validation MSE: 27.907630
Starting epoch 458 / 1500
Loss: 18.000463    Validation MSE: 27.862353
Starting epoch 459 / 1500
Loss: 17.928741    Validation MSE: 27.818056
Starting epoch 460 / 1500
Loss: 18.121597    Validation MSE: 27.774198
Starting epoch 461 / 1500
Loss: 18.447798    Validation MSE: 27.733744
Starting epoch 462 / 1500
Loss: 

Starting epoch 564 / 1500
Loss: 17.064076    Validation MSE: 24.257551
Starting epoch 565 / 1500
Loss: 16.605916    Validation MSE: 24.259332
Starting epoch 566 / 1500
Loss: 17.351237    Validation MSE: 24.261518
Starting epoch 567 / 1500
Loss: 16.959295    Validation MSE: 24.230742
Starting epoch 568 / 1500
Loss: 16.834874    Validation MSE: 24.200137
Starting epoch 569 / 1500
Loss: 16.648020    Validation MSE: 24.168344
Starting epoch 570 / 1500
Loss: 17.256706    Validation MSE: 24.136885
Starting epoch 571 / 1500
Loss: 16.513007    Validation MSE: 24.105847
Starting epoch 572 / 1500
Loss: 16.645478    Validation MSE: 24.075796
Starting epoch 573 / 1500
Loss: 16.823976    Validation MSE: 24.054264
Starting epoch 574 / 1500
Loss: 16.415410    Validation MSE: 24.035209
Starting epoch 575 / 1500
Loss: 16.555843    Validation MSE: 24.006647
Starting epoch 576 / 1500
Loss: 16.542278    Validation MSE: 23.976559
Starting epoch 577 / 1500
Loss: 16.886779    Validation MSE: 23.946918
Starti

Loss: 15.870625    Validation MSE: 21.670079
Starting epoch 680 / 1500
Loss: 15.727760    Validation MSE: 21.649070
Starting epoch 681 / 1500
Loss: 16.467580    Validation MSE: 21.632553
Starting epoch 682 / 1500
Loss: 15.715431    Validation MSE: 21.614838
Starting epoch 683 / 1500
Loss: 16.364642    Validation MSE: 21.595432
Starting epoch 684 / 1500
Loss: 16.080138    Validation MSE: 21.575681
Starting epoch 685 / 1500
Loss: 15.668466    Validation MSE: 21.552800
Starting epoch 686 / 1500
Loss: 15.733352    Validation MSE: 21.529757
Starting epoch 687 / 1500
Loss: 16.546255    Validation MSE: 21.507101
Starting epoch 688 / 1500
Loss: 16.204382    Validation MSE: 21.484722
Starting epoch 689 / 1500
Loss: 16.693741    Validation MSE: 21.462291
Starting epoch 690 / 1500
Loss: 15.797784    Validation MSE: 21.441483
Starting epoch 691 / 1500
Loss: 15.846866    Validation MSE: 21.420185
Starting epoch 692 / 1500
Loss: 15.396517    Validation MSE: 21.398586
Starting epoch 693 / 1500
Loss: 

KeyboardInterrupt: 