In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import tensorflow as tf

import utils

  return f(*args, **kwds)


In [4]:
#SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
#WORK_DIRECTORY = 'data'
#IMAGE_SIZE = 28
#NUM_CHANNELS = 1
#PIXEL_DEPTH = 255
#NUM_LABELS = 10
#VALIDATION_SIZE = 5000  # Size of the validation set.
#SEED = 66478  # Set to None for random seed.
BATCH_SIZE = 64
NUM_EPOCHS = 10
EVAL_BATCH_SIZE = 64
EVAL_FREQUENCY = 100  # Number of steps between evaluations.
DATASET = 'mnist'

NUM_UNROLL_STEPS = 5

### Should be like 1_stage/confnn_gated_mnist.ipynb, but without dropout
### Just copied 2_stage/confnn_gated_cifar10.ipynb.ipynb and changed DATASET variable

In [5]:
def model_step(input_images, prior, batch_size, training, num_labels, use_priors):
    """The Model definition."""
    inputs = input_images
    
    conv1 = tf.layers.conv2d(
        inputs=inputs,
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    
    pool2_shape = pool2.get_shape()
    num_units_after_conv = pool2_shape[1] * pool2_shape[2] * pool2_shape[3]

    pool2_flat = tf.reshape(pool2, [-1, num_units_after_conv])
    
    if use_priors:
        projections = tf.layers.dense(inputs=prior, units=100, activation=tf.nn.relu)
        gates = tf.layers.dense(inputs=projections, units=num_units_after_conv, activation=tf.nn.sigmoid)
        
        gated = tf.multiply(pool2_flat, gates)
    else:
        gated = pool2_flat
    
    
    dense = tf.layers.dense(inputs=gated, units=1024, activation=tf.nn.relu)

    logits = tf.layers.dense(inputs=dense, units=num_labels)
    posteriors = tf.nn.softmax(logits)
    
    return logits, posteriors

def apply(input_images, training, train_labels_node, num_labels, use_priors):
    results = []
    loss = 0.0

    batch_size = input_images.get_shape()[0]
    priors = tf.ones((batch_size, num_labels)) / num_labels
    for step in range(NUM_UNROLL_STEPS):
        with tf.variable_scope('one_step', reuse=(step > 0)):
            logits, posteriors = model_step(input_images, priors, batch_size,
                                            training=training, num_labels=num_labels,
                                            use_priors=use_priors)
        priors = posteriors
        results.append((logits, posteriors))
        loss += tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=train_labels_node, logits=logits))
    return tf.stack([logits for (logits, _) in results]), loss

## Use priors = False

In [6]:
use_priors = False

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

history, stdout_lines = utils.run_train(apply, train_config, dataset,
                                        build_func_kwargs=dict(use_priors=use_priors))

Initialized!
Step 0 (epoch 0.00), 29.3 ms
Minibatch loss: 10.430, learning rate: 0.001000
Minibatch error: [75.0, 75.0, 75.0, 75.0, 75.0]
Validation error: [86.58, 86.58, 86.58, 86.58, 86.58]
Step 100 (epoch 0.12), 36.7 ms
Minibatch loss: 0.407, learning rate: 0.001000
Minibatch error: [1.5625, 1.5625, 1.5625, 1.5625, 1.5625]
Validation error: [5.640000000000001, 5.640000000000001, 5.640000000000001, 5.640000000000001, 5.640000000000001]
Step 200 (epoch 0.23), 36.6 ms
Minibatch loss: 0.460, learning rate: 0.001000
Minibatch error: [1.5625, 1.5625, 1.5625, 1.5625, 1.5625]
Validation error: [2.8400000000000034, 2.8400000000000034, 2.8400000000000034, 2.8400000000000034, 2.8400000000000034]
Step 300 (epoch 0.35), 36.8 ms
Minibatch loss: 0.467, learning rate: 0.001000
Minibatch error: [3.125, 3.125, 3.125, 3.125, 3.125]
Validation error: [2.519999999999996, 2.519999999999996, 2.519999999999996, 2.519999999999996, 2.519999999999996]
Step 400 (epoch 0.47), 36.7 ms
Minibatch loss: 0.669, lear

