In [2]:
import os 
import tensorflow as tf
import math
import numpy as np
from collections import OrderedDict
import re
import glob


In [2]:
''' Supporting Functions for Data  '''

def get_files_in_folder(path, file_extensions, skip_folder, use_subfolder):
    """
    """
    files = []
    for ext in file_extensions:
        pattern = os.path.join(path, '*.' + ext)
        files += glob.glob(pattern)

    # check subfolder
    if use_subfolder:
        for subfolder in os.listdir(path):
            subfolder_path = os.path.join(path, subfolder)
            if os.path.isdir(subfolder_path) and not (subfolder in skip_folder):
                files += get_files_in_folder(subfolder_path, file_extensions, skip_folder, use_subfolder)
    return files

def load_images_by_subfolder(root_dir, validation_ratio, skip_folder=[], use_subfolder=False):
    file_extensions =  ['jpeg'] # ['jpg', 'jpeg']
    return load_by_subfolder(root_dir, file_extensions, validation_ratio, skip_folder, use_subfolder)

def load_features_by_subfolder(root_dir, validation_ratio, skip_folder=[], use_subfolder=False, train_dirs = []):
    file_extensions =  ['txt']
    shared = load_by_subfolder(root_dir, file_extensions, validation_ratio, skip_folder, use_subfolder)
    
    # ####################
    # This will add extra training data, in case more training directories are provided.
    # 
    # Depending on the filename (which should be *md5hash*.txt), 
    # it will try to make sure, that no validation feature is used for training
    # ###################
    if train_dirs:
        validation_files = [os.path.basename(p) for p in shared['validation_paths']] 
        for train_dir in train_dirs:
            print('=> Adding trainingdata from: {}'.format(train_dir))
            train_only = load_by_subfolder(train_dir, file_extensions, 0, skip_folder, use_subfolder)
            # Make sure both have the same labels and the same label order
            if train_only['labels'] == shared['labels']:
                for i in range(train_only['training_count']):
                    # Filter features that are already in the validation group
                    if os.path.basename(train_only['training_paths'][i]) not in validation_files:
                        shared['training_paths'].append(train_only['training_paths'][i])
                        shared['training_labels'].append(train_only['training_labels'][i])
                        shared['training_count'] += 1

        print('=> Final dataset: {} Training, {} Validation'.format(shared['training_count'], shared['validation_count']))
        print('')

    return shared

def load_by_subfolder(root_dir, file_extensions, validation_ratio, skip_folder=[], use_subfolder=False):
    """
    Create a list of labeled data, seperated in training and validation sets.
    Will create a new label/class for every sub-directory in the 'root_dir'.

    Args:
        root_dir: String path to a folder containing subfolders with data.
        validation_ratio: How much of the data should go into the validation set

    Returns:
        A dictionary containing an entry for each subfolder/class, 
        with paths split into training and validation sets.
    """
    labels = []
    training_paths = []
    training_labels = []
    validation_paths = []
    validation_labels = []
        
    for folder in os.listdir(root_dir):
        folder_path = os.path.join(root_dir, folder)
        if not os.path.isdir(folder_path) or folder in skip_folder:
            continue # skip files and skipped folders

        paths = get_files_in_folder(folder_path, file_extensions, skip_folder, use_subfolder)
        if not paths:
            continue # skip empty directories

        total_count = len(paths)
        # split the list into traning and validation
        label = re.sub(r'[^a-z0-9]+', ' ', folder.lower())
        if (validation_ratio > 0):
            paths_sub = paths[::validation_ratio]
            del paths[::validation_ratio]
        else:
            paths_sub = []

        # print infos
        print('=> Label: {} ({}) [Files: {} Total, {} Training, {} Validation]'.format(
            label,
            len(labels),
            total_count,
            len(paths),
            len(paths_sub)
        ))

        # add entries to the result
        labels.append(label)
        label_index = len(labels) - 1
        training_paths += paths
        training_labels += [label_index] * len(paths)
        validation_paths += paths_sub
        validation_labels += [label_index] * len(paths_sub)

    print('')
    return {
        'labels': labels,
        'training_count': len(training_paths),
        'training_paths': training_paths,
        'training_labels': training_labels,
        'validation_count': len(validation_paths),
        'validation_paths': validation_paths,
        'validation_labels': validation_labels
    }

