In [1]:
import cPickle as pickle
import matplotlib.pylab as pl
%matplotlib inline
import numpy as np
import mxnet as mx
import mx_layers as layers

In [2]:
n_residual_layers = 4

In [3]:
accuracy, progress = pickle.load(open('info/residual-network-on-shrinked-mnist-4', 'rb'))
accuracy

0.5215963375796179

In [4]:
parameters, states = pickle.load(open('parameters/residual-network-on-shrinked-mnist-4', 'rb'))

In [5]:
def _normalized_convolution(**args):
  network = layers.convolution(no_bias=True, **args)
  network = layers.batch_normalization(network, fix_gamma=False)
  network = layers.ReLU(network)
  return network

In [6]:
network = layers.variable('data')
for index in range(3):
  network = _normalized_convolution(X=network, n_filters=16, kernel_shape=(5, 5), stride=(1, 1), pad=(2, 2))
  network = layers.pooling(X=network, mode='maximum', kernel_shape=(2, 2), stride=(2, 2), pad=(0, 0))

shared_weight = layers.variable('shared_weight')
shared_gamma = layers.variable('shared_gamma')
shared_beta = layers.variable('shared_beta')
kwargs = {'n_filters' : 16, 'kernel_shape' : (3, 3), 'stride' : (1, 1), 'pad' : (1, 1)}

identity = network
residual = layers.convolution(X=network, weight=shared_weight, no_bias=True, **kwargs)
network = identity + residual

for index in range(args.n_residual_layers - 1):
  network = layers.batch_normalization(network, beta=shared_beta, gamma=shared_gamma, fix_gamma=False)
  network = layers.ReLU(network)
  identity = network
  residual = layers.convolution(X=network, weight=shared_weight, no_bias=True, **kwargs)
  network = identity + residual

network = layers.pooling(X=network, mode='average', global_pool=True, kernel_shape=(1, 1), stride=(1, 1), pad=(0, 0))
network = layers.flatten(network)
network = layers.fully_connected(X=network, n_hidden_units=10)

In [7]:
def score_variance(scores):
    N, D = scores.shape
    mean = mx.nd.mean(scores, axis=1, keepdims=True)
    variance = mx.nd.mean((scores - mean) ** 2, axis=1, keepdims=True)
    return mx.nd.mean(variance)

In [8]:
from data_utilities import load_mnist
stretched_mnist = load_mnist(path='stretched_mnist', scale=1, shape=(1, 56, 56))
stretched_canvas_mnist = load_mnist(path='stretched_canvas_mnist', scale=1, shape=(1, 56, 56))

In [9]:
context = mx.cpu()
args = {key : mx.nd.array(value, context) for key, value in parameters.items()}
aux_states = {key : mx.nd.array(value, context) for key, value in states.items()}

In [10]:
args['data'] = mx.nd.array(stretched_mnist[4], context)
executor = network.bind(context, args, aux_states=aux_states)

In [11]:
scores = executor.forward()[0]
score_variance(scores).asscalar()

79.215363

In [12]:
args['data'][:] = mx.nd.array(stretched_mnist[2], context)
scores = executor.forward()[0]
score_variance(scores).asscalar()

80.613609

In [13]:
args['data'][:] = mx.nd.array(stretched_canvas_mnist[4], context)
scores = executor.forward()[0]
score_variance(scores).asscalar()

22.737444

In [14]:
args['data'][:] = mx.nd.array(stretched_canvas_mnist[2], context)
scores = executor.forward()[0]
score_variance(scores).asscalar()

21.999304