An example of applying transfer learning using Tensorflow on the traffic sign dataset.
The focus of this example was to provide an explanation on how to do transfer learning, the model trained is far from optimal for the given problem.

This notebook should be run using **GPU support** unless you don't mind waiting a very long time!

Download & unpack [pretrained VGG](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz) under `../pretrained/vgg/vgg_16.ckpt`.

- based on the code of omondrot https://gist.github.com/omoindrot/dedc857cdc0e680dfb1be99762990c9c
- added model saving
- added tensorboard visualization

In [1]:
import os
import time
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import vgg
import tensorflow.contrib.slim.nets
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Model training parameters
model_path='../pretrained/vgg/vgg_16.ckpt'
batch_size=8    # lower in case of memory problems
num_workers=4   # set equally to the amount of cpu processors
num_epochs1=15
num_epochs2=15
learning_rate1=1e-3
learning_rate2=1e-5
dropout_keep_prob=0.5
weight_decay=5e-4

save_path='../pretrained/vgg-finetuned'
log_path='../logs/vgg'
data_path='../data/trafficsigns/train'

assert(os.path.isdir(save_path))
assert(os.path.isfile(model_path))
assert(os.path.isdir(data_path))

In [3]:
VGG_MEAN = [123.68, 116.78, 103.94] # average pixel values

# Getting input data & labels

In [4]:
'''
Function to get the data paths and the annotated labels.
'''
def load_data(data_directory):
    directories = [d for d in os.listdir(data_directory)
                   if os.path.isdir(os.path.join(data_directory, d))]
    labels = []
    images = []
    label_tag = 0
    num_samples = 0

    for d in directories:

        sub_directories = [k for k in os.listdir(os.path.join(data_directory, d))]

        for s in sub_directories:
            label_directory = os.path.join(data_directory,d,s)
            file_names = [os.path.join(label_directory, f)
                          for f in os.listdir(label_directory)
                          if f.endswith('.png')]

            for f in file_names:
                images.append(f)
                labels.append(label_tag)
                num_samples += 1
        label_tag += 1

    return images, labels

In [5]:
X, y = load_data(data_path)

In [6]:
print(X[0])
print(y[0])

../data/trafficsigns/train/stop/B5/00983_02446.png
0


In [7]:
# randomly split the data into a training and validation set (for cross validation purposes)
train_filenames, val_filenames, train_labels, val_labels = train_test_split(X, y, test_size=0.10, random_state=0)
num_classes = len(set(train_labels))

In [8]:
print('Training set size: ' + str(len(train_filenames)))
print('Validation set size: ' + str(len(val_filenames)))
print('Number of classes: ' + str(num_classes))

Training set size: 3731
Validation set size: 415
Number of classes: 12


# Data Preprocessing

In [9]:
'''
Given a filename and label, returns a tuple (tensor, label).
The tensor contains the pixel values of the provided image, proportionally scaled so the smallest side is 256 pixels.
'''
def _parse_function(filename, label):
    image_string = tf.read_file(filename)   # read the image input file
    image_decoded = tf.image.decode_jpeg(image_string, channels=3) # decode jpeg to receive pixel values
    image = tf.cast(image_decoded, tf.float32)

    smallest_side = 256.0
    height, width = tf.shape(image)[0], tf.shape(image)[1]
    height = tf.to_float(height)
    width = tf.to_float(width)

    # rescale the images, keeping proportions, so the smallest side has size 'smallest_side'
    scale = tf.cond(tf.greater(height, width), lambda: smallest_side / width,lambda: smallest_side / height)
    new_height = tf.to_int32(height * scale)
    new_width = tf.to_int32(width * scale)

    resized_image = tf.image.resize_images(image, [new_height, new_width])
    return resized_image, label


# VGG preprocessing steps for training (cropping, random flipping, subtracting the mean pixel values)
def training_preprocess(image, label):
    # VGG expects images of size [224, 224, 3]
    crop_image = tf.random_crop(image, [224, 224, 3]) # Randomly extract the given dimensions from the image
    flip_image = tf.image.random_flip_left_right(crop_image) # Randomly mirror (or not)

    means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
    centered_image = flip_image - means # Detract average value per channel

    return centered_image, label

# VGG preprocessing steps for validation (cropping, subtracting the mean pixel values)
def val_preprocess(image, label):
    crop_image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)

    means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
    centered_image = crop_image - means

    return centered_image, label