def load_by_file(file, validation_ratio):
    """
    Create a list of labeled data, seperated in training and validation sets.
    Reads the given 'file' line by line, where every line has the format *label* *data_path*

    Args:
        file: File with a list of labels and data path
        validation_ratio: How much of the data should go into the validation set

    Returns:
        A dictionary containing an entry for each subfolder/class, 
        with paths split into training and validation sets.
    """
    # Groupe the paths by label
    labeled_paths = {}
    f = open(file)
    for line in f: 
        line = line.strip().split(' ')
        label = int(line[1].rstrip())
        path = line[0].rstrip()
        if label in labeled_paths:
            labeled_paths[label].append(path)
        else:
            labeled_paths[label] = [path]

    # Seperate them into training and validation sets 
    labels = []
    training_paths = []
    training_labels = []
    validation_paths = []
    validation_labels = []
    for label, paths in labeled_paths.items():
        print('Data with label \'%s\'' %label)
        print('=> Found %i entries' %len(paths))

        # split the list into traning and validation
        if (validation_ratio > 0):
            paths_sub = paths[::validation_ratio]
            del paths[::validation_ratio]
        else:
            paths_sub = []

        # print infos
        print('  => Training: %i' %len(paths))
        print('  => Validation %i' %len(paths_sub))

        # add entries to the result
        labels.append(label)
        label_index = len(labels) - 1
        training_paths += paths
        training_labels += [label_index] * len(paths)
        validation_paths += paths_sub
        validation_labels += [label_index] * len(paths_sub)

    return {
        'labels': labels,
        'training_count': len(training_paths),
        'training_paths': training_paths,
        'training_labels': training_labels,
        'validation_count': len(validation_paths),
        'validation_paths': validation_paths,
        'validation_labels': validation_labels
    }


In [3]:
'''  Supporting operations for calculating accuracy and loss'''

from tensorflow.data import Iterator


def get_validation_ops(scores, true_classes):
    """Inserts the operations we need to evaluate the accuracy of our results.

    Args:
        scores: The new final node that produces results
        true_classes: The node we feed the true classes in
    Returns:
        Evaluation operation: defining the accuracy of the model
    """
    with tf.name_scope("accuracy"):
        predicted_index = tf.argmax(scores, 1)
        true_index = tf.argmax(true_classes, 1)
        correct_pred = tf.equal(predicted_index, true_index)
        accuracy_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return accuracy_op, correct_pred, predicted_index, true_index


def get_train_op(loss, learning_rate, train_vars, use_adam_optimizer=False):
    """Inserts the training operation
    Creates an optimizer and applies gradient descent to the trainable variables
    Check: https://www.tensorflow.org/versions/r0.12/api_docs/python/train/optimizers

    Args:
        loss: the cross entropy mean (scors <> real class)
        train_vars: list of all trainable variables
    Returns:
        Traning/optizing operation
    """
    with tf.name_scope("train"):
        if use_adam_optimizer:
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)

        train_op = optimizer.minimize(loss, var_list=train_vars)
        # --> minimize() = combines calls compute_gradients() and apply_gradients()
    return train_op


def get_loss_op(scores, true_classes):
    """Inserts the operations which calculates the loss.

    Args:
        scores: The final node that produces results
        true_classes: The node we feed the true classes in
    Returns: loss operation
    """
    # Op for calculating the loss
    with tf.name_scope("cross_entropy"):
        # sm = tf.nn.softmax(scores)
        # total_loss = true_classes * tf.log(sm)
        # loss = -(tf.reduce_mean(total_loss))
        #
        # softmax_cross_entropy_with_logits 
        # --> calculates the cross entropy between the softmax score (probaility) and hot encoded class expectation (all "0" except one "1") 
        # reduce_mean 
        # --> computes the mean of elements across dimensions of a tensor (cross entropy values here)
        #
        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=true_classes))
    return loss_op


