In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import tensorflow as tf

import utils

  return f(*args, **kwds)


In [4]:
#SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
#WORK_DIRECTORY = 'data'
#IMAGE_SIZE = 28
#NUM_CHANNELS = 1
#PIXEL_DEPTH = 255
#NUM_LABELS = 10
#VALIDATION_SIZE = 5000  # Size of the validation set.
#SEED = 66478  # Set to None for random seed.
BATCH_SIZE = 64
NUM_EPOCHS = 10
EVAL_BATCH_SIZE = 64
EVAL_FREQUENCY = 100  # Number of steps between evaluations.
DATASET = 'mnist'

NUM_UNROLL_STEPS = 5

In [5]:
def feature_extractor(input_images, training):
    x = input_images
    x = tf.layers.conv2d(x, filters=16, kernel_size=[5,5], strides=1,
                         padding="VALID")
    x = tf.layers.batch_normalization(x, momentum=0.9, scale=False, fused=True, training=training)
    x = tf.nn.relu(x)
    x = tf.layers.conv2d(x, filters=32, kernel_size=[3,3], strides=1,
                         padding="VALID")
    x = tf.layers.batch_normalization(x, momentum=0.9, scale=False, fused=True, training=training)
    x = tf.nn.relu(x)
    x = tf.layers.conv2d(x, filters=64, kernel_size=[3,3], strides=1,
                         padding="VALID")
    x = tf.layers.batch_normalization(x, momentum=0.9, scale=False, fused=True, training=training)
    x = tf.nn.relu(x)
    x = tf.layers.conv2d(x, filters=128, kernel_size=[3,3], strides=1,
                         padding="VALID")
    x = tf.layers.batch_normalization(x, momentum=0.9, fused=True, training=training)
    x = tf.nn.relu(x)
    return x

def model_step(input_images, prior, batch_size, training, num_labels):
    """The Model definition."""
    EMBEDDING_SIZE = 10
    NUM_INTERNAL_CONVS = 5
    prior_embeddings = tf.get_variable("prior_embeddings",
                                        shape=[num_labels, EMBEDDING_SIZE - 1],
                                        initializer=tf.random_uniform_initializer(
                                            minval=-1.0/np.sqrt(num_labels), maxval=1.0/np.sqrt(num_labels)))
    raw_embedding_features = tf.matmul(prior, prior_embeddings)
    prior_entropy = - tf.reduce_sum(prior * tf.log(1e-4 + prior), axis=-1, keep_dims=True)
    embedding_features = tf.concat([raw_embedding_features, prior_entropy], axis=-1)
    for i in range(4):
        gates = tf.layers.dense(
            embedding_features, EMBEDDING_SIZE, activation=tf.nn.sigmoid)
        embedding_features = gates * (embedding_features + tf.layers.dense(
            embedding_features, EMBEDDING_SIZE, use_bias=False, activation=tf.nn.relu))
    
    def get_dynamic_weights(weights_shape):
        num_weights = np.prod(weights_shape[1:])
        dynamic_weights_flat = tf.layers.dense(embedding_features, num_weights)
        dynamic_weights = tf.reshape(dynamic_weights_flat, weights_shape)
        dynamic_weights.set_shape(weights_shape)
        return dynamic_weights
    
    conv1_weights = get_dynamic_weights([batch_size, 1, 1, 128, EMBEDDING_SIZE])
    conv2_weights = [get_dynamic_weights([batch_size, 2, 2, EMBEDDING_SIZE, EMBEDDING_SIZE])
                     for _ in range(NUM_INTERNAL_CONVS)]
    conv3_weights = get_dynamic_weights([batch_size, 1, 1, EMBEDDING_SIZE, num_labels])
    conv3_biases = get_dynamic_weights([batch_size, num_labels])
    
    logits_lst, posteriors_lst = [], []
    for elm in range(batch_size):
        conv = tf.nn.conv2d([input_images[elm]],
                            conv1_weights[elm],
                            strides=[1, 1, 1, 1],
                            padding='VALID')
        relu = tf.nn.relu(conv)
        for i in range(NUM_INTERNAL_CONVS):
            conv = tf.nn.conv2d(relu,
                            conv2_weights[i][elm],
                            strides=[1, 1, 1, 1],
                            padding='SAME')
            relu = tf.nn.relu(conv) + relu
        conv = tf.nn.conv2d(relu,
                            conv3_weights[elm],
                            strides=[1, 1, 1, 1],
                            padding='VALID')
        conv_shape = conv.get_shape()
        logits = tf.nn.avg_pool(conv, ksize=[1, conv_shape[1], conv_shape[2], 1],
                                strides=[1, 1, 1, 1], padding="VALID") + conv3_biases[elm]
        logits = tf.squeeze(logits, axis=[1, 2])
        posteriors = tf.nn.softmax(logits)
        logits_lst.append(logits)
        posteriors_lst.append(posteriors)
    return tf.concat(logits_lst, axis=0), tf.concat(posteriors_lst, axis=0)

