## Load Libraries and Required Data

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


 - More routing iterations: Does more routing help classification. (This routing operation is sweet, can we benefit from it more)
 
 - Varying the number of capsules: Does the number(32)/size(8) of the PrimaryCaps layer help classification? (This network is heavy as f!@#)
 
 - More initial convolutional layers: Do smaller filters (3x3, say) with more standard conv layers help classification? (This network is heavy as f!@#)
 
 - No primary capsules: Can we get away with just convolutional layers, routing, DigiCaps and reconstruction?
 

---
## Matrix Multiplications Before Routing
Here we're working out the logic for the large number of matrix multiplications that need to happen before the routing procedure.

In [70]:
tf.reset_default_graph()

# --------------------------------------------

batch_size = 2      # typically: 64
num_of_caps_ops = 2 # typically: 32*6*6
caps_op_size = 3    # typically: 8
digi_caps_size = 2  # typically: 16

# --------------------------------------------

U = np.array([     # Notice U ~ (2, 2, 1, 3) ~ (batch_size, num_of_caps_ops, 1, caps_op_size)
              [
               [[1, 1, 1]], 
               [[2, 2, 2]]
              ],
              [
               [[3, 3, 3]], 
               [[4, 4, 4]]
              ]
             ])
W_j_single = np.array([   # Notice W_j ~ (2, 3, 2) ~ (num_of_caps_ops, caps_op_size, digi_caps_size)
                       [[1, 1],
                        [1, 1],
                        [1, 1]],
                       [[2, 2],
                        [2, 2],
                        [2, 2]]
                      ])

# U = np.random.normal(size=(batch_size, num_of_caps_ops, 1, caps_op_size))
# W_j_single = np.random.normal(size=(num_of_caps_ops, caps_op_size, digi_caps_size))

U_ = tf.get_variable('U_', initializer=tf.constant(U))
U_ = tf.cast(U, tf.float32)

# --------------------------------------------

# basically U_ will be given to us as the output of the primary caps layer,
# from there we repeat the proceedure below 10 times, for 
# j in {1, 2, ..., 10}, namely we should have (distinct) tensors:
# W_1Single, W_2Single, ..., W_10Single.
W_jSingle = tf.get_variable('W_jSingle', initializer=tf.constant(W_j_single))
W_jSingle = tf.cast(W_jSingle, tf.float32)

# here we want to reuse W_jSingle to carry out the matrix multiplication
# for each item (primary caps output) in the batch
temp = []
for i in range(1, batch_size + 1):
    temp.append(W_jSingle)
W_j = tf.stack(temp)
    
UHatj_ = tf.reshape(tf.matmul(U_, W_j), (batch_size, num_of_caps_ops, digi_caps_size))

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('shape of U_: {}'.format(sess.run(U_).shape))
    print('shape of W_jSingle: {}'.format(sess.run(W_jSingle).shape))
    print('shape of W_j: {}'.format(sess.run(W_j).shape))
    print('shape of UHatj_: {}'.format(sess.run(UHatj_).shape), '\n')
    print('desired output: \n{0}'.format(np.array([[[3., 3.],[12., 12.]],[[9., 9.],[24., 24.]]])))
    print('\nmatmul output: \n{}'.format(sess.run(UHatj_)))

shape of U_: (2, 2, 1, 3)
shape of W_jSingle: (2, 3, 2)
shape of W_j: (2, 2, 3, 2)
shape of UHatj_: (2, 2, 2) 

desired output: 
[[[  3.   3.]
  [ 12.  12.]]

 [[  9.   9.]
  [ 24.  24.]]]

matmul output: 
[[[  3.   3.]
  [ 12.  12.]]

 [[  9.   9.]
  [ 24.  24.]]]


Ok so this way of doing things lookes like it's exhibiting the desired behaviour. In particular this should be able to handle batches of arbitrary size.

Next up: routing.

## Routing With Batches
Here we work out the routing procedure in the presence of batches.

Notice that `tf.slice` behaves as follows:

In [91]:
tf.reset_default_graph()
M = tf.constant([[[1, 2], [3, 4]], 
                 [[5, 6], [7, 8]],
                 [[9, 10], [11, 12]]])
s1 = tf.slice(M, begin=[1, 0, 0], size=[1, 2, 2])
s2 = tf.slice(M, begin=[1, 0, 0], size=[-1, 2, 2])
with tf.Session() as sess:
    print(sess.run(s1))
    print(sess.run(s2))

[[[5 6]
  [7 8]]]
[[[ 5  6]
  [ 7  8]]

 [[ 9 10]
  [11 12]]]


---
## Building the Capsule Network
Now We start building the Capsule Network. We begin with a function to shorten up the call to `tf.layers.conv2d`

In [2]:
def conv_layer(inputs, filters, kernel_size=9, strides=1, padding='valid', activation=None):
    return tf.layers.conv2d(inputs=inputs, 
                            filters=filters, 
                            kernel_size=kernel_size, 
                            strides=strides, 
                            padding=padding, 
                            activation=activation, 
                            kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 0.001
ROUTING_ITERS = 3 

In [42]:
ones = tf.reshape(tf.ones([784]), [-1, 28, 28, 1])
out = conv_layer(ones, 256, activation=tf.nn.relu)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output = sess.run(out)
    data_shape = output.shape
    assert output.shape == (1, 20, 20, 256), 'Incorrect shape, {}, after regular conv layer. Should be {}'.format(data_shape, [1, 20, 20, 256])

In [2]:
def caps_net_model_fn(features, labels, mode):
    
    # image layer
    assert features['x'].get_shape()[1:] == [784] 
    input_layer = tf.reshape(features['x'], shape=(-1, 28, 28, 1))
    
    # save the batch size for later (<= 64)
    batch_size = input_layer.get_shape()[0]
    n_primary_caps = 32
    primary_caps_size = 8
    digi_caps_size = 16
    
    # conv layer (regular convolutional layer)
    with tf.variable_scope('regular_conv_layer'):
        conv1 = conv_layer(input_layer, 256, activation=tf.nn.relu)
        data_shape = conv1.get_shape()
        e1 = 'Incorrect shape, {}, after regular conv layer. Should be {}'.format(data_shape, [batch_size, 20, 20, 256])
        assert data_shape == [batch_size, 20, 20, 256], e1
    
    # first capsule layer (PrimaryCaps)
    capsules = []
    for i in range(32):
        # naming convention: capsule_[capsule layer]_[capsule index]
        with tf.variable_scope('capsule_1_' + str(i + 1)):
            caps_i = conv_layer(conv1, primary_caps_size, strides=2)
            reshape = tf.reshape(caps_i, shape=(-1, 6*6, primary_caps_size))
            capsules.append(reshape)
    assert capsules[0].get_shape() == [batch_size, 6*6, primary_caps_size]
    
    # stack and reshape
    #
    # here we reshape the capsule outputs (i.e. the 8D vectors)
    # to be 1x8 matrices, this will enable us to use tf.matmul
    # to calculate the inputs (u_hat_ij's) to each DigiCaps
    # capsule in one shot
    capsules = tf.stack(capsules, axis=1)
    capsules = tf.reshape(capsules, shape=(-1, 6*6*n_primary_caps, 1, primary_caps_size))
    assert capsules.get_shape() == [batch_size, 6*6*n_primary_caps, 1, primary_caps_size]
    
    # second capsule layer (DigiCaps)
    u_hat = []
    for j in range(10):
        with tf.variable_scope('->capsule_{}'.format(j)):
            name = 'W_i{}'.format(j)
            weights_to_j = tf.get_variable(name=name, shape=(1, 6*6*n_primary_caps, primary_caps_size, digi_caps_size))
            weights_to_j = tf.tile(weights_to_j, (batch_size, 1, 1, 1))
            u_hat_ji = tf.matmul(capsules, weights_to_j, )
            
    
    
            