def get_dataset_ops(data_train, data_val, batch_size, train_size, val_size, shuffle=True):
    """
    """
    # shuffle the dataset and create batches
    if shuffle:
        data_train = data_train.shuffle(train_size)
        data_val   = data_val.shuffle(val_size)
    
    data_train = data_train.batch(batch_size)
    data_val   = data_val.batch(batch_size)

    # create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(data_train.output_types, data_train.output_shapes)
    next_batch = iterator.get_next()

    # Ops for initializing the two different iterators
    init_op_train = iterator.make_initializer(data_train)
    init_op_val   = iterator.make_initializer(data_val)

    return init_op_train, init_op_val, next_batch


In [4]:
'''  Utils '''


from datetime import datetime
from random import randint  
from tensorflow.python.ops.nn_ops import softmax
    
def save_session_to_checkpoint_file(sess, saver, epoch, path):
    """
    """
    checkpoint = os.path.join(path, datetime.now().strftime("%m%d_%H%M%S_") + 'model_epoch' + str(epoch+1) + '.ckpt')
    saver.save(sess, checkpoint)
    print("Model checkpoint saved at {}".format(checkpoint))

def get_misclassified(corr_pred, paths, pred_index, true_index, scores):
    """
    Returns: a list of tupels (path, predicted label, true label)
    """
    misclassified = []
    for i, correct in enumerate(corr_pred):
        if not correct:
            misclassified.append((paths[i], pred_index[i], true_index[i], scores[i]))

    return misclassified

def print_misclassified(sess, misclassified, labels):
    """
    """
    print("----------------------------------------------------------------")
    print("")
    print("=> Misclassified: %i" %len(misclassified))
    print("================================================================")
    for (path, pred_index, true_index, score) in misclassified:
        smax = sess.run(softmax(score))

        print("{} | {} ({}) | {}".format(
            path,
            labels[pred_index],
            labels[true_index],
            smax
        ))
    print("================================================================")

def print_output_header(train_count, val_count):
    """
    """
    print("=> Getting loss and accuracy for:")
    print("  => {} training entries".format(train_count))
    print("  => {} validation entries".format(val_count))
    print("")
    print(" Ep  |   Time   |   T Loss   |   V Loss   |  T Accu. |  V Accu.")
    print("----------------------------------------------------------------")

def print_output_epoch(epoch, train_loss, train_acc, test_loss, test_acc):
    """
    """
    print("{:4.0f} | {} | {:.8f} | {:.8f} | {:.6f} | {:.6f}".format(
        epoch,
        datetime.now().strftime("%H:%M:%S"),
        train_loss,
        test_loss,
        train_acc,
        test_acc
    ))

def run_training(sess, train_op, loss_op, accuracy_op, iterator_op, get_next_batch_op, ph_data, ph_labels, ph_keep_prob, keep_prob, batches):
    """
    Args:
        sess:
        loss_op,
        train_op:
        accuracy_op:
        iterator_op:
        get_next_batch_op:
        ph_data:
        ph_labels:
        ph_keep_prob:
        keep_prob:
        batches:
    """
    # Variables to keep track over different batches
    acc = 0.
    loss = 0.
    # use_batch_for_crossvalidation = randint(0, batches - 2)
    # -2 -> -1 = we start at index 0 / -1 we don't want to use the last batch, it might be smaller

    sess.run(iterator_op)
    for batch_step in range(batches):
        # Get next batch of data and run the training operation
        data_batch, label_batch, _ = sess.run(get_next_batch_op)
        _, batch_loss, batch_acc = sess.run(
            [train_op, loss_op, accuracy_op],
            feed_dict={ph_data: data_batch, ph_labels: label_batch, ph_keep_prob: keep_prob}
        )
        loss += batch_loss
        acc += batch_acc

    acc /= batches
    loss /= batches
    return loss, acc


def run_validation(sess, loss_op, accuracy_op, correct_prediction_op, predicted_index_op, true_index_op, final_op,
                   iterator_op, get_next_batch_op, ph_data, ph_labels, ph_keep_prob, batches, return_misclassified = False):
    """
    Args:
        sess:
        accuracy_op:
        predicted_index_op:
        iterator_op:
        get_next_batch_op:
        ph_data:
        ph_labels:
        ph_keep_prob:
        batches:
        return_misclassified:
        data
    """
    # Variables to keep track over different batches
    acc = 0.
    loss = 0.
    misclassified = []

    sess.run(iterator_op)
    for _ in range(batches):
        img_batch, label_batch, paths = sess.run(get_next_batch_op)
        scores, batch_loss, batch_acc, corr_pred, pred_index, true_index = sess.run(
            [final_op, loss_op, accuracy_op, correct_prediction_op, predicted_index_op, true_index_op],
            feed_dict={ph_data: img_batch, ph_labels: label_batch, ph_keep_prob: 1.}
        )
        loss += batch_loss
        acc += batch_acc
        
        if return_misclassified:
            misclassified += get_misclassified(corr_pred, paths, pred_index, true_index, scores)
    
    acc /= batches
    loss /= batches
    return loss, acc, misclassified


