## Load Libraries and Required Data

In [45]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


**Some interesting questions:**

 - More routing iterations: Does more routing help classification. (This routing operation is sweet, can we benefit from it more)
 
 - Varying the number of capsules: Does the number(32)/size(8) of the PrimaryCaps layer help classification? (This network is heavy AF)
 
 - More initial convolutional layers: Do smaller filters (3x3, say) with more standard conv layers help classification? (This network is heavy AF)
 
 - No primary capsules: Can we get away with just convolutional layers, routing, DigiCaps and reconstruction?
 

---
## Matrix Multiplications Before Routing
Here we're working out the logic for the large number of matrix multiplications that need to happen before the routing procedure.

In [46]:
tf.reset_default_graph()

# --------------------------------------------

batch_size = 2        # typically: 64
num_of_caps_ops = 2   # typically: 32*6*6
caps_op_size = 3      # typically: 8
num_of_digi_caps = 2  # typically: 10 (b/c there are 10 digits/classes)
digi_caps_size = 2    # typically: 16

# --------------------------------------------

U = np.array([     # Notice U ~ (2, 2, 1, 3) ~ (batch_size, num_of_caps_ops, 1, caps_op_size)
              [
               [[1, 1, 1]], 
               [[2, 2, 2]]
              ],
              [
               [[7, 7, 7]], 
               [[11, 11, 11]]
              ]
             ])
W_single = [ np.array([   # Notice W_j ~ (2, 3, 2) ~ (num_of_caps_ops, caps_op_size, digi_caps_size)
                       [[1, 1],
                        [1, 1],
                        [1, 1]],
                       [[2, 2],
                        [2, 2],
                        [2, 2]]
                      ]),
              np.array([   # Notice W_j ~ (2, 3, 2) ~ (num_of_caps_ops, caps_op_size, digi_caps_size)
                       [[3, 3],
                        [3, 3],
                        [3, 3]],
                       [[5, 5],
                        [5, 5],
                        [5, 5]]
                      ])
           ]

# U = np.random.normal(size=(batch_size, num_of_caps_ops, 1, caps_op_size))
# W_j_single = np.random.normal(size=(num_of_caps_ops, caps_op_size, digi_caps_size))

U_ = tf.get_variable('U_', initializer=tf.constant(U))
U_ = tf.cast(U, tf.float32)
UHat = []

# --------------------------------------------

# basically U_ will be given to us as the output of the primary caps layer,
# from there we repeat the proceedure below 10 times, for 
# j in {1, 2, ..., 10}, namely we should have (distinct) tensors:
# W_1Single, W_2Single, ..., W_10Single.
for i in range(1, num_of_digi_caps + 1):
#     tf.reset_default_graph()
#     with tf.variable_scope('primary_caps_TO_{}'.format(i)):
    W_jSingle = tf.get_variable('W{}jSingle'.format(i), initializer=tf.constant(W_single[i - 1]))
    W_jSingle = tf.cast(W_jSingle, tf.float32)


    # here we want to reuse W_jSingle to carry out the matrix multiplication
    # for each item (primary caps output) in the batch
    temp = []
    for _ in range(1, batch_size + 1):
        temp.append(W_jSingle)
    W_j = tf.stack(temp)

    UHatj_ = tf.reshape(tf.matmul(U_, W_j), (batch_size, num_of_caps_ops, digi_caps_size))

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        print('---------------- _->{} ----------------\n'.format(i))
        print('shape of U_: {}'.format(sess.run(U_).shape))
        print('shape of W_jSingle: {}'.format(sess.run(W_jSingle).shape))
        print('shape of W_j: {}'.format(sess.run(W_j).shape))
        print('shape of UHatj_: {}'.format(sess.run(UHatj_).shape), '\n')
        print('\nmatmul output: \n{}\n'.format(sess.run(UHatj_)))
        UHat.append(sess.run(UHatj_))

---------------- _->1 ----------------

