# Capsule Networks
## Imports

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

## Squash Function

This is the custom activation function shown in the paper. It is called as the squash function. It is calculated using the L2 norm of the vector given to this function. 

\begin{equation*}
v_j = \frac{||s_j||^2 * s_j}{(||s_j||^2 + 1) * ||s_j|| }
\end{equation*}

$v_j$ is the output of capsule j where $s_j$ is total input of the capsule j.


In [1]:
def squash(inputs, epsilon):
    with tf.name_scope('squash'): #Set name_scope for tensorboard
        input_squared_norm = tf.reduce_sum(tf.square(inputs), -2, keep_dims=True, name='squared_norm') #gets the squared L2 norm
        scalar_factor = input_squared_norm / (1 + input_squared_norm) 
        scalar_factor = scalar_factor / tf.sqrt(input_squared_norm + epsilon)
        squashed_input = tf.multiply(scalar_factor, inputs, name='squashed_output') #Return v_j
    return squashed_input

## Primary Capsule

First we create a conv2d layer with dimensions [None, 32 * 8], 9 such kernels are created each with a stride of 2 and no padding. We use ReLU activations(The activations were not mentioned in the paper, but I assume they have used them). The capsules are then reshaped to [-1, 1152, 8, 1] shape. We add the squash activation to this layer and return the same.

In [3]:
def add_capsule(inputs, batch_size, outputs, vector_length, strides, num_kernels, epsilon):
    capsules = tf.contrib.layers.conv2d(inputs,  outputs * vector_length, num_kernels, strides, padding="VALID", activation_fn=tf.nn.relu) #Get the conv2d layer
    capsules = tf.reshape(capsules, [-1, 1152, vector_length, 1], name='capsules_reshape') #reshape them to required dims
    capsules = squash(capsules, epsilon) #Add activation
    assert(capsules.get_shape()[1:] == [1152, 8 , 1])
    print('Primary Capsule Shape '+str(capsules.get_shape()))
    return capsules

In [4]:
def get_weights(name, stddev, shape): #Driver to get weights, straight forward
    return tf.Variable(tf.truncated_normal(shape=shape, stddev=stddev), name=name)

## Digit Caps Layer

Get the inputs and reshape them to [-1, 1152, 1, 8, 1]. We now initialize b_ij to all zeros with the following shape, [inputs.shape[0], 1152, 10, 1, 1]. We also initialize the weights to [1, 1152, 10, 8, 16] dims. the inputs that we reshaped initially will now be replicated 10 times along axis = 2. So inputs are now of the form [-1, 1152, 10, 8, 1]. The weights are replicated for batch_size along axis = 0 (Had to be done because there is no dynamic initialization).

We now multiply weights and the transpose of inputs, so we get u_cap with dims [-1, 1152, 10, 16, 1]. We also have u_cap_not_passed to represent weights where the grandient should not be passed, it is initialized to u_cap. 

Now we need to perform dynamic routing between the primary caps and digit caps layers. This is controlled by the number of routing iterations between the 2 layers. The paper suggests use of 3 iterations.

c_ij is calculated as the softmax of the output vectors in b_ij(hence dim = 2).

If we are not in the last round of iteration:
* then we multiply c_ij and u_cap_not_passed, find the sume of s_j along axis = 1, we get a matrix of dims [batch_size, 1, 10, 16, 1].
* We squash this vector, we pass a small value epsilon to prevent divide by zero Exception. v_j_replica, replicates the v_j matrix to dims [batch_size, 1152, 10, 16, 1].
* We now find u_v as product of u_cap_not_passed and v_j_replica transpose, we get a matrix of the dim [batch_size, 1152, 10, 1, 1]. c_ij will take softmax along dim = 2 in the next iteration.

If we are in the last round of iteration: 
* We do the same as above, but return the output of the squash function, without performing any changes to output. 

In [5]:
def add_fc(inputs, batch_size, outputs, stddev, routing_iters, epsilon):
    inputs = tf.reshape(inputs, [-1, 1152, 1, inputs.shape[-2].value, 1])
    inputs_shape = inputs.get_shape()
    print(inputs_shape)
    b_ij = tf.constant(np.zeros([batch_size, inputs.shape[1].value, outputs, 1, 1], dtype=np.float32), name='b_ij')