def apply(input_images, training, train_labels_node, num_labels):
    results = []
    loss = 0.0
    conv_features = feature_extractor(input_images, training=training)
    batch_size = conv_features.get_shape()[0]  # HyperNet operates on single images only
    priors = np.array([[1/num_labels for _ in range(num_labels)] for _ in range(batch_size)],
                         dtype=np.float32)
    for step in range(NUM_UNROLL_STEPS):
        logits, posteriors = model_step(conv_features, priors, batch_size, training, num_labels)
        priors = posteriors
        results.append((logits, posteriors))
        loss += tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=train_labels_node, logits=logits))
    return tf.stack([logits for (logits, _) in results]), loss

In [None]:
use_priors = True

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

stdout_lines = utils.run_train(apply, train_config, dataset,
                               build_func_kwargs=dict())

Initialized!
Step 0 (epoch 0.00), 277.6 ms
Minibatch loss: 11.555, learning rate: 0.001000
Minibatch error: [89.0625, 92.1875, 89.0625, 87.5, 89.0625]
Validation error: [90.42, 90.14, 90.14, 89.0, 90.42]
Step 100 (epoch 0.12), 1135.3 ms
Minibatch loss: 6.088, learning rate: 0.001000
Minibatch error: [40.625, 29.6875, 56.25, 56.25, 25.0]
Validation error: [52.88, 50.98, 50.0, 63.5, 46.4]
Step 200 (epoch 0.23), 1106.7 ms
Minibatch loss: 4.684, learning rate: 0.001000
Minibatch error: [15.625, 17.1875, 37.5, 43.75, 21.875]
Validation error: [64.38, 69.76, 73.2, 75.22, 72.7]
Step 300 (epoch 0.35), 1108.5 ms
Minibatch loss: 3.859, learning rate: 0.001000
Minibatch error: [23.4375, 14.0625, 23.4375, 21.875, 10.9375]
Validation error: [36.14, 32.14, 52.86, 53.54, 44.22]
Step 400 (epoch 0.47), 1111.3 ms
Minibatch loss: 3.754, learning rate: 0.001000
Minibatch error: [23.4375, 18.75, 18.75, 20.3125, 17.1875]
Validation error: [24.099999999999994, 15.540000000000006, 14.019999999999996, 12.14, 1

In [10]:
print("".join(stdout_lines))

Step 0 (epoch 0.00), 277.6 ms
Minibatch loss: 11.555, learning rate: 0.001000
Minibatch error: [89.0625, 92.1875, 89.0625, 87.5, 89.0625]
Validation error: [90.42, 90.14, 90.14, 89.0, 90.42]
Step 100 (epoch 0.12), 1135.3 ms
Minibatch loss: 6.088, learning rate: 0.001000
Minibatch error: [40.625, 29.6875, 56.25, 56.25, 25.0]
Validation error: [52.88, 50.98, 50.0, 63.5, 46.4]
Step 200 (epoch 0.23), 1106.7 ms
Minibatch loss: 4.684, learning rate: 0.001000
Minibatch error: [15.625, 17.1875, 37.5, 43.75, 21.875]
Validation error: [64.38, 69.76, 73.2, 75.22, 72.7]
Step 300 (epoch 0.35), 1108.5 ms
Minibatch loss: 3.859, learning rate: 0.001000
Minibatch error: [23.4375, 14.0625, 23.4375, 21.875, 10.9375]
Validation error: [36.14, 32.14, 52.86, 53.54, 44.22]
Step 400 (epoch 0.47), 1111.3 ms
Minibatch loss: 3.754, learning rate: 0.001000
Minibatch error: [23.4375, 18.75, 18.75, 20.3125, 17.1875]
Validation error: [24.099999999999994, 15.540000000000006, 14.019999999999996, 12.14, 11.5]
Step 500