In [10]:
# Training dataset preprocessing and Tensorflow tensors adaptation
train_filenames_tensor = tf.constant(train_filenames)
train_labels_tensor = tf.constant(train_labels)
train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames_tensor, train_labels_tensor))
train_dataset = train_dataset.map(_parse_function, num_threads=num_workers, output_buffer_size=batch_size)
train_dataset = train_dataset.map(training_preprocess, num_threads=num_workers, output_buffer_size=batch_size)
train_dataset = train_dataset.shuffle(buffer_size=10000)
batched_train_dataset = train_dataset.batch(batch_size)

In [11]:
# validation dataset preprocessing and Tensorflow tensors adaptation
val_filenames_tensor = tf.constant(val_filenames)
val_labels_tensor = tf.constant(val_labels)
val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames_tensor, val_labels_tensor))
val_dataset = val_dataset.map(_parse_function, num_threads=num_workers, output_buffer_size=batch_size)
val_dataset = val_dataset.map(val_preprocess,num_threads=num_workers, output_buffer_size=batch_size)
batched_val_dataset = val_dataset.batch(batch_size)

In [12]:
# Define an iterator that we'll use to automatically provide input during training/validation
iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types, batched_train_dataset.output_shapes)
images, labels = iterator.get_next()

In [13]:
# Calling these operations will allow the generator to provide values each run
train_init_op = iterator.make_initializer(batched_train_dataset)
val_init_op = iterator.make_initializer(batched_val_dataset)

In [14]:
# Placeholder for VGG flag
is_training = tf.placeholder(tf.bool)   # define training or testing mode

In [15]:
# get the vgg-16 model via the slim framework
with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=weight_decay)): # Unsure what this does...
    logits, _ = vgg.vgg_16(images, num_classes=num_classes, is_training=is_training, dropout_keep_prob=dropout_keep_prob)


In [16]:
# load the pretrained model weights off vgg
with tf.name_scope('Model'):
    # Gather the variables currently defined (meaning those of the VGG model), except for those from the last layer.
    variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['vgg_16/fc8'])
    # Create a function that will restore the variables from the checkpoint when called
    init_vgg = tf.contrib.framework.assign_from_checkpoint_fn(model_path, variables_to_restore)

In [17]:
with tf.name_scope('FC8_new'):
    # randomly initialize the weights of the final layer of the vgg-model
    fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
    fc8_init = tf.variables_initializer(fc8_variables) # Retrieve the initializer for the last layer

with tf.name_scope('loss'):
    # define the loss function
    tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    loss = tf.losses.get_total_loss()

with tf.name_scope('SGD_fc8'):
    # set the learning rate and the optimization model for the final layer
    fc8_optimizer = tf.train.GradientDescentOptimizer(learning_rate1)
    fc8_train_op = fc8_optimizer.minimize(loss, var_list=fc8_variables)

with tf.name_scope('SGD_all'):
    # set the learning rate for the general model (very small value as we are not retraining the mode from scratch)
    full_optimizer = tf.train.GradientDescentOptimizer(learning_rate2)
    full_train_op = full_optimizer.minimize(loss)

with tf.name_scope('Accuracy'):
    # Define the accuracy calculation of the trained model
    prediction = tf.to_int32(tf.argmax(logits, 1))
    correct_prediction = tf.equal(prediction, labels)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [18]:
# Tensorboard logging

 # Create a summary to monitor loss tensor
tf.summary.scalar("loss", loss)

# Create a summary to monitor accuracy tensor
tf.summary.scalar("accuracy", accuracy)

# Merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()

In [19]:
# create a log file to store all the data and summaries
summary_writer  = tf.summary.FileWriter(log_path, tf.get_default_graph())

In [20]:
# Create a session
sess = tf.InteractiveSession()

In [21]:
# Initialize all variables
init_vgg(sess)  # load the pretrained weights
sess.run(fc8_init)  # Freshly initialize last new layer of the network

INFO:tensorflow:Restoring parameters from ../pretrained/vgg/vgg_16.ckpt


In [22]:
'''
Runs the model on the entire dataset and returns the accuracy of the predictions.
'''
def check_accuracy(sess, correct_prediction, is_training, dataset_init_op):
    # Method to get the training or validation (is_training=False), accuracy
    sess.run(dataset_init_op)
    num_correct, num_samples = 0, 0
    while True:
        try:
            correct_pred = sess.run(correct_prediction, {is_training: False})   #get the predicitons
            num_correct += correct_pred.sum()
            num_samples += correct_pred.shape[0]
        except tf.errors.OutOfRangeError:
            break
    acc = float(num_correct) / num_samples
    return acc

