In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from contextlib import contextmanager

from rnndatasets import sequentialmnist as mnist

from itertools import chain

Can we make a network that starts small and grows by splitting its layers?

In [2]:
@contextmanager
def new_collection(name):
    old_tvars = tf.trainable_variables()
    yield
    for var in tf.trainable_variables():
        if var not in old_tvars:
            tf.add_to_collection(name, var)

In [3]:
def affine(input_var, new_size, weights_initialiser=None, bias_initialiser=None, return_weights=False):
    input_size = input_var.get_shape()[1].value
    
    if type(weights_initialiser) == np.ndarray:
        weight_shape = None
    else:
        weight_shape = [input_size, new_size]
    
    if type(bias_initialiser) == np.ndarray:
        bias_shape = None
    else:
        bias_shape = [new_size]
    
    weights = tf.get_variable('weights', weight_shape,
                             initializer=weights_initialiser)
    bias = tf.get_variable('bias', bias_shape,
                           initializer=bias_initialiser)
    results = tf.nn.bias_add(tf.matmul(input_var, weights), bias)
    
    if return_weights:
        return results, (weights, bias)
    return results

In [4]:
def initial_net(input_var, output_shape):
    return affine(input_var, output_shape)

In [5]:
def batch_iter(data, labels, batch_size):
    num_batches = data.shape[0] // batch_size
    
    for i in range(num_batches):
        yield data[i*batch_size:(i+1)*batch_size, ...], labels[i*batch_size:(i+1)*batch_size]

In [6]:
def lennox(activations, sess=None):
    num_features = activations.get_shape()[1].value
    a = tf.get_variable('a', [num_features], initializer=tf.constant_initializer(1.0))
    b = tf.get_variable('b', [num_features], initializer=tf.constant_initializer(1.0))
    c = tf.get_variable('c', [num_features], initializer=tf.constant_initializer(0.0))
    
    centered = activations - c
    
    return tf.select(centered > 0, a * centered, b * centered)

In [7]:
def get_metrics(net, targets):
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(net, targets)
    loss = tf.reduce_mean(loss)
    acc = tf.contrib.metrics.accuracy(tf.cast(tf.argmax(net_out, 1), tf.int32),
                                      targets)
    return loss, acc

def split(session, split_layer, output_layer, layer_in, targets, scope):
    # figure out the new value in numpy to keep the graph clean
    weights_var, bias_var = split_layer
    weights_val, bias_val = session.run([weights_var, bias_var])
    u, s, vT = np.linalg.svd(weights_val, full_matrices=False)
    print(u.shape, s.shape, vT.shape)
    
    with tf.variable_scope('model'):
        with new_collection('weights'):
            with tf.variable_scope(scope + '_split_1'):
                hiddens, hh_vars = affine(layer_in, len(s), weights_initialiser=u, bias_initialiser=bias_val,
                                                 return_weights=True)
            with tf.variable_scope(scope + '_split_2'):
                net_out, split_vars = affine(lennox(hiddens), len(s), weights_initialiser=np.dot(np.diag(s), vT),
                                             bias_initialiser=tf.constant_initializer(0.0), return_weights=True)
        with tf.variable_scope('output', reuse=True):
            net_out = affine(lennox(net_out), 10)
        
    loss, acc = get_metrics(net_out, targets)
    
    opt = tf.train.AdamOptimizer(0.001)
    train_op = opt.minimize(loss)
    
    return train_op, split_vars, loss, acc
    

In [8]:
tf.reset_default_graph()

inputs = tf.placeholder(tf.float32, [None, 784])
targets = tf.placeholder(tf.int32, [None])

with tf.variable_scope('model'):
    with new_collection('weights'):
        net_out, split_layer = affine(inputs, 50, return_weights=True)
        with tf.variable_scope('output'):
            net_out, output_layer = affine(lennox(net_out), 10, return_weights=True)

loss, acc = get_metrics(net_out, targets)

opt = tf.train.AdamOptimizer(0.001)
train_op = opt.minimize(loss, var_list=tf.get_collection('weights'))

In [9]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

losses = []
valid_losses = []

for epoch in range(500):
    train, valid, test = mnist.get_iters(batch_size=64, shuffle=True)
    
    epoch_loss = 0
    epoch_steps = 0
    for dbatch, tbatch in train:
        dbatch = dbatch.transpose((1, 0, 2))
        batch_loss, _ = sess.run([loss, train_op],
                                 {inputs: dbatch.reshape((-1, 784)),
                                  targets: tbatch})
        epoch_loss += batch_loss
        epoch_steps += 1
        
    if epoch % 1 == 0:
        valid_loss = 0
        valid_steps = 0
        for dbatch, tbatch in valid:
            dbatch = dbatch.transpose((1, 0, 2))
            batch_loss = sess.run(acc,
                                 {inputs: dbatch.reshape((-1, 784)),
                                  targets: tbatch})
            valid_loss += batch_loss
            valid_steps += 1
            
        valid_losses.append(valid_loss/valid_steps)
    
    
    print('\r({}) ~~ {}  (valid: {})'.format(epoch, epoch_loss/epoch_steps, valid_loss/valid_steps), end='')
    losses.append(epoch_loss/epoch_steps)
    
    if len(valid_losses) > 6 and valid_loss/valid_steps <= valid_losses[-1]:
        print('\ntime to split')
        train_op, split_layer, loss, acc = split(sess, split_layer, output_layer, inputs, targets, 'split_{}'.format(epoch))
        
        uninits = [var for var in tf.all_variables() if not sess.run(tf.is_variable_initialized(var))]
        print([var.name for var in uninits])
        sess.run([var.initializer for var in uninits])

(6) ~~ 0.17306154111588656  (valid: 0.9530248397435898)
time to split
(784, 50) (50,) (50, 50)
['model/split_6_split_1/weights:0', 'model/split_6_split_1/bias:0', 'model/split_6_split_2/a:0', 'model/split_6_split_2/b:0', 'model/split_6_split_2/c:0', 'model/split_6_split_2/weights:0', 'model/split_6_split_2/bias:0', 'beta1_power_1:0', 'beta2_power_1:0', 'model/output/a/Adam_2:0', 'model/output/a/Adam_3:0', 'model/output/b/Adam_2:0', 'model/output/b/Adam_3:0', 'model/output/c/Adam_2:0', 'model/output/c/Adam_3:0', 'model/output/weights/Adam_2:0', 'model/output/weights/Adam_3:0', 'model/output/bias/Adam_2:0', 'model/output/bias/Adam_3:0', 'model/split_6_split_1/weights/Adam:0', 'model/split_6_split_1/weights/Adam_1:0', 'model/split_6_split_1/bias/Adam:0', 'model/split_6_split_1/bias/Adam_1:0', 'model/split_6_split_2/a/Adam:0', 'model/split_6_split_2/a/Adam_1:0', 'model/split_6_split_2/b/Adam:0', 'model/split_6_split_2/b/Adam_1:0', 'model/split_6_split_2/c/Adam:0', 'model/split_6_split_2/c/

ValueError: Dimensions 784 and 50 are not compatible

In [None]:
plt.plot(losses)
plt.show()
plt.plot(np.arange(len(valid_losses)) * 1, valid_losses)
plt.show()