In [9]:
''' Retrainer Function '''

from tensorflow.data import Dataset


class Retrainer(object):
    """
    Retrain (Finetune) a given model on a new set of categories 
    """

    def __init__(self, model_def, data, image_prep=None, write_checkpoints = False):
        self.model_def = model_def
        self.data = data
        self.num_classes = len(data['labels'])
        self.image_prep = image_prep if image_prep else model_def.image_prep # overwrite the default model image prep?
        self.write_checkpoints = write_checkpoints

    @staticmethod
    def print_infos(train_vars, restore_vars, learning_rate, batch_size, keep_prob, use_adam):
        """Print infos about the current run

        Args:
            restore_vars: 
            train_vars:
            learning_rate:
            batch_size:
            keep_prob:
        """
        print("=> Will Restore:")
        for var in restore_vars:
            print("  => {}".format(var))
        print("=> Will train:")
        for var in train_vars:
            print("  => {}".format(var))
        print("")
        print("=> Learningrate: %.4f" %learning_rate)
        print("=> Batchsize: %i" %batch_size)
        print("=> Dropout: %.4f" %(1.0 - keep_prob))
        print("=> Using Adam Optimizer: %r" %use_adam)
        print("")

    def parse_data(self, path, label, is_training):
        """
        Args:
            path:
            label:
            is_training:

        Returns:
            image: image loaded and preprocesed
            label: converted label number into one-hot-encoding (binary)
        """
        # convert label number into one-hot-encoding
        one_hot = tf.one_hot(label, self.num_classes)

        # load the image
        img_file      = tf.read_file(path)
        img_decoded   = tf.image.decode_jpeg(img_file, channels=3)
        img_processed = self.image_prep.preprocess_image(
            image=img_decoded,
            output_height=self.model_def.image_size,
            output_width=self.model_def.image_size,
            is_training=is_training
        )
        return img_processed, one_hot, path

    def parse_train_data(self, path, label):
        return self.parse_data(path, label, True)

    def parse_validation_data(self, path, label):
        return self.parse_data(path, label, False)
    
    def create_dataset(self, is_training=True):
        """
        Args:
            is_training: Define what kind of Dataset should be returned (traning or validation)

        Returns: A Tensorflow Dataset with images and their labels. Either for training or validation.
        """
        paths = self.data['training_paths'] if is_training else self.data['validation_paths']
        labels = self.data['training_labels'] if is_training else self.data['validation_labels']
        dataset = Dataset.from_tensor_slices((
            tf.convert_to_tensor(paths, dtype=tf.string),
            tf.convert_to_tensor(labels, dtype=tf.int32)
        ))
        # load and preprocess the images
        if is_training:
            dataset = dataset.map(self.parse_train_data)
        else:
            dataset = dataset.map(self.parse_validation_data)

        return dataset

    ############################################################################
    def run(self, finetune_layers, epochs, learning_rate = 0.01, batch_size = 128, keep_prob = 1.0, memory_usage = 1.0, 
            device = '/gpu:0', save_ckpt_dir = '', init_ckpt_file = '', use_adam_optimizer=False):
        """
        Run a training on part of the model (retrain/finetune)

        Args:
            finetune_layers:
            epochs:
            learning_rate:
            batch_size:
            keep_prob:
            memory_usage:
            device:
            show_misclassified:
            validate_on_each_epoch:
            save_ckpt_dir:
            init_ckpt_file:
        """
        # create datasets
        data_train = self.create_dataset(is_training=True)
        data_val = self.create_dataset(is_training=False)

        # Get ops to init the dataset iterators and get a next batch
        init_train_iterator_op, init_val_iterator_op, get_next_batch_op = ops.get_dataset_ops(
            data_train,
            data_val,
            batch_size,
            self.data['training_count'],
            self.data['validation_count'],
            shuffle=True
        )

        # Initialize model and create input placeholders
        with tf.device(device):
            ph_images = tf.placeholder(tf.float32, [None, self.model_def.image_size, self.model_def.image_size, 3])
            ph_labels = tf.placeholder(tf.float32, [None, self.num_classes])
            # Could set the first placholder dimension to batch_size, but this wouldn't work with leftover data that not form a whole batch
            ph_keep_prob = tf.placeholder(tf.float32)

            model = self.model_def(ph_images, keep_prob=ph_keep_prob, num_classes=self.num_classes, retrain_layer=finetune_layers)
            final_op = model.get_final_op()
        
        # Get a list with all trainable variables and print infos for the current run
        retrain_vars = model.get_retrain_vars()
        restore_vars = model.get_restore_vars()
        self.print_infos(retrain_vars, restore_vars, learning_rate, batch_size, keep_prob, use_adam_optimizer)

        # Add/Get the different operations to optimize (loss, train and validate)
        with tf.device(device):
            loss_op = ops.get_loss_op(final_op, ph_labels)
            train_op = ops.get_train_op(loss_op, learning_rate, retrain_vars, use_adam_optimizer)
            accuracy_op, correct_prediction_op, predicted_index_op, true_index_op = ops.get_validation_ops(final_op, ph_labels)

        # Get the number of training/validation steps per epoch to get through all images
        batches_per_epoch_train = int(math.ceil(self.data['training_count'] / (batch_size + 0.0)))
        batches_per_epoch_val   = int(math.ceil(self.data['validation_count'] / (batch_size + 0.0)))

        # Initialize a saver, create a session config and start a session
        saver = tf.train.Saver()
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = memory_usage
        with tf.Session(config=config) as sess:
            # Init all variables 
            sess.run(tf.global_variables_initializer())
            
            # Load the pretrained variables or a saved checkpoint
            if init_ckpt_file:
                saver.restore(sess, init_ckpt_file)
            else: 
                model.load_initial_weights(sess)
            
            utils.print_output_header(self.data['training_count'], self.data['validation_count'])
            for epoch in range(epochs):
                is_last_epoch = True if (epoch+1) == epochs else False
                
                train_loss, train_acc = utils.run_training(
                    sess,
                    train_op,
                    loss_op,
                    accuracy_op,
                    init_train_iterator_op,
                    get_next_batch_op,
                    ph_images,
                    ph_labels,
                    ph_keep_prob,
                    keep_prob,
                    batches_per_epoch_train
                )

                return_misclassified = is_last_epoch
                test_loss, test_acc, misclassified = utils.run_validation(
                    sess,
                    loss_op,
                    accuracy_op,
                    correct_prediction_op,
                    predicted_index_op,
                    true_index_op,
                    final_op,
                    init_val_iterator_op,
                    get_next_batch_op,
                    ph_images,
                    ph_labels,
                    ph_keep_prob,
                    batches_per_epoch_val,
                    return_misclassified
                )
                
                utils.print_output_epoch(epoch + 1, train_loss, train_acc, test_loss, test_acc)
                
                # show missclassified list on last epoch
                if is_last_epoch:
                    utils.print_misclassified(sess, misclassified, self.data['labels'])

                # save session in a checkpoint file
                if self.write_checkpoints or is_last_epoch:
                    utils.save_session_to_checkpoint_file(sess, saver, epoch, save_ckpt_dir)