Step 3500 (epoch 4.07), 36.6 ms
Minibatch loss: 0.004, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.2000000000000028, 1.2000000000000028, 1.2000000000000028, 1.2000000000000028, 1.2000000000000028]
Step 3600 (epoch 4.19), 36.7 ms
Minibatch loss: 0.000, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8400000000000034, 0.8400000000000034, 0.8400000000000034, 0.8400000000000034, 0.8400000000000034]
Step 3700 (epoch 4.31), 36.7 ms
Minibatch loss: 0.001, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.019999999999996, 1.019999999999996, 1.019999999999996, 1.019999999999996, 1.019999999999996]
Step 3800 (epoch 4.42), 36.8 ms
Minibatch loss: 0.008, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0999999999999943, 1.0999999999999943, 1.0999999999999943, 1.0999999999999943, 1.0999999999999943]
Step 3900 (epoch 4.54), 36.6 ms
Minibatch

Step 7000 (epoch 8.15), 36.7 ms
Minibatch loss: 0.002, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0799999999999983, 1.0799999999999983, 1.0799999999999983, 1.0799999999999983, 1.0799999999999983]
Step 7100 (epoch 8.26), 36.6 ms
Minibatch loss: 0.001, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.7600000000000051, 0.7600000000000051, 0.7600000000000051, 0.7600000000000051, 0.7600000000000051]
Step 7200 (epoch 8.38), 36.6 ms
Minibatch loss: 0.001, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.7800000000000011, 0.7800000000000011, 0.7800000000000011, 0.7800000000000011, 0.7800000000000011]
Step 7300 (epoch 8.49), 36.7 ms
Minibatch loss: 0.011, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8799999999999955, 0.8799999999999955, 0.8799999999999955, 0.8799999999999955, 0.8799999999999955]
Step 7400 (epoch 8.61), 36.7 ms
Mini

## Use priors = True

In [7]:
use_priors = True

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

history, stdout_lines = utils.run_train(apply, train_config, dataset,
                                        build_func_kwargs=dict(use_priors=use_priors))

Initialized!
Step 0 (epoch 0.00), 7.8 ms
Minibatch loss: 10.817, learning rate: 0.001000
Minibatch error: [85.9375, 85.9375, 85.9375, 85.9375, 85.9375]
Validation error: [90.52, 90.52, 90.52, 90.52, 90.52]
Step 100 (epoch 0.12), 47.7 ms
Minibatch loss: 0.399, learning rate: 0.001000
Minibatch error: [4.6875, 4.6875, 4.6875, 4.6875, 4.6875]
Validation error: [5.640000000000001, 5.599999999999994, 5.599999999999994, 5.599999999999994, 5.599999999999994]
Step 200 (epoch 0.23), 47.6 ms
Minibatch loss: 0.463, learning rate: 0.001000
Minibatch error: [4.6875, 4.6875, 4.6875, 4.6875, 4.6875]
Validation error: [3.0, 3.019999999999996, 3.019999999999996, 3.019999999999996, 3.019999999999996]
Step 300 (epoch 0.35), 47.6 ms
Minibatch loss: 0.558, learning rate: 0.001000
Minibatch error: [4.6875, 4.6875, 4.6875, 4.6875, 4.6875]
Validation error: [2.8400000000000034, 2.819999999999993, 2.819999999999993, 2.819999999999993, 2.819999999999993]
Step 400 (epoch 0.47), 47.5 ms
Minibatch loss: 0.808, lea

Step 3400 (epoch 3.96), 47.6 ms
Minibatch loss: 0.003, learning rate: 0.000857
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.3199999999999932, 1.2199999999999989, 1.2000000000000028, 1.2000000000000028, 1.2000000000000028]
Step 3500 (epoch 4.07), 47.7 ms
Minibatch loss: 0.025, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0400000000000063, 0.980000000000004, 0.980000000000004, 0.980000000000004, 0.980000000000004]
Step 3600 (epoch 4.19), 47.7 ms
Minibatch loss: 0.000, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9000000000000057, 0.8400000000000034, 0.8199999999999932, 0.8199999999999932, 0.8199999999999932]
Step 3700 (epoch 4.31), 47.7 ms
Minibatch loss: 0.011, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8799999999999955, 0.8799999999999955, 0.8799999999999955, 0.8799999999999955, 0.8799999999999955]
Step 3800 (epoch 4.42), 47.7 ms
Minibatc

