In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import argparse
from features import extract_features_omniglot, extract_features_mini_imagenet
from inference import infer_classifier
from utilities import sample_normal, multinoulli_log_density, print_and_log, get_log_files
from data import get_data

In [None]:
# Hard-Code Once Command-Line Args

# Dataset to use
dataset = 'Omniglot'

# Whether to run traing only, testing only, or both training and testing.
mode = 'train_test'

# Size of the feature extractor output.
d_theta = 256

# Number of training examples.
shot = 5

# Number of classes.
way = 5

# Shot to be used at evaluation time. If not specified 'shot' will be used.
test_shot = None

# Way to be used at evaluation time. If not specified 'way' will be used.
test_way = None

# Number of tasks per batch.
tasks_per_batch = 16

# Number of samples from q.
samples = 10

# Learning rate.
learning_rate = 1e-4

# Number of training iterations.
iterations = 800

# Directory to save trained models.
checkpoint_dir = './checkpoint'

# Dropout keep probability.
dropout = 0.9

# Model to load and test.
test_model_path = None

# Frequency of summary results (in iterations).
print_freq = 200

In [None]:

logfile, checkpoint_path_validation, checkpoint_path_final = get_log_files(checkpoint_dir)

print_and_log(logfile, "Options:")

In [None]:
# Load training and eval data
data = get_data(dataset)
# <omniglot.OmniglotData object at 0x7fb091f4eed0>

# set the feature extractor based on the dataset
feature_extractor_fn = extract_features_mini_imagenet
if dataset == "Omniglot":
    feature_extractor_fn = extract_features_omniglot



In [None]:
# evaluation samples
eval_samples_train = 15
eval_samples_test = shot

# testing parameters
test_iterations = 600
test_args_per_batch = 1  # always use a batch size of 1 for testing

In [None]:
# tf placeholders
train_images = tf.placeholder(tf.float32, [None,  # tasks per batch
                                        None,  # shot
                                        data.get_image_height(),
                                        data.get_image_width(),
                                        data.get_image_channels()],
                            name='train_images')
test_images = tf.placeholder(tf.float32, [None,  # tasks per batch
                                        None,  # num test images
                                        data.get_image_height(),
                                        data.get_image_width(),
                                        data.get_image_channels()],
                            name='test_images')
train_labels = tf.placeholder(tf.float32, [None,  # tasks per batch
                                        None,  # shot
                                        way],
                            name='train_labels')
test_labels = tf.placeholder(tf.float32, [None,  # tasks per batch
                                        None,  # num test images
                                        way],
                            name='test_labels')

dropout_keep_prob = tf.placeholder(tf.float32, [], name='dropout_keep_prob')
L = tf.constant(samples, dtype=tf.float32, name="num_samples")

In [None]:
# Relevant computations for a single task
def evaluate_task(inputs):
    train_inputs, train_outputs, test_inputs, test_outputs = inputs
    with tf.variable_scope('shared_features'):
        # extract features from train and test data
        features_train = feature_extractor_fn(images=train_inputs,
                                                output_size=d_theta,
                                                use_batch_norm=True,
                                                dropout_keep_prob=dropout_keep_prob)
        features_test = feature_extractor_fn(images=test_inputs,
                                                output_size=d_theta,
                                                use_batch_norm=True,
                                                dropout_keep_prob=dropout_keep_prob)
    # Infer classification layer from q
    with tf.variable_scope('classifier'):
        classifier = infer_classifier(features_train, train_outputs, d_theta, way)

    # Local reparameterization trick
    # Compute parameters of q distribution over logits
    weight_mean, bias_mean = classifier['weight_mean'], classifier['bias_mean']
    weight_log_variance, bias_log_variance = classifier['weight_log_variance'], classifier['bias_log_variance']
    logits_mean_test = tf.matmul(features_test, weight_mean) + bias_mean
    logits_log_var_test =\
        tf.log(tf.matmul(features_test ** 2, tf.exp(weight_log_variance)) + tf.exp(bias_log_variance))
    logits_sample_test = sample_normal(logits_mean_test, logits_log_var_test, samples)
    test_labels_tiled = tf.tile(tf.expand_dims(test_outputs, 0), [samples, 1, 1])
    task_log_py = multinoulli_log_density(inputs=test_labels_tiled, logits=logits_sample_test)
    averaged_predictions = tf.reduce_logsumexp(logits_sample_test, axis=0) - tf.log(L)
    task_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(test_outputs, axis=-1),
                                                    tf.argmax(averaged_predictions, axis=-1)), tf.float32))
    task_score = tf.reduce_logsumexp(task_log_py, axis=0) - tf.log(L)
    task_loss = -tf.reduce_mean(task_score, axis=0)

    return [task_loss, task_accuracy]