In [14]:
from alexnet.alexnet import AlexNet
from vgg.vgg import VGG
from inception_v3.inception_v3 import InceptionV3

In [11]:
''' FINTUNNING ... '''


# Input params
VALIDATION_RATIO = 5 # e.g. 5 -> every 5th element = 1/5 = 0.2 = 20%
USE_SUBFOLDER = True
SKIP_FOLDER = []

# Learning params
LEARNING_RATE = 0.005
NUM_EPOCHS = 20
BATCH_SIZE = 32

# Network params
KEEP_PROB = 1.0 # [0.5]
FINETUNE_LAYERS = ['fc6', 'fc7', 'fc8']
CHECKPOINT_DIR = '../checkpoints'

# HARDWARE USAGE
DEVICE = '/cpu:0'
MEMORY_USAGE = 1.0

def finetune(model_def, data, ckpt_dir, write_checkpoint_on_each_epoch, init_from_ckpt, use_adam_optimizer):
    """
    Args:
        model_def:
        data:
        show_misclassified:
        validate_on_each_epoch:
        ckpt_dir:
        write_checkpoint_on_each_epoch:
        init_from_ckpt:
        use_adam_optimizer:
    """
    trainer = Retrainer(model_def, data, write_checkpoint_on_each_epoch)
    trainer.run(
        FINETUNE_LAYERS,
        NUM_EPOCHS,
        LEARNING_RATE,
        BATCH_SIZE,
        KEEP_PROB,
        MEMORY_USAGE,
        DEVICE,
        ckpt_dir,
        init_from_ckpt,
        use_adam_optimizer
    )