shape of U_: (2, 2, 1, 3)
shape of W_jSingle: (2, 3, 2)
shape of W_j: (2, 2, 3, 2)
shape of UHatj_: (2, 2, 2) 


matmul output: 
[[[  3.   3.]
  [ 12.  12.]]

 [[ 21.  21.]
  [ 66.  66.]]]

---------------- _->2 ----------------

shape of U_: (2, 2, 1, 3)
shape of W_jSingle: (2, 3, 2)
shape of W_j: (2, 2, 3, 2)
shape of UHatj_: (2, 2, 2) 


matmul output: 
[[[   9.    9.]
  [  30.   30.]]

 [[  63.   63.]
  [ 165.  165.]]]



In [50]:
with tf.Session() as sess:
    print(sess.run(tf.stack(UHat, axis=1)))

[[[[   3.    3.]
   [  12.   12.]]

  [[   9.    9.]
   [  30.   30.]]]


 [[[  21.   21.]
   [  66.   66.]]

  [[  63.   63.]
   [ 165.  165.]]]]


Ok so this way of doing things lookes like it's exhibiting the desired behaviour. In particular this should be able to handle batches of arbitrary size.

Next up: routing.

## Routing With Batches
Here we work out the routing procedure in the presence of batches.

We collect all of the $\widehat{u}_{j \vert *}$ into one tensor, stacking along `axis = 1`. 

In [None]:
UHat = tf.stack(UHat, axis=1)

To perform the routing we'll need a modified softmax function which can take the softmax along a specified axis. Note, `tf.nn.softmax` does not give the desired behavior in this specific instance, since `tf.nn.softmax` actually performs the softmax over the right-most (last?) axis.

Also, notice that 

$$softmax(x) = softmax(x - c), \quad c \in \mathbf{R}$$

Thus taking $c$ to be the max along the specified axis means that the largest term we ever exponentiate by is 0 (i.e., most terms appearing in the exponent will be negative), this is done with numerical stability in mind.

In [51]:
def softmax(inbound, axis=0, name=None):
    with tf.name_scope(name, 'softmax', [inbound]) as scope:
        # first we find the max entry along the specified axis
        max_along_axis= tf.reduce_max(inbound, axis, keep_dims=True)
        
        # then subtract the max from all entries to 
        # help with numerical stability
        exp = tf.exp(inbound - max_along_axis)                       
        
        # next we compute the term used for normalization
        normalizing_term = tf.reduce_sum(exp, axis, keep_dims=True)  
        
        # lastly, compute the softmax (along the specified axis)
        softmax = exp / normalizing_term               
        
        return tf.identity(softmax, name=scope)

In the example below we would like to take softmaxes over the "columns" `[1, 1]`, `[3, 5]`, `[7, 13]`, and `[2, 11]`, instead of over the "rows" `[1, 3]`, `[1, 5]`, `[7, 2]`, `[13, 11]` as would be done if we used `tf.nn.softmax`. 