In [None]:
# tf mapping of batch to evaluation function
batch_output = tf.map_fn(fn=evaluate_task,
                            elems=(train_images, train_labels, test_images, test_labels),
                            dtype=[tf.float32, tf.float32],
                            parallel_iterations=tasks_per_batch)

# average all values across batch
batch_losses, batch_accuracies = batch_output
loss = tf.reduce_mean(batch_losses)
accuracy = tf.reduce_mean(batch_accuracies)

In [None]:
with tf.Session() as sess:
    saver = tf.train.Saver()

    if mode == 'train' or mode == 'train_test':
        # train the model
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_step = optimizer.minimize(loss)

        validation_batches = 200
        iteration = 0
        best_validation_accuracy = 0.0
        train_iteration_accuracy = []
        sess.run(tf.global_variables_initializer())
        # Main training loop
        while iteration < iterations:
            
            train_inputs, test_inputs, train_outputs, test_outputs = \
                data.get_batch('train', tasks_per_batch, shot, way, eval_samples_train)
            feed_dict = {train_images: train_inputs, test_images: test_inputs,
                            train_labels: train_outputs, test_labels: test_outputs,
                            dropout_keep_prob: dropout}
            
            # Train for One Iteration
            _, iteration_loss, iteration_accuracy = sess.run([train_step, loss, accuracy], feed_dict)
            train_iteration_accuracy.append(iteration_accuracy)

            # Print Debugging Infrmation to Console and Log Ooutput
            # Execute Validation Test
            if (iteration > 0) and (iteration % print_freq == 0):
                # compute accuracy on validation set
                validation_iteration_accuracy = []
                validation_iteration = 0
                while validation_iteration < validation_batches:
                    
                    train_inputs, test_inputs, train_outputs, test_outputs = \
                        data.get_batch('validation', tasks_per_batch, shot, way, eval_samples_test)
                    feed_dict = {train_images: train_inputs, test_images: test_inputs,
                                    train_labels: train_outputs, test_labels: test_outputs,
                                    dropout_keep_prob: 1.0}

                    iteration_accuracy = sess.run(accuracy, feed_dict)
                    validation_iteration_accuracy.append(iteration_accuracy)
                    validation_iteration += 1
                
                validation_accuracy = np.array(validation_iteration_accuracy).mean()
                train_accuracy = np.array(train_iteration_accuracy).mean()

                # save checkpoint if validation is the best so far
                if validation_accuracy > best_validation_accuracy:
                    best_validation_accuracy = validation_accuracy
                    saver.save(sess=sess, save_path=checkpoint_path_validation)

                print_and_log(logfile, 'Iteration: {}, Loss: {:5.3f}, Train-Acc: {:5.3f}, Val-Acc: {:5.3f}'
                    .format(iteration, iteration_loss, train_accuracy, validation_accuracy))
                train_iteration_accuracy = []

            iteration += 1
        # save the checkpoint from the final epoch
        saver.save(sess, save_path=checkpoint_path_final)
        print_and_log(logfile, 'Fully-trained model saved to: {}'.format(checkpoint_path_final))
        print_and_log(logfile, 'Best validation accuracy: {:5.3f}'.format(best_validation_accuracy))
        print_and_log(logfile, 'Best validation model saved to: {}'.format(checkpoint_path_validation))
    
    def test_model(model_path, load=True):
        if load:
            saver.restore(sess, save_path=model_path)
        test_iteration = 0
        test_iteration_accuracy = []
        
        # Main Test Loop
        while test_iteration < test_iterations:
            
            train_inputs, test_inputs, train_outputs, test_outputs = \
                                    data.get_batch('test', test_args_per_batch, test_shot, test_way,
                                                    eval_samples_test)
            feedDict = {train_images: train_inputs, test_images: test_inputs,
                        train_labels: train_outputs, test_labels: test_outputs,
                        dropout_keep_prob: 1.0}

            iter_acc = sess.run(accuracy, feedDict)
            test_iteration_accuracy.append(iter_acc)
            test_iteration += 1

        test_accuracy = np.array(test_iteration_accuracy).mean() * 100.0
        confidence_interval_95 = (196.0 * np.array(test_iteration_accuracy).std()) / np.sqrt(len(test_iteration_accuracy))
        
        print_and_log(logfile, 'Held out accuracy: {0:5.3f} +/- {1:5.3f} on {2:}'
                        .format(test_accuracy, confidence_interval_95, model_path))

    if mode == 'train_test':
        print_and_log(logfile, 'Train Shot: {0:d}, Train Way: {1:d}, Test Shot {2:d}, Test Way {3:d}'
                        .format(shot, way, test_shot, test_way))
        # test the model on the final trained model
        # no need to load the model, it was just trained
        test_model(checkpoint_path_final, load=False)

        # test the model on the best validation checkpoint so far
        test_model(checkpoint_path_validation)

    if mode == 'test':
        test_model(test_model_path)

logfile.close()