Step 6900 (epoch 8.03), 47.6 ms
Minibatch loss: 0.001, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9200000000000017, 0.9200000000000017, 0.9200000000000017, 0.9200000000000017, 0.9200000000000017]
Step 7000 (epoch 8.15), 47.6 ms
Minibatch loss: 0.004, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8799999999999955, 0.8599999999999994, 0.8599999999999994, 0.8599999999999994, 0.8599999999999994]
Step 7100 (epoch 8.26), 47.7 ms
Minibatch loss: 0.006, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9000000000000057, 0.8799999999999955, 0.8599999999999994, 0.8599999999999994, 0.8599999999999994]
Step 7200 (epoch 8.38), 47.7 ms
Minibatch loss: 0.003, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.7600000000000051, 0.7600000000000051, 0.7199999999999989, 0.7199999999999989, 0.7199999999999989]
Step 7300 (epoch 8.49), 47.6 ms
Mini

### Introduced G

In [8]:
def model_step(input_images, prior, batch_size, training, num_labels, use_priors):
    """The Model definition."""
    inputs = input_images
    
    conv1 = tf.layers.conv2d(
        inputs=inputs,
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    
    pool2_shape = pool2.get_shape()
    num_units_after_conv = pool2_shape[1] * pool2_shape[2] * pool2_shape[3]

    pool2_flat = tf.reshape(pool2, [-1, num_units_after_conv])
    
    if use_priors:
        projections = tf.layers.dense(inputs=prior, units=100, activation=tf.nn.relu)
        gates = tf.layers.dense(inputs=projections, units=num_units_after_conv, activation=tf.nn.sigmoid)
        bias = tf.layers.dense(inputs=projections, units=num_units_after_conv, activation=None)
        
        gated = tf.multiply(pool2_flat, gates) + bias
    else:
        gated = pool2_flat
    
    
    dense = tf.layers.dense(inputs=gated, units=1024, activation=tf.nn.relu)

    logits = tf.layers.dense(inputs=dense, units=num_labels)
    posteriors = tf.nn.softmax(logits)
    
    return logits, posteriors

def apply(input_images, training, train_labels_node, num_labels, use_priors):
    results = []
    loss = 0.0

    batch_size = input_images.get_shape()[0]
    priors = tf.ones((batch_size, num_labels)) / num_labels
    for step in range(NUM_UNROLL_STEPS):
        with tf.variable_scope('one_step', reuse=(step > 0)):
            logits, posteriors = model_step(input_images, priors, batch_size,
                                            training=training, num_labels=num_labels,
                                            use_priors=use_priors)
        priors = posteriors
        results.append((logits, posteriors))
        loss += tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=train_labels_node, logits=logits))
    return tf.stack([logits for (logits, _) in results]), loss

In [9]:
use_priors = True

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

history, stdout_lines = utils.run_train(apply, train_config, dataset,
                                        build_func_kwargs=dict(use_priors=use_priors))

Initialized!
Step 0 (epoch 0.00), 8.8 ms
Minibatch loss: 10.920, learning rate: 0.001000
Minibatch error: [85.9375, 85.9375, 85.9375, 85.9375, 85.9375]
Validation error: [90.76, 90.76, 90.76, 90.76, 90.76]
Step 100 (epoch 0.12), 51.5 ms
Minibatch loss: 0.314, learning rate: 0.001000
Minibatch error: [1.5625, 1.5625, 1.5625, 1.5625, 1.5625]
Validation error: [6.040000000000006, 6.299999999999997, 6.359999999999999, 6.340000000000003, 6.3799999999999955]
Step 200 (epoch 0.23), 51.5 ms
Minibatch loss: 0.599, learning rate: 0.001000
Minibatch error: [3.125, 3.125, 3.125, 3.125, 3.125]
Validation error: [3.3799999999999955, 3.3799999999999955, 3.4200000000000017, 3.4399999999999977, 3.4399999999999977]
Step 300 (epoch 0.35), 51.4 ms
Minibatch loss: 0.427, learning rate: 0.001000
Minibatch error: [3.125, 3.125, 3.125, 3.125, 3.125]
Validation error: [2.8400000000000034, 2.799999999999997, 2.9000000000000057, 2.9000000000000057, 2.9200000000000017]
Step 400 (epoch 0.47), 51.4 ms
Minibatch los

Step 3400 (epoch 3.96), 51.4 ms
Minibatch loss: 0.035, learning rate: 0.000857
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.3199999999999932, 1.3199999999999932, 1.3199999999999932, 1.3199999999999932, 1.3199999999999932]
Step 3500 (epoch 4.07), 51.5 ms
Minibatch loss: 0.011, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.1400000000000006, 1.1599999999999966, 1.1800000000000068, 1.1800000000000068, 1.1800000000000068]
Step 3600 (epoch 4.19), 51.5 ms
Minibatch loss: 0.000, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9000000000000057, 0.8599999999999994, 0.8599999999999994, 0.8599999999999994, 0.8599999999999994]
Step 3700 (epoch 4.31), 51.4 ms
Minibatch loss: 0.004, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.3199999999999932, 1.3599999999999994, 1.3599999999999994, 1.3599999999999994, 1.3599999999999994]
Step 3800 (epoch 4.42), 51.4 ms
Mini