In [23]:
for epoch in range(num_epochs1):
    # Run an epoch over the training data.
    print('Starting epoch %d / %d' % (epoch + 1, num_epochs1))
    start = time.perf_counter()
    sess.run(train_init_op)
    while True:
        try:
            _, summary  = sess.run([fc8_train_op, merged_summary_op], {is_training: True})
            summary_writer.add_summary(summary, epoch)
        except tf.errors.OutOfRangeError:
            break

    # Check accuracy on the train and val sets every epoch.
    train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
    val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
    print('Train accuracy: %f' % train_acc)
    print('Val accuracy: %f' % val_acc)
    print('Time for this epoch: %.0f seconds\n' % (time.perf_counter() - start))

    # Here you can add some code to do early stopping: validation accuracy is decreasing, while the trainings accuracy is increasing

Starting epoch 1 / 15
Train accuracy: 0.940231
Val accuracy: 0.918072
Time for this epoch: 77 seconds

Starting epoch 2 / 15
Train accuracy: 0.951756
Val accuracy: 0.946988
Time for this epoch: 76 seconds

Starting epoch 3 / 15
Train accuracy: 0.966497
Val accuracy: 0.954217
Time for this epoch: 76 seconds

Starting epoch 4 / 15
Train accuracy: 0.966229
Val accuracy: 0.966265
Time for this epoch: 76 seconds

Starting epoch 5 / 15
Train accuracy: 0.972930
Val accuracy: 0.966265
Time for this epoch: 77 seconds

Starting epoch 6 / 15
Train accuracy: 0.976146
Val accuracy: 0.966265
Time for this epoch: 76 seconds

Starting epoch 7 / 15
Train accuracy: 0.980702
Val accuracy: 0.966265
Time for this epoch: 76 seconds

Starting epoch 8 / 15
Train accuracy: 0.981238
Val accuracy: 0.973494
Time for this epoch: 76 seconds

Starting epoch 9 / 15
Train accuracy: 0.974002
Val accuracy: 0.975904
Time for this epoch: 76 seconds

Starting epoch 10 / 15
Train accuracy: 0.983919
Val accuracy: 0.975904
Ti

In [24]:
# Train the entire model for a few more epochs, continuing with the *same* weights.
for epoch in range(num_epochs2):
    print('Starting epoch %d / %d' % (epoch + 1, num_epochs2))
    start = time.perf_counter()
    sess.run(train_init_op)
    while True:
        try:
            _, summary = sess.run([full_train_op,merged_summary_op], {is_training: True})
            summary_writer.add_summary(summary, epoch + num_epochs1)

        except tf.errors.OutOfRangeError:
            break

    # Check accuracy on the train and val sets every epoch
    train_acc = check_accuracy(sess, correct_prediction, is_training, train_init_op)
    val_acc = check_accuracy(sess, correct_prediction, is_training, val_init_op)
    print('Train accuracy: %f' % train_acc)
    print('Val accuracy: %f' % val_acc)
    print('Time for this epoch: %.0f seconds\n' % (time.perf_counter() - start))

    # Here you can add some code to do early stopping: validation accuracy is decreasing, while the trainings accuracy is increasing


Starting epoch 1 / 15
Train accuracy: 0.989547
Val accuracy: 0.978313
Time for this epoch: 169 seconds

Starting epoch 2 / 15
Train accuracy: 0.989547
Val accuracy: 0.985542
Time for this epoch: 168 seconds

Starting epoch 3 / 15
Train accuracy: 0.989547
Val accuracy: 0.985542
Time for this epoch: 168 seconds

Starting epoch 4 / 15
Train accuracy: 0.992227
Val accuracy: 0.985542
Time for this epoch: 168 seconds

Starting epoch 5 / 15
Train accuracy: 0.991423
Val accuracy: 0.987952
Time for this epoch: 168 seconds

Starting epoch 6 / 15
Train accuracy: 0.991691
Val accuracy: 0.987952
Time for this epoch: 168 seconds

Starting epoch 7 / 15
Train accuracy: 0.990351
Val accuracy: 0.987952
Time for this epoch: 168 seconds

Starting epoch 8 / 15
Train accuracy: 0.990887
Val accuracy: 0.990361
Time for this epoch: 168 seconds

Starting epoch 9 / 15
Train accuracy: 0.990619
Val accuracy: 0.990361
Time for this epoch: 168 seconds

Starting epoch 10 / 15
Train accuracy: 0.991423
Val accuracy: 0.

In [25]:
# Finally it is important to save the trained model
# Create a saver object which will save the weights and variables of the trained model
saver = tf.train.Saver()
model_name = 'vgg-finetuned-%s-%s' % (num_epochs1, num_epochs2)
saver.save(sess, os.path.join(save_path, model_name))

'../pretrained/vgg-finetuned/vgg-finetuned-15-15'

In [26]:
sess.close()