In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
def squash(inputs, epsilon):
    with tf.name_scope('squash'):
        input_squared_norm = tf.reduce_sum(tf.square(inputs), -2, keep_dims=True, name='squared_norm')
        scalar_factor = input_squared_norm / (1 + input_squared_norm) 
        scalar_factor = scalar_factor / tf.sqrt(input_squared_norm + epsilon)
        squashed_input = tf.multiply(scalar_factor, inputs, name='squashed_output')
    return squashed_input

In [3]:
def add_capsule(inputs, batch_size, outputs, vector_length, strides, num_kernels, epsilon):
    capsules = tf.contrib.layers.conv2d(inputs,  outputs * vector_length, num_kernels, strides, padding="VALID", activation_fn=tf.nn.relu)
    capsules = tf.reshape(capsules, [-1, 1152, vector_length, 1], name='capsules_reshape')
    capsules = squash(capsules, epsilon)
    assert(capsules.get_shape()[1:] == [1152, 8 , 1])
    print('Primary Capsule Shape '+str(capsules.get_shape()))
    return capsules

In [4]:
def get_weights(name, stddev, shape):
    return tf.Variable(tf.truncated_normal(shape=shape, stddev=stddev), name=name)

In [5]:
def add_fc(inputs, batch_size, outputs, stddev, routing_iters, epsilon):
    inputs = tf.reshape(inputs, [-1, 1152, 1, inputs.shape[-2].value, 1])
    inputs_shape = inputs.get_shape()
    print(inputs_shape)
    b_ij = tf.constant(np.zeros([batch_size, inputs.shape[1].value, outputs, 1, 1], dtype=np.float32), name='b_ij')
    W = get_weights('Weight', shape=(1, 1152, 10, 8, 16), stddev=stddev)

    inputs = tf.tile(inputs, [1, 1, 10, 1, 1])
    W = tf.tile(W, [batch_size, 1, 1, 1, 1])
    assert(inputs.get_shape()[1:] == [1152, 10, 8, 1])
    
    u_cap = tf.matmul(W, inputs, transpose_a=True)
    assert(u_cap.get_shape()[1:] == [1152, 10, 16, 1])

    u_cap_not_passed = tf.stop_gradient(u_cap, name='stop_gradient')

    for iter_i in range(routing_iters):
        with tf.variable_scope('iter_' + str(iter_i)):
            c_ij = tf.nn.softmax(b_ij, dim=2)
            
            if iter_i == routing_iters - 1:
                s_j = tf.multiply(c_ij, u_cap)
                s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True)
                assert(s_j.get_shape()[1:] == [1, 10, 16, 1])
                v_j = squash(s_j, epsilon)
                assert(v_j.get_shape()[1:] == [1, 10, 16, 1])
            
            elif iter_i < routing_iters - 1:
                s_j = tf.multiply(c_ij, u_cap_not_passed)
                s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True)
                v_j = squash(s_j, epsilon)
                v_j_replica = tf.tile(v_j, [1, 1152, 1, 1, 1])
                u_v = tf.matmul(u_cap_not_passed, v_j_replica, transpose_a=True)
                assert(u_v.get_shape() == b_ij.get_shape())
                b_ij += u_v
    
    print('Digit Caps Shape '+str(v_j.get_shape()))
    return v_j

In [6]:
def get_model(batch_size, epsilon):
    with tf.name_scope('inputs'):
        x_i = tf.placeholder(tf.float32, shape=[None, 784], name = 'x')
        y = tf.placeholder(tf.float32, shape=[None, 10], name = 'y')
        
    with tf.name_scope('reshape_input'):
        x = tf.reshape(x_i, [-1,28,28,1], name='reshape_x')
    
    with tf.name_scope('ReLUConv1'):
        conv1 = tf.nn.relu(tf.contrib.layers.conv2d(x, num_outputs=256, kernel_size=9, stride=1, padding='VALID'), name = 'ReLUConv1')
        print('Conv1 Shape: '+str(conv1.get_shape()))
        assert(conv1.get_shape()[1:] == [20, 20, 256])

    with tf.name_scope('PrimaryCaps'):
        primary_caps = add_capsule(conv1, batch_size = batch_size, outputs=32, vector_length=8, strides=2, num_kernels=9, epsilon=epsilon)
        assert(primary_caps.get_shape()[1:] == [1152, 8, 1])

    with tf.name_scope('DigitCaps'):
        digit_caps = add_fc(primary_caps, batch_size = batch_size, outputs=10, stddev = 0.01, routing_iters= 1, epsilon = epsilon)
        digit_caps = tf.reshape(digit_caps, [-1, 10, 16, 1])
        print('Reshaped Digit Caps: '+str(digit_caps.get_shape()))
        assert(digit_caps.get_shape()[1:] == [10, 16, 1])
        
    return x_i, y, digit_caps

In [7]:
def train(train_data, batch_size, iters, m_plus, m_minus, lambda_val):
    epsilon = 1e-9
    x, y, digit_caps = get_model(batch_size, epsilon)
    with tf.name_scope('output'):
        v_l2 = tf.sqrt(tf.reduce_sum(tf.square(digit_caps), axis=2, keep_dims=True) + epsilon, name='output')
        print('v_l2 shape: '+str(v_l2.get_shape()))
        softmax_v =tf.reshape(tf.nn.softmax(v_l2, dim=1), [-1, 10], name='softmax_output')
        print('Softmax shape: '+str(softmax_v.get_shape()))
        assert(softmax_v.get_shape() == [batch_size, 10])
        
    with tf.name_scope('loss'):
        v_l = tf.reshape(v_l2, [-1, 10])
        lc = y * tf.square(tf.maximum(0.0, m_plus - v_l)) + lambda_val * (1 - y) * tf.square(tf.maximum(0.0, v_l - m_minus))
        loss = tf.reduce_mean(tf.reduce_sum(lc, axis= 1), name='loss')
    
    with tf.name_scope('train'):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
        tf.summary.scalar('loss', loss)
        
    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(softmax_v, 1), tf.argmax(y, 1))
        correct_prediction = tf.cast(correct_prediction, tf.float32)
        accuracy = tf.reduce_mean(correct_prediction, name = 'accuracy')
        tf.summary.scalar('accuracy', accuracy)
    
    summ = tf.summary.merge_all()
    
    builder = tf.saved_model.builder.SavedModelBuilder('models/')
    writer = tf.summary.FileWriter('logs_output_1')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for i in range(iters):
            batch = train_data.next_batch(batch_size)
            sess.run([train_step], feed_dict={x:batch[0], y:batch[1]})
            if i%10 == 0:
                acc, s = sess.run([accuracy, summ],feed_dict={x:batch[0], y:batch[1]})
                print("Accuracy after "+str(i)+" iterations is "+str(acc))
                writer.add_summary(s, i)
        builder.add_meta_graph_and_variables(sess,['model_'+str(iters)+'iters'])
        builder.save()
        writer.close()
        print("Complete!")
    return

In [8]:
def main():
    
    iters = 1000
    batch_size = 128
    m_plus = 0.9
    m_minus = 0.1
    lambda_val = 0.5
    
    if not tf.gfile.Exists('dataset/'):
        print('creating dataset dir')
        tf.gfile.MakeDirs('dataset')
    print('Loading dataset')
    mnist = input_data.read_data_sets('dataset/', one_hot=True)
    print('Loading Complete')
    if tf.gfile.Exists('models/'):
        flag = input('Do you want to reset the model and train: ')
        if flag == 'y' or flag == 'Y':
            tf.gfile.DeleteRecursively('models/')
            train(mnist.train, batch_size, iters, m_plus, m_minus, lambda_val)
    else:
        train(mnist.train, batch_size, iters, m_plus, m_minus, lambda_val)

In [9]:
main()

Loading dataset
Extracting dataset/train-images-idx3-ubyte.gz
Extracting dataset/train-labels-idx1-ubyte.gz
Extracting dataset/t10k-images-idx3-ubyte.gz
Extracting dataset/t10k-labels-idx1-ubyte.gz
Loading Complete
Do you want to reset the model and train: y
Conv1 Shape: (?, 20, 20, 256)
Primary Capsule Shape (?, 1152, 8, 1)
(?, 1152, 1, 8, 1)
Digit Caps Shape (128, 1, 10, 16, 1)
Reshaped Digit Caps: (128, 10, 16, 1)
v_l2 shape: (128, 10, 1, 1)
Softmax shape: (128, 10)
Accuracy after 0 iterations is 0.132812
Accuracy after 10 iterations is 0.132812
Accuracy after 20 iterations is 0.0859375
Accuracy after 30 iterations is 0.109375
Accuracy after 40 iterations is 0.28125
Accuracy after 50 iterations is 0.421875
Accuracy after 60 iterations is 0.539062
Accuracy after 70 iterations is 0.75
Accuracy after 80 iterations is 0.828125
Accuracy after 90 iterations is 0.703125
Accuracy after 100 iterations is 0.835938
Accuracy after 110 iterations is 0.835938
Accuracy after 120 iterations is 0.80