Step 6900 (epoch 8.03), 51.4 ms
Minibatch loss: 0.006, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0, 1.0, 1.0, 1.0, 1.0]
Step 7000 (epoch 8.15), 51.4 ms
Minibatch loss: 0.002, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9000000000000057, 0.9200000000000017, 0.9200000000000017, 0.9200000000000017, 0.9200000000000017]
Step 7100 (epoch 8.26), 51.5 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.7000000000000028, 0.7000000000000028, 0.7000000000000028, 0.7000000000000028, 0.7000000000000028]
Step 7200 (epoch 8.38), 51.5 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.9399999999999977, 0.9399999999999977, 0.9399999999999977, 0.9399999999999977, 0.9399999999999977]
Step 7300 (epoch 8.49), 51.4 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0,

## Separate 2-layer networks for F and G

In [10]:
def model_step(input_images, prior, batch_size, training, num_labels, use_priors):
    """The Model definition."""
    inputs = input_images
    
    conv1 = tf.layers.conv2d(
        inputs=inputs,
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    
    pool2_shape = pool2.get_shape()
    num_units_after_conv = pool2_shape[1] * pool2_shape[2] * pool2_shape[3]

    pool2_flat = tf.reshape(pool2, [-1, num_units_after_conv])
    
    if use_priors:
        projections_for_gates = tf.layers.dense(inputs=prior, units=100, activation=tf.nn.relu)
        gates = tf.layers.dense(inputs=projections_for_gates, units=num_units_after_conv, activation=tf.nn.sigmoid)
        projections_for_bias = tf.layers.dense(inputs=prior, units=100, activation=tf.nn.relu)
        bias = tf.layers.dense(inputs=projections_for_bias, units=num_units_after_conv, activation=None)
        
        gated = tf.multiply(pool2_flat, gates) + bias
    else:
        gated = pool2_flat
    
    
    dense = tf.layers.dense(inputs=gated, units=1024, activation=tf.nn.relu)

    logits = tf.layers.dense(inputs=dense, units=num_labels)
    posteriors = tf.nn.softmax(logits)
    
    return logits, posteriors

def apply(input_images, training, train_labels_node, num_labels, use_priors):
    results = []
    loss = 0.0

    batch_size = input_images.get_shape()[0]
    priors = tf.ones((batch_size, num_labels)) / num_labels
    for step in range(NUM_UNROLL_STEPS):
        with tf.variable_scope('one_step', reuse=(step > 0)):
            logits, posteriors = model_step(input_images, priors, batch_size,
                                            training=training, num_labels=num_labels,
                                            use_priors=use_priors)
        priors = posteriors
        results.append((logits, posteriors))
        loss += tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=train_labels_node, logits=logits))
    return tf.stack([logits for (logits, _) in results]), loss

In [11]:
use_priors = True

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

history, stdout_lines = utils.run_train(apply, train_config, dataset,
                                        build_func_kwargs=dict(use_priors=use_priors))

Initialized!
Step 0 (epoch 0.00), 9.0 ms
Minibatch loss: 10.909, learning rate: 0.001000
Minibatch error: [85.9375, 85.9375, 85.9375, 85.9375, 85.9375]
Validation error: [90.76, 90.76, 90.76, 90.76, 90.76]
Step 100 (epoch 0.12), 52.3 ms
Minibatch loss: 0.234, learning rate: 0.001000
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [5.3799999999999955, 5.140000000000001, 5.099999999999994, 5.1200000000000045, 5.140000000000001]
Step 200 (epoch 0.23), 52.3 ms
Minibatch loss: 0.482, learning rate: 0.001000
Minibatch error: [4.6875, 6.25, 3.125, 3.125, 3.125]
Validation error: [3.5999999999999943, 3.819999999999993, 3.680000000000007, 3.6599999999999966, 3.6599999999999966]
Step 300 (epoch 0.35), 52.4 ms
Minibatch loss: 0.460, learning rate: 0.001000
Minibatch error: [4.6875, 3.125, 3.125, 3.125, 3.125]
Validation error: [2.9000000000000057, 2.780000000000001, 2.819999999999993, 2.819999999999993, 2.799999999999997]
Step 400 (epoch 0.47), 52.3 ms
Minibatch loss: 0.777, learning

