In [15]:
import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging
from preprocessing import inception_preprocessing
from nets.inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
from nets import inception
import os
import time
from tensorflow.contrib import slim

from datasets import dataset_utils

In [2]:
#================ DATASET INFORMATION ======================
#State dataset directory where the tfrecord files are located
dataset_dir = 'drivers_data'

#State where your log file is at. If it doesn't exist, create it.
log_dir = 'log'

if not tf.gfile.Exists(log_dir):
    tf.gfile.MakeDirs(log_dir)

#State where your checkpoint file is
url = "http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz"
checkpoints_dir = 'checkpoints'
checkpoint_file = 'checkpoints/inception_resnet_v2_2016_08_30.ckpt'

if not tf.gfile.Exists(checkpoints_dir):
    tf.gfile.MakeDirs(checkpoints_dir)

if not tf.gfile.Exists(checkpoint_file):
    dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir)

#State the image size you're resizing your images to. We will use the default inception size of 299.
image_size = inception.inception_resnet_v2.default_image_size

#State the number of classes to predict:
num_classes = 10

#State the labels file and read it
labels_file = 'drivers_data/labels.txt'
labels = open(labels_file, 'r')

#Create a dictionary to refer each label to their string name
labels_to_name = {}
for line in labels:
    label, string_name = line.split(':')
    string_name = string_name[:-1] #Remove newline
    labels_to_name[int(label)] = string_name

# Create the file pattern of your TFRecord files so that it could be recognized later on
# file_pattern = 'drivers_%s_*.tfrecord'
file_pattern = 'drivers_{}_*.tfrecord'

#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
items_to_descriptions = {
    'image': 'A 3-channel RGB coloured driver image.',
    'label': 'A label of status of driver -- c0, c1, c2, c3, c4, c5, c6, c7, c8, c9'
}

In [3]:
#================= TRAINING INFORMATION ==================
# State the number of epochs to train
num_epochs = 10

# State your batch size
batch_size = 32

# Learning rate information and configuration (Up to you to experiment)
initial_learning_rate = 0.001
learning_rate_decay_factor = 0.7
num_epochs_before_decay = 2

In [None]:
# We now create a function that creates a Dataset class which will give us many TFRecord files to feed in the examples into a queue in parallel.
def get_split(split_name, dataset_dir, file_pattern=file_pattern, file_pattern_for_counting='drivers'):
    '''
    Obtains the split - training or validation - to create a Dataset class for feeding the examples into a queue later on. This function will
    set up the decoder and dataset information all into one Dataset class so that you can avoid the brute work later on.
    Your file_pattern is very important in locating the files later. 
    INPUTS:
    - split_name(str): 'train' or 'validation'. Used to get the correct data split of tfrecord files
    - dataset_dir(str): the dataset directory where the tfrecord files are located
    - file_pattern(str): the file name structure of the tfrecord files in order to get the correct data
    - file_pattern_for_counting(str): the string name to identify your tfrecord files for counting
    OUTPUTS:
    - dataset (Dataset): A Dataset class object where we can read its various components for easier batch creation later.
    '''

    # First check whether the split_name is train or validation
    if split_name not in ['train', 'validation']:
        raise ValueError('split name {} was not recognized.'.format(split_name))

    #Create the full path for a general file_pattern to locate the tfrecord_files
    file_pattern_path = os.path.join(dataset_dir, file_pattern.format(split_name))

    #Count the total number of examples in all of these shard
    num_samples = 0
    file_pattern_for_counting = file_pattern_for_counting + '_' + split_name
    tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if file.startswith(file_pattern_for_counting)]
    for tfrecord_file in tfrecords_to_count:
        for record in tf.python_io.tf_record_iterator(tfrecord_file):
            num_samples += 1

    #Create a reader, which must be a TFRecord reader in this case
    reader = tf.TFRecordReader

    #Create the keys_to_features dictionary for the decoder
    keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }

    #Create the items_to_handlers dictionary for the decoder.
    items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }

    #Start to create the decoder
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

    #Create the labels_to_name file
    labels_to_name_dict = labels_to_name

    #Actually create the dataset
    dataset = slim.dataset.Dataset(
        data_sources = file_pattern_path,
        decoder = decoder,
        reader = reader,
        # num_readers = 4,
        num_samples = num_samples,
        num_classes = num_classes,
        labels_to_name = labels_to_name_dict,
        items_to_descriptions = items_to_descriptions)

    return dataset