#     b_ij = tf.zeros([batch_size, inputs.shape[1].value, outputs, 1, 1], dtype=np.float32), name='b_ij')
    W = get_weights('Weight', shape=(1, 1152, 10, 8, 16), stddev=stddev)

    inputs = tf.tile(inputs, [1, 1, 10, 1, 1])
    W = tf.tile(W, [batch_size, 1, 1, 1, 1])
    assert(inputs.get_shape()[1:] == [1152, 10, 8, 1])
    
    u_cap = tf.matmul(W, inputs, transpose_a=True)
    assert(u_cap.get_shape()[1:] == [1152, 10, 16, 1])

    u_cap_not_passed = tf.stop_gradient(u_cap, name='stop_gradient')

    for iter_i in range(routing_iters):
        with tf.variable_scope('iter_' + str(iter_i)):
            c_ij = tf.nn.softmax(b_ij, dim=2)
            
            if iter_i == routing_iters - 1:
                s_j = tf.multiply(c_ij, u_cap)
                s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True)
                assert(s_j.get_shape()[1:] == [1, 10, 16, 1])
                v_j = squash(s_j, epsilon)
                assert(v_j.get_shape()[1:] == [1, 10, 16, 1])
            
            elif iter_i < routing_iters - 1:
                s_j = tf.multiply(c_ij, u_cap_not_passed)
                s_j = tf.reduce_sum(s_j, axis=1, keep_dims=True)
                v_j = squash(s_j, epsilon)
                v_j_replica = tf.tile(v_j, [1, 1152, 1, 1, 1])
                u_v = tf.matmul(u_cap_not_passed, v_j_replica, transpose_a=True)
                assert(u_v.get_shape() == b_ij.get_shape())
                b_ij += u_v
    
    print('Digit Caps Shape '+str(v_j.get_shape()))
    return v_j

## Setup Model

* Setup the place holders, x_i, y.
* Reshape x_i to [-1, 28, 28, 1]
* Add a ReLUConv with 256 ouputs and of size 9 * 9
* Add the PrimaryCaps Layer
* Add DigitCaps layer, the stddev for weight initialization is taken as 0.01
* Return the output of the DigitCaps layer, because from here, we either train or check for accuracy.

Name Scopes have been added for ease in future use and for tensorboard visualizations, which really help in better understanding this complex model

In [6]:
def get_model(batch_size, epsilon):
    with tf.name_scope('inputs'):
        x_i = tf.placeholder(tf.float32, shape=[None, 784], name = 'x')
        y = tf.placeholder(tf.float32, shape=[None, 10], name = 'y')
        
    with tf.name_scope('reshape_input'):
        x = tf.reshape(x_i, [-1,28,28,1], name='reshape_x')
    
    with tf.name_scope('ReLUConv1'):
        conv1 = tf.nn.relu(tf.contrib.layers.conv2d(x, num_outputs=256, kernel_size=9, stride=1, padding='VALID'), name = 'ReLUConv1')
        print('Conv1 Shape: '+str(conv1.get_shape()))
        assert(conv1.get_shape()[1:] == [20, 20, 256])

    with tf.name_scope('PrimaryCaps'):
        primary_caps = add_capsule(conv1, batch_size = batch_size, outputs=32, vector_length=8, strides=2, num_kernels=9, epsilon=epsilon)
        assert(primary_caps.get_shape()[1:] == [1152, 8, 1])

    with tf.name_scope('DigitCaps'):
        digit_caps = add_fc(primary_caps, batch_size = batch_size, outputs=10, stddev = 0.01, routing_iters= 3, epsilon = epsilon)
        digit_caps = tf.reshape(digit_caps, [-1, 10, 16, 1])
        print('Reshaped Digit Caps: '+str(digit_caps.get_shape()))
        assert(digit_caps.get_shape()[1:] == [10, 16, 1])
        
    return x_i, y, digit_caps

## Train

This is the driver for the train function. We take epsilon to be $10^{-9}$ in this case. The function is explained below:

* We get the place holders and digit caps output from the compute graph that we have set up above. 
* In the output name scope, we find the L2 norm of each the vectors of length 16. This is the required output.
* We find the softmax along the L2 norm.
* Reshape the formax to get of the form [batch_size, 10] (This is done in order to use it in accuracy calculations, no use in loss, but since this is the output, it is under the output name scope.).

To calculate loss, we perform the following operations:
* Get the L2 norm output and reshape to [batch_size, 10]. Now we calculate the loss using the custom loss function suggested in the paper:

\begin{equation*}
loss_c = T_c max(0, m^+ - ||v_c||)^2 + \lambda(1-T_c) min (0, ||v_c|| - m^-)^2
\end{equation*}

* $loss_c$ is the loss for class 'c', so the matrix representation of loss, is to just take $T_c$ as $y$ and take $||v_c||$ as the calculated L2 norm.
* We try to reduce the mean of the loss, for each class.

Accuracy Calculations are quite straight forward and need no explanation.