Step 3500 (epoch 4.07), 52.3 ms
Minibatch loss: 0.008, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.980000000000004, 1.019999999999996, 1.0, 1.0, 1.0]
Step 3600 (epoch 4.19), 52.3 ms
Minibatch loss: 0.008, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0799999999999983, 1.1200000000000045, 1.1200000000000045, 1.1200000000000045, 1.1200000000000045]
Step 3700 (epoch 4.31), 52.3 ms
Minibatch loss: 0.000, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.1400000000000006, 1.1599999999999966, 1.1599999999999966, 1.1599999999999966, 1.1599999999999966]
Step 3800 (epoch 4.42), 52.3 ms
Minibatch loss: 0.017, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0999999999999943, 1.0600000000000023, 1.0799999999999983, 1.0799999999999983, 1.0799999999999983]
Step 3900 (epoch 4.54), 52.2 ms
Minibatch loss: 0.076, learning rate: 0.000815
Mini

Step 7100 (epoch 8.26), 52.3 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8599999999999994, 0.7800000000000011, 0.8199999999999932, 0.7800000000000011, 0.7999999999999972]
Step 7200 (epoch 8.38), 52.3 ms
Minibatch loss: 0.011, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.019999999999996, 1.0600000000000023, 1.0400000000000063, 1.0400000000000063, 1.0400000000000063]
Step 7300 (epoch 8.49), 52.4 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.6800000000000068, 0.6599999999999966, 0.6599999999999966, 0.6599999999999966, 0.6599999999999966]
Step 7400 (epoch 8.61), 52.3 ms
Minibatch loss: 0.000, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.7399999999999949, 0.7800000000000011, 0.7800000000000011, 0.7600000000000051, 0.7800000000000011]
Step 7500 (epoch 8.73), 52.3 ms
Minib

In [12]:
def model_step(input_images, prior, batch_size, training, num_labels, use_priors):
    """The Model definition."""
    def get_gates_and_bias(layer_sizes):
        gates = prior
        bias = prior
        for size in layer_sizes[:-1]:
            gates = tf.layers.dense(inputs=gates, units=size, activation=tf.nn.relu)
            bias = tf.layers.dense(inputs=bias, units=size, activation=tf.nn.relu)
        gates = tf.layers.dense(inputs=gates, units=layer_sizes[-1], activation=tf.nn.sigmoid)
        bias = tf.layers.dense(inputs=bias, units=layer_sizes[-1], activation=None)
        return gates, bias
        
    def apply_gating_on_conv(tensor, layer_sizes):
        if not use_priors:
            return tensor
        num_channels = tensor.get_shape()[-1]
        gates, bias = get_gates_and_bias(layer_sizes + (num_channels, ))
        gates = tf.expand_dims(tf.expand_dims(gates, axis=1), axis=1)
        bias = tf.expand_dims(tf.expand_dims(bias, axis=1), axis=1)
        return tf.multiply(tensor, gates) + bias
    
    def apply_gating_on_dense(tensor, layer_sizes):
        if not use_priors:
            return tensor
        num_units = tensor.get_shape()[1]
        gates, bias = get_gates_and_bias(layer_sizes + (num_units, ))
        return tf.multiply(tensor, gates) + bias
        
    
    inputs = input_images
    
    conv1 = tf.layers.conv2d(
        inputs=inputs,
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    
    conv1 = apply_gating_on_conv(conv1, (100,))
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)
    
    conv2 = apply_gating_on_conv(conv2, (100,))
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    pool2_flat = tf.reshape(pool2, [pool2.get_shape()[0], -1])
    pool2_flat = apply_gating_on_dense(pool2_flat, (100,))
    
    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
    dense = apply_gating_on_dense(dense, (100,))

    logits = tf.layers.dense(inputs=dense, units=num_labels)
    posteriors = tf.nn.softmax(logits)
    
    return logits, posteriors

def apply(input_images, training, train_labels_node, num_labels, use_priors):
    results = []
    loss = 0.0

    batch_size = input_images.get_shape()[0]
    priors = tf.ones((batch_size, num_labels)) / num_labels
    for step in range(NUM_UNROLL_STEPS):
        with tf.variable_scope('one_step', reuse=(step > 0)):
            logits, posteriors = model_step(input_images, priors, batch_size,
                                            training=training, num_labels=num_labels,
                                            use_priors=use_priors)
        priors = posteriors
        results.append((logits, posteriors))
        loss += tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=train_labels_node, logits=logits))
    return tf.stack([logits for (logits, _) in results]), loss