In [None]:
def load_batch(dataset, batch_size, height=image_size, width=image_size, is_training=True):
    '''
    Loads a batch for training.
    INPUTS:
    - dataset(Dataset): a Dataset class object that is created from the get_split function
    - batch_size(int): determines how big of a batch to train
    - height(int): the height of the image to resize to during preprocessing
    - width(int): the width of the image to resize to during preprocessing
    - is_training(bool): to determine whether to perform a training or evaluation preprocessing
    OUTPUTS:
    - images(Tensor): a Tensor of the shape (batch_size, height, width, channels) that contain one batch of images
    - labels(Tensor): the batch's labels with the shape (batch_size,) (requires one_hot_encoding).
    '''
    #First create the data_provider object
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        # common_queue_capacity = 24 + 3 * batch_size,
        common_queue_capacity = 2 * batch_size,
        common_queue_min = 24)

    #Obtain the raw image using the get method
    raw_image, label = data_provider.get(['image', 'label'])

    #Perform the correct preprocessing for this image depending if it is training or evaluating
    image = inception_preprocessing.preprocess_image(raw_image, height, width, is_training)

    #As for the raw images, we just do a simple reshape to batch it up
    raw_image = tf.expand_dims(raw_image, 0)
    raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
    raw_image = tf.squeeze(raw_image)

    #Batch up the image by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.
    images, raw_images, labels = tf.train.batch(
        [image, raw_image, label],
        batch_size=batch_size,
        # num_threads = 4,
        num_threads=1
        # capacity = 4 * batch_size,
        capacity=2 * batch_size
        allow_smaller_final_batch = True)

    return images, raw_images, labels

build graph and train

In [None]:
with tf.Graph().as_default() as graph:
    tf.logging.set_verbosity(tf.logging.INFO)
    
    # creat dataset and load batches
    dataset = get_split('train', dataset_dir, file_pattern=file_pattern)
    imges, _, labels = load_batch(dataset, batch_size=batch_size)
    
    num_batches_per_epoch = int(dataset.num_samples / batch_size)
    num_steps_per_epoch = num_batches_per_epoch
    decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)
    
    # creat the model
    with slim.arg_scope(inception_resnet_v2_arg_scope()):
        logits, end_points = inception_resnet_v2(images, num_classes=dataset.num_classes, is_training=True)
    
    # define scopes to excluded
    exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
    variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
    
    # one-hot-encodeing of the labels
    one_hot_labels = slim.one_hot_encoding(labels, dadtaset.num_classes)
    
    # calculate loss
    loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)
    total_loss = tf.losses.get_total_loss()
    
    # create global step for mornitoring
    global_step = get_or_create_global_step()
    
    # define decaying learning rate
    learning_rate = tf.train.exponential_decay(learning_rate=initial_learning_rate, 
                                               global_step=global_step, 
                                               decay_steps=decay_steps, 
                                               decay_rate=learning_rate_decay_factor, 
                                               staircase=True)
    
    # optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    
    # create training operator
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    
    # the predictions
    probabilities = end_points['Predictions']
    predictions = tf.argmax(end_points['Predictions'], 1)
    accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
    metrics_op = tf.group(accuracy_update, probabilities)
    
    # summaries
    tf.summary.scalar('losses/Total_Loss', total_loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('learning_rate', lr)
    my_summary_op = tf.summary.merge_all()
    
    # define training step function
    def train_step(sess, train_op, global_step):
        start_time = time.time()
        total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op])
        time_elapsed = time.time() - start_time
        
        # logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)
        
        return total_loss, global_step_count
    
    # create saver function to restore variables from a checkpoint file
    saver = tf.train.Saver(variables_to_restore)
    def restore_fn(sess):
        return saver.restore(sess, checkpoint_file)
    
    sv = tf.train.Supervisor(logdir=log_dir, summary_op=None, init_fn=restore_fn)
    
    with sv.managed_session() as sess:
        for step in xrange(num_steps_per_epoch * num_epochs):
            if step % num_batches_per_epoch == 0:
                logging.info('Epoch {}/{}'.format(step/num_batches_per_epoch + 1, num_epochs))
                learning_rate_value, accuracy_value = sess.run([lr, accuracy])
                logging.info('Current Learning Rate: {}'.format(learning_rate_value))
                logging.info('Current Streaming Accuracy: {}'.format(accuracy_value))
                
            if step % 10 == 0:
                loss, step = train_step(sess, train_op, sv.global_step)
                summaries = sess.run(my_summary_op)
                sv.summary_computed(sess, summaries)
                logging.info('global step {}: loss: {}'.format(step, loss))
            else:
                loss, _ = train_step(sess, train_op, sv.global_step)
                
        logging.info('Final Loss: {}'.format(loss))
        logging.info('Final Accuracy: {}'.format(sess.run(accuracy)))
        
        logging.info('Training finished! Saving model to disk.')
        sv.saver.save(sess, sv,save_path, global_step=sv.global_step)