Of course we could do this with `tf.transpose` and subsequently use the built-in `tf.nn.softmax`, but this would involve transposing, softmax-ing, and then transposing back (given the way we've set things up so far), so a small extra function for customized softmaxes seems reasonable.

In [87]:
xx = tf.reshape(tf.constant([[[[1, 3]], 
                              [[1, 5]]], 
                             [[[7, 2]], 
                              [[13, 11]]]]), shape=(2, 2, 1, 2))
xx = tf.cast(xx, tf.float32)
with tf.Session() as sess:
    print('test tensor: \n')
    print(sess.run(xx))
    print('\ncustom softmax: (the desired behaviour) \n')
    print(sess.run(softmax(xx, axis=1)))
    print('\nbuilt-in softmax: \n')
    print(sess.run(tf.nn.softmax(xx)))

test tensor: 

[[[[  1.   3.]]

  [[  1.   5.]]]


 [[[  7.   2.]]

  [[ 13.  11.]]]]

custom softmax: (the desired behaviour) 

[[[[  5.00000000e-01   1.19202919e-01]]

  [[  5.00000000e-01   8.80797029e-01]]]


 [[[  2.47262325e-03   1.23394580e-04]]

  [[  9.97527421e-01   9.99876618e-01]]]]

built-in softmax: 

[[[[ 0.11920291  0.88079703]]

  [[ 0.01798621  0.98201376]]]


 [[[ 0.99330717  0.00669285]]

  [[ 0.88079703  0.11920291]]]]


With the softmax out of the way we get to the actual routing algorithm.

---
## Building the Capsule Network
Now We start building the Capsule Network. We begin with a function to shorten up the call to `tf.layers.conv2d`

In [2]:
def conv_layer(inputs, filters, kernel_size=9, strides=1, padding='valid', activation=None):
    return tf.layers.conv2d(inputs=inputs, 
                            filters=filters, 
                            kernel_size=kernel_size, 
                            strides=strides, 
                            padding=padding, 
                            activation=activation, 
                            kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 0.001
ROUTING_ITERS = 3 

In [42]:
ones = tf.reshape(tf.ones([784]), [-1, 28, 28, 1])
out = conv_layer(ones, 256, activation=tf.nn.relu)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output = sess.run(out)
    data_shape = output.shape
    assert output.shape == (1, 20, 20, 256), 'Incorrect shape, {}, after regular conv layer. Should be {}'.format(data_shape, [1, 20, 20, 256])

In [2]:
def caps_net_model_fn(features, labels, mode):
    
    # image layer
    assert features['x'].get_shape()[1:] == [784] 
    input_layer = tf.reshape(features['x'], shape=(-1, 28, 28, 1))
    
    # save the batch size for later (<= 64)
    batch_size = input_layer.get_shape()[0]
    n_primary_caps = 32
    primary_caps_size = 8
    digi_caps_size = 16
    
    # conv layer (regular convolutional layer)
    with tf.variable_scope('regular_conv_layer'):
        conv1 = conv_layer(input_layer, 256, activation=tf.nn.relu)
        data_shape = conv1.get_shape()
        e1 = 'Incorrect shape, {}, after regular conv layer. Should be {}'.format(data_shape, [batch_size, 20, 20, 256])
        assert data_shape == [batch_size, 20, 20, 256], e1
    
    # first capsule layer (PrimaryCaps)
    capsules = []
    for i in range(32):
        # naming convention: capsule_[capsule layer]_[capsule index]
        with tf.variable_scope('capsule_1_' + str(i + 1)):
            caps_i = conv_layer(conv1, primary_caps_size, strides=2)
            reshape = tf.reshape(caps_i, shape=(-1, 6*6, primary_caps_size))
            capsules.append(reshape)
    assert capsules[0].get_shape() == [batch_size, 6*6, primary_caps_size]
    
    # stack and reshape
    #
    # here we reshape the capsule outputs (i.e. the 8D vectors)
    # to be 1x8 matrices, this will enable us to use tf.matmul
    # to calculate the inputs (u_hat_ij's) to each DigiCaps
    # capsule in one shot
    capsules = tf.stack(capsules, axis=1)
    capsules = tf.reshape(capsules, shape=(-1, 6*6*n_primary_caps, 1, primary_caps_size))
    assert capsules.get_shape() == [batch_size, 6*6*n_primary_caps, 1, primary_caps_size]
    
    # second capsule layer (DigiCaps)
    u_hat = []
    for j in range(10):
        with tf.variable_scope('->capsule_{}'.format(j)):
            name = 'W_i{}'.format(j)
            weights_to_j = tf.get_variable(name=name, shape=(1, 6*6*n_primary_caps, primary_caps_size, digi_caps_size))
            weights_to_j = tf.tile(weights_to_j, (batch_size, 1, 1, 1))
            u_hat_ji = tf.matmul(capsules, weights_to_j, )
            
    
    
            