In [13]:
use_priors = True

tf.reset_default_graph()

dataset = utils.get_dataset(DATASET)

# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0, dtype=tf.float32)
# Decay once per epoch, using an exponential schedule starting at 0.01.
learning_rate = tf.train.exponential_decay(
    1e-3,                # Base learning rate.
    batch * BATCH_SIZE,  # Current index into the dataset.
    dataset.train_size,          # Decay step.
    0.95,                # Decay rate.
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate)

train_config = dict(
    optimizer=optimizer,
    batch_var=batch,
    learning_rate_var=learning_rate,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    eval_frequency=EVAL_FREQUENCY,
)

history, stdout_lines = utils.run_train(apply, train_config, dataset,
                                        build_func_kwargs=dict(use_priors=use_priors))

Initialized!
Step 0 (epoch 0.00), 21.1 ms
Minibatch loss: 11.311, learning rate: 0.001000
Minibatch error: [85.9375, 85.9375, 85.9375, 85.9375, 85.9375]
Validation error: [90.76, 90.76, 90.76, 90.76, 90.76]
Step 100 (epoch 0.12), 85.5 ms
Minibatch loss: 0.616, learning rate: 0.001000
Minibatch error: [1.5625, 3.125, 3.125, 3.125, 3.125]
Validation error: [9.099999999999994, 8.939999999999998, 8.900000000000006, 9.019999999999996, 9.060000000000002]
Step 200 (epoch 0.23), 85.2 ms
Minibatch loss: 0.990, learning rate: 0.001000
Minibatch error: [7.8125, 6.25, 6.25, 6.25, 6.25]
Validation error: [4.739999999999995, 4.579999999999998, 4.439999999999998, 4.480000000000004, 4.480000000000004]
Step 300 (epoch 0.35), 85.3 ms
Minibatch loss: 0.759, learning rate: 0.001000
Minibatch error: [7.8125, 6.25, 6.25, 6.25, 6.25]
Validation error: [3.4599999999999937, 3.5, 3.519999999999996, 3.480000000000004, 3.4399999999999977]
Step 400 (epoch 0.47), 85.2 ms
Minibatch loss: 1.004, learning rate: 0.0010

Step 3500 (epoch 4.07), 85.3 ms
Minibatch loss: 0.026, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [2.8400000000000034, 2.9399999999999977, 2.8400000000000034, 2.8799999999999955, 2.8400000000000034]
Step 3600 (epoch 4.19), 85.5 ms
Minibatch loss: 0.000, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0799999999999983, 1.1599999999999966, 1.0799999999999983, 1.1599999999999966, 1.1200000000000045]
Step 3700 (epoch 4.31), 85.3 ms
Minibatch loss: 0.002, learning rate: 0.000815
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.1400000000000006, 1.1200000000000045, 1.1599999999999966, 1.1200000000000045, 1.1400000000000006]
Step 3800 (epoch 4.42), 85.3 ms
Minibatch loss: 0.037, learning rate: 0.000815
Minibatch error: [0.0, 1.5625, 0.0, 0.0, 0.0]
Validation error: [1.0999999999999943, 1.0400000000000063, 1.0600000000000023, 1.0400000000000063, 1.0600000000000023]
Step 3900 (epoch 4.54), 85.4 ms
M

Step 7000 (epoch 8.15), 85.5 ms
Minibatch loss: 0.006, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0400000000000063, 1.0799999999999983, 1.019999999999996, 1.0600000000000023, 1.0400000000000063]
Step 7100 (epoch 8.26), 85.6 ms
Minibatch loss: 0.012, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.8199999999999932, 0.7999999999999972, 0.7399999999999949, 0.7800000000000011, 0.7600000000000051]
Step 7200 (epoch 8.38), 85.4 ms
Minibatch loss: 0.040, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [1.0400000000000063, 0.980000000000004, 0.980000000000004, 0.9599999999999937, 0.980000000000004]
Step 7300 (epoch 8.49), 85.3 ms
Minibatch loss: 0.004, learning rate: 0.000663
Minibatch error: [0.0, 0.0, 0.0, 0.0, 0.0]
Validation error: [0.980000000000004, 0.9200000000000017, 0.8599999999999994, 0.9000000000000057, 0.9000000000000057]
Step 7400 (epoch 8.61), 85.3 ms
Minibatch