def main():
    """
    Main
    """
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-image_dir',
        type=str,
        default='',
        help='Folder with trainings/validation images'
    )
    parser.add_argument(
        '-image_file',
        type=str,
        default='',
        help='File with a list of trainings/validation images and their labels'
    )
    parser.add_argument(
        '-write_checkpoint_on_each_epoch',
        default=False,
        help='Write a checkpoint file on each epoch (default is just once at the end',
        action='store_true' # whenever this option is set, the arg is set to true
    )
    parser.add_argument(
        '-init_from_ckpt',
        type=str,
        default='',
        help='Load this checkpoint file to continue training from this point on'
    )
    parser.add_argument(
        '-model',
        type=str,
        choices=['vgg', 'alex'],
        default='alex',
        help='Model to be validated. Default is AlexNet (alex)'
    )
    parser.add_argument(
        '-use_adam_optimizer',
        default=False,
        help='Use Adam optimizer instead of GradientDecent',
        action='store_true' # whenever this option is set, the arg is set to true
    )
    
    args = parser.parse_args()
    if args.image_dir:
        
        image_dir = args.image_dir
    else :
        image_dir = 'path to training/validation images'
    
    if args.image_file:
        
        image_file = args.image_file
    else :
        image_file = 'path to image_files'
    
    if args.write_checkpoint_on_each_epoch:
        
        write_checkpoint_on_each_epoch = args.write_checkpoint_on_each_epoch
    else :
        write_checkpoint_on_each_epoch = 'path to write_checkpoint_on_each_epoch'
    
    if args.init_from_ckpt:
        
        init_from_ckpt = args.init_from_ckpt
    else :
        init_from_ckpt = 'path to init_from_ckpt'
    
    if args.model:
        
        model_str = args.model
    else :
        model_str = 'path to model'
    
    if args.use_adam_optimizer:
        
        use_adam_optimizer = args.use_adam_optimizer
    else :
        use_adam_optimizer = 'use_adam_optimizer'
   

    # Load images
    if not image_dir and not image_file:
        print('Provide one of the following options to load images \'-image_file\' or \'-image_dir\'')
        return None
    elif image_dir: 
        if not os.path.exists(image_dir):
            print('Image root directory \'%s\' not found' %image_dir)
            return None
        else:
            data = data_provider.load_images_by_subfolder(image_dir, VALIDATION_RATIO, SKIP_FOLDER, use_subfolder=USE_SUBFOLDER)
    else:
        if not os.path.exists(image_file):
            print('Image file \'%s\' not found' %image_file)
            return None
        else:
            data = data_provider.load_by_file(image_file, VALIDATION_RATIO)

    # Set a CNN model definition
    if model_str == 'vgg':
        model_def = VGG
    elif model_str == 'vgg_slim':
        model_def = VGGslim
    elif model_str == 'inc_v3':
        model_def = InceptionV3
    elif model_str == 'alex': # default
        model_def = AlexNet

    # Make sure the checkpoint dir exists
    ckpt_dir = os.path.join(CHECKPOINT_DIR, model_str)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    # Start retraining/finetuning
    finetune(model_def, data, ckpt_dir, write_checkpoint_on_each_epoch, init_from_ckpt, use_adam_optimizer)

if __name__ == '__main__':
    main()