In [7]:
def train(train_data, batch_size, iters, m_plus, m_minus, lambda_val):
    epsilon = 1e-9
    x, y, digit_caps = get_model(batch_size, epsilon)
    with tf.name_scope('output'):
        v_l2 = tf.sqrt(tf.reduce_sum(tf.square(digit_caps), axis=2, keep_dims=True), name='output')
        print('v_l2 shape: '+str(v_l2.get_shape()))
        softmax_v =tf.reshape(tf.nn.softmax(v_l2, dim=1), [-1, 10], name='softmax_output')
        print('Softmax shape: '+str(softmax_v.get_shape()))
        assert(softmax_v.get_shape() == [batch_size, 10])
        
    with tf.name_scope('loss'):
        v_l = tf.reshape(v_l2, [-1, 10])
        lc = y * tf.square(tf.maximum(0.0, m_plus - v_l)) + lambda_val * (1 - y) * tf.square(tf.maximum(0.0, v_l - m_minus))
        loss = tf.reduce_mean(tf.reduce_sum(lc, axis= 1), name='loss')
    
    with tf.name_scope('train'):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
        tf.summary.scalar('loss', loss)
        
    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(softmax_v, 1), tf.argmax(y, 1))
        correct_prediction = tf.cast(correct_prediction, tf.float32)
        accuracy = tf.reduce_mean(correct_prediction, name = 'accuracy')
        tf.summary.scalar('accuracy', accuracy)
    
    summ = tf.summary.merge_all()
    
    builder = tf.saved_model.builder.SavedModelBuilder('models/')
    writer = tf.summary.FileWriter('logs_output_1')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for i in range(iters):
            batch = train_data.next_batch(batch_size)
            sess.run([train_step], feed_dict={x:batch[0], y:batch[1]})
            if i%10 == 0:
                acc, s = sess.run([accuracy, summ],feed_dict={x:batch[0], y:batch[1]})
                print("Accuracy after "+str(i)+" iterations is "+str(acc))
                writer.add_summary(s, i)
        builder.add_meta_graph_and_variables(sess,['model_'+str(iters)+'iters'])
        builder.save()
        writer.close()
        print("Complete!")
    return

## Test Function

This is a standard testing function, only thing to note is that since there is no dynamic creation in TensorFlow, we have to pass tests in batches of the same size. We take the average test accuracy. 

In [13]:
def test(test_data, iters, batch_size):
    tf.reset_default_graph()
    with tf.Session() as sess:
        tf.saved_model.loader.load(sess,['model_'+str(iters)+'iters'],'models/')
        test_len = len(test_data.images) // batch_size
        mean_acc = 0.0
        for i in range(test_len):
            acc = sess.run(['accuracy/accuracy:0'], feed_dict={'inputs/x:0':test_data.images[i*128:(i+1)*128], 'inputs/y:0':test_data.labels[i*128:(i+1)*128]})[0]
            print('Accuracy in iter '+str(i)+' is '+str(acc))
            mean_acc += acc
        print("Test Accuracy is "+str(mean_acc/test_len))    
    return

## Main Driver

Just run this function to train and see outputs. The hyper parameters are shown in the first few lines of this function. They are chosen as per the paper specifications.

In [9]:
def main():
    epochs = 10
    batch_size = 128
    m_plus = 0.9
    m_minus = 0.1
    lambda_val = 0.5
    
    if not tf.gfile.Exists('dataset/'):
        print('creating dataset dir')
        tf.gfile.MakeDirs('dataset')
    print('Loading dataset')
    mnist = input_data.read_data_sets('dataset/', one_hot=True)
    iters = 10*len(mnist.train.images) // batch_size
    print('Loading Complete')
    print('Will be running for '+str(iters)+' iters')
    if tf.gfile.Exists('models/'):
        flag = input('Do you want to reset the model and train: ')
        if flag == 'y' or flag == 'Y':
            tf.gfile.DeleteRecursively('models/')
            train(mnist.train, batch_size, iters, m_plus, m_minus, lambda_val)
    else:
        train(mnist.train, batch_size, iters, m_plus, m_minus, lambda_val)
    
    test(mnist.test, iters, batch_size)

In [14]:
main()

Loading dataset
Extracting dataset/train-images-idx3-ubyte.gz
Extracting dataset/train-labels-idx1-ubyte.gz
Extracting dataset/t10k-images-idx3-ubyte.gz
Extracting dataset/t10k-labels-idx1-ubyte.gz
Loading Complete
Will be running for 4296 iters
Do you want to reset the model and train: n
INFO:tensorflow:Restoring parameters from b'models/variables/variables'
Accuracy in iter 0 is 1.0
Accuracy in iter 1 is 0.992188
Accuracy in iter 2 is 0.976562
Accuracy in iter 3 is 0.984375
Accuracy in iter 4 is 0.992188
Accuracy in iter 5 is 0.960938
Accuracy in iter 6 is 1.0
Accuracy in iter 7 is 0.984375
Accuracy in iter 8 is 0.984375
Accuracy in iter 9 is 0.976562
Accuracy in iter 10 is 0.984375
Accuracy in iter 11 is 0.976562
Accuracy in iter 12 is 0.992188
Accuracy in iter 13 is 0.984375
Accuracy in iter 14 is 0.976562
Accuracy in iter 15 is 0.984375
Accuracy in iter 16 is 0.976562
Accuracy in iter 17 is 0.992188
Accuracy in iter 18 is 0.992188
Accuracy in iter 19 is 0.992188
Accuracy in iter 2

That wasn't the most elegant way of showing the output, but yes we get an average test accuracy of 99.11% . 