In [1]:
import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging

from preprocessing.preprocessing_factory import get_preprocessing
from nets import nets_factory

import os
import time
from tensorflow.contrib import slim

from datasets import dataset_utils
from checkpoints_downloader import ckpt_maker
from dataset_preparation import get_split, load_batch

In [2]:
MODEL = 'resnet_v2_50'

In [3]:
checkpoint_file = ckpt_maker(MODEL)

>> Downloading resnet_v2_50_2017_04_14.tar.gz 100.0%
Successfully downloaded resnet_v2_50_2017_04_14.tar.gz 286441851 bytes.
Checkpoint for resnet_v2_50 is ready!
File name: checkpoints/resnet_v2_50.ckpt


In [4]:
#================ DATASET INFORMATION ======================
#State dataset directory where the tfrecord files are located
dataset_dir = '/drivers_data'

#State where your log file is at. If it doesn't exist, create it.
log_dir = 'log/' + MODEL

if not tf.gfile.Exists(log_dir):
    tf.gfile.MakeDirs(log_dir)

#State the number of classes to predict:
num_classes = 10

# #State the labels file and read it
# labels_file = 'drivers_data/labels.txt'
# labels = open(labels_file, 'r')

# #Create a dictionary to refer each label to their string name
# labels_to_name = {}
# for line in labels:
#     label, string_name = line.split(':')
#     string_name = string_name[:-1] #Remove newline
#     labels_to_name[int(label)] = string_name

# Create the file pattern of your TFRecord files so that it could be recognized later on
# file_pattern = 'drivers_%s_*.tfrecord'
file_pattern = 'drivers_{}_*.tfrecord'

#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
items_to_descriptions = {
    'image': 'A 3-channel RGB coloured driver image.',
    'label': 'A label from 0 to 9.'
}

In [5]:
# The model for training
model_train = nets_factory.get_network_fn(MODEL, num_classes, is_training=True)

# The model for evaluation
model_eval = nets_factory.get_network_fn(MODEL, num_classes, is_training=False)

In [6]:
# State the image size you're resizing your images to. 
image_size = model_train.default_image_size

In [10]:
#================= TRAINING INFORMATION ==================
# State the number of epochs to train
num_epochs = 30

# State your batch size
batch_size = 32

# Learning rate information and configuration (Up to you to experiment)
initial_learning_rate = 0.001
learning_rate_decay_factor = 0.7
num_epochs_before_decay = 2

In [None]:
with tf.Graph().as_default() as graph:
    tf.logging.set_verbosity(tf.logging.INFO)
    
    # creat dataset and load batches
    dataset = get_split('train', dataset_dir, file_pattern=file_pattern)
    images, _, labels = load_batch(dataset, 
                                   batch_size=batch_size, 
                                   MODEL=MODEL, 
                                   height=image_size, 
                                   width=image_size, 
                                   is_training=True)
    
    num_batches_per_epoch = int(dataset.num_samples / batch_size)
    num_steps_per_epoch = num_batches_per_epoch
    decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)
    
    # creat the model
    logits, end_points = model_train(images)
    
    # define scopes to excluded
    exclude = ['resnet_v2_50/logits']
    variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
    
    # one-hot-encodeing of the labels
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    
    # calculate loss
    loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)
    total_loss = tf.losses.get_total_loss()
    
    # create global step for mornitoring
    global_step = get_or_create_global_step()
    
    # define decaying learning rate
    learning_rate = tf.train.exponential_decay(learning_rate=initial_learning_rate, 
                                               global_step=global_step, 
                                               decay_steps=decay_steps, 
                                               decay_rate=learning_rate_decay_factor, 
                                               staircase=True)
    
    # optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    
    # create training operator
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    
    # the predictions
    probabilities = end_points['predictions']
    predictions = tf.argmax(end_points['predictions'], 1)
    accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
    metrics_op = tf.group(accuracy_update, probabilities)
    
    # summaries
    tf.summary.scalar('losses/Total_Loss', total_loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('learning_rate', learning_rate)
    my_summary_op = tf.summary.merge_all()
    
    # define training step function
    def train_step(sess, train_op, global_step):
        start_time = time.time()
        total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op])
        time_elapsed = time.time() - start_time
        
        # logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)
        
        return total_loss, global_step_count
    
    # create saver function to restore variables from a checkpoint file
    saver = tf.train.Saver(variables_to_restore)
    def restore_fn(sess):
        return saver.restore(sess, checkpoint_file)
    
    sv = tf.train.Supervisor(logdir=log_dir, summary_op=None, init_fn=restore_fn)
    
    with sv.managed_session() as sess:
        for step in range(num_steps_per_epoch * num_epochs):
            if step % num_batches_per_epoch == 0:
                logging.info('Epoch {}/{}'.format(step/num_batches_per_epoch + 1, num_epochs))
                learning_rate_value, accuracy_value = sess.run([learning_rate, accuracy])
                logging.info('Current Learning Rate: {}'.format(learning_rate_value))
                logging.info('Current Streaming Accuracy: {}'.format(accuracy_value))
                
            if step % 10 == 0:
                loss, step = train_step(sess, train_op, sv.global_step)
                summaries = sess.run(my_summary_op)
                sv.summary_computed(sess, summaries)
                logging.info('global step {}: loss: {}'.format(step, loss))
            else:
                loss, _ = train_step(sess, train_op, sv.global_step)
                
        logging.info('Final Loss: {}'.format(loss))
        logging.info('Final Accuracy: {}'.format(sess.run(accuracy)))
        
        logging.info('Training finished! Saving model to disk.')
        sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

INFO:tensorflow:Scale of 0 disables regularizer.
INFO:tensorflow:Restoring parameters from log/resnet_v2_50/model.ckpt-4565
INFO:tensorflow:Epoch 1.0/30
INFO:tensorflow:Current Learning Rate: 0.0002401000092504546
INFO:tensorflow:Current Streaming Accuracy: 0.0
INFO:tensorflow:global step 4567: loss: 0.07189057022333145
INFO:tensorflow:global step 4577: loss: 0.28046271204948425
INFO:tensorflow:global step 4587: loss: 0.13434793055057526
INFO:tensorflow:global step 4597: loss: 0.07863301783800125
INFO:tensorflow:global step 4607: loss: 0.36539894342422485
INFO:tensorflow:global step 4617: loss: 0.26720091700553894
INFO:tensorflow:global step 4627: loss: 0.11038193106651306
INFO:tensorflow:global step 4637: loss: 0.2415679693222046
INFO:tensorflow:global step 4647: loss: 0.206980362534523
INFO:tensorflow:global step 4657: loss: 0.16673707962036133
INFO:tensorflow:global step 4667: loss: 0.16701248288154602
INFO:tensorflow:global step 4677: loss: 0.21359717845916748
INFO:tensorflow:globa