Sample capsule neural network architecture from [Sabour, Frosst, Hinton 2017](https://arxiv.org/pdf/1710.09829.pdf)
<img src="https://cdn-images-1.medium.com/max/1500/1*AjRyyzttFIoMRb73Jzycog.png"></img>
[image source](https://pechyonkin.me/capsules-1/)

In this notebook we'll do the following
* build a capsule network (CapsNet) with dynamic routing
* compare its performance in a 'small data' domain against a traditional convolutional neural network (CNN)
* visualize its weights to understand its learned 'representation'
* visualize how a given sample gets transformed by the capsule network
* visualize the learned latent representation of the digit capsules ('digitcaps') 

Be sure to change your [Google colab](https://colab.research.google.com) runtime type by clicking `Runtime`, `Change runtime type`, and then selecting `GPU`. We'll try our luck again to see if Google let's us all get free GPUs.

This notebook's Capsule Network implementations were adapted from the following sources
* https://github.com/XifengGuo/CapsNet-Keras
* https://keras.io/examples/cifar10_cnn_capsule/ (note that this implementation is not quite right / not quite matching the original paper)

<br><br><br>
Python / notebook cheatsheet
* `Shift+Enter` to execute a given cell and move to / make a new next cell
* `#` is for single line comments
* each cell in this notebook can be either of type `code` (the cell below) or type `markdown` (this cell)
* tabs (indents) are required for any flow control (`for`, `if..else`, `while` etc)
* `=` is for assignment
* `()` is to call a function or method
* `[]` is for indexing into a variable
* a quick way to post-hoc to put parenthesis or brackets around something, just select that something and type `(` or `[`
* get help for any function by typing `?` after a given function (and execute the cell), or pressing `Shift Tab` when your line cursor is on the function

In [None]:
# keras-specific modules we need
from keras import backend as K
from keras import activations
from keras.callbacks import LambdaCallback
from keras.datasets import mnist, fashion_mnist, cifar10
from keras.initializers import RandomUniform, Zeros, Ones, glorot_uniform
from keras.layers import *
from keras.models import Sequential, Model
from keras.utils import to_categorical
import tensorflow as tf

# more packages / functions that we'll need
from IPython.core.pylabtools import figsize
from IPython.display import HTML
from matplotlib.animation import ArtistAnimation
import math
import matplotlib.pyplot as plt
import numpy as np

### Let's first load and visualize our sample data

In [None]:
# built-in keras datasets: https://keras.io/datasets/
# load a dataset of choice
# recall each of these has 10 classes
(x_train,y_train),(x_test,y_test) = mnist.load_data()
#(x_train,y_train),(x_test,y_test) = fashion_mnist.load_data()
#(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# shuffle the training dataset
indices = np.arange(0,x_train.shape[0])
np.random.shuffle(indices)
x_train = x_train[indices,:,:]
y_train = y_train[indices]

# one-hot encode labels
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# add a dummy 4th dim for channel
if len(x_train.shape)<4: 
    x_train = np.reshape(x_train,(x_train.shape[0],x_train.shape[1],x_train.shape[2],1))
    x_test = np.reshape(x_test,(x_test.shape[0],x_test.shape[1],x_test.shape[2],1))    
    
# rescale and normalize
x_train = (x_train/255.-0.5)*2
x_test = (x_test/255.-0.5)*2
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

In [None]:
# check data dimensions: n samples, by n pix, by n pix, by n channels
x_train.shape

In [None]:
# visualize sample images from input dataset
figsize(5,5)
for a in range(0,36):
    plt.subplot(6,6,a+1)
    plt.imshow(np.squeeze(x_train[a,:,:,:]),cmap='gray')
    plt.axis('off')

We're interested in comparing the performance of Capsule networks and convolutional neural networks in 'small' data domains, so let's pull out a small subset of the training data

In [None]:
# get balanced subset of training and test data (to simulate small data scenario)
nSamp = 10
nClass = y_train.shape[1]
x_train_subset = np.zeros((nSamp*nClass,x_train.shape[1],x_train.shape[2],x_train.shape[3]),dtype='float32')
y_train_subset = np.zeros((nSamp*nClass,nClass),dtype='float32')
x_test_subset = np.zeros((nSamp*nClass,x_train.shape[1],x_train.shape[2],x_train.shape[3]),dtype='float32')
y_test_subset = np.zeros((nSamp*nClass,nClass),dtype='float32')
for a in range(nClass):
    CurrClassIdxs = np.nonzero(y_train[:,a])
    CurrClassIdxs = CurrClassIdxs[0][:nSamp] # take first nSamp examples
    CurrClassIdxs_test = np.nonzero(y_test[:,a])
    CurrClassIdxs_test = CurrClassIdxs_test[0][:nSamp] # take first nSamp examples    
    x_train_subset[nSamp*a:nSamp*a+nSamp,:,:,:] = x_train[CurrClassIdxs,:,:,:]
    y_train_subset[nSamp*a:nSamp*a+nSamp,a] = y_train[CurrClassIdxs,a]
    x_test_subset[nSamp*a:nSamp*a+nSamp,:,:,:] = x_test[CurrClassIdxs_test,:,:,:]
    y_test_subset[nSamp*a:nSamp*a+nSamp,a] = y_test[CurrClassIdxs_test,a]
    
# should all equal nSamp
np.sum(y_train_subset,axis=0)

In [None]:
batch_size = 128
num_classes = y_train.shape[1] # =10 classes
epochs = 100

### Hand-code support functions for our Capsule Network

Recall that Capsule Networks are similar to normal neural networks in that they 
1. take some input, 
2. transform it by multiplying it by some learned parameters (weights) and 
3. pass the resulting value(s) through another function, most often a nonlinear 'activation' function

`Input ---> Input*Weights ---> f(Input*Weights)`

where `f` could be anything from, e.g. `mx+b` (line; not ever used) to `max(0,x)` (relu) to `1/(1+e^-x)` (logistic function)

Recall also that in order to train our neural network, we need to additionally 
4. define a performance metric, i.e. a `loss` function, with which we can compute how well our network has performed and backpropagate an error signal to update our network parameters

And so we need to define for our Capsule network the following
* a capsule 'layer' (to satisfy #2 above)
* the capsule 'squashing' activation function (to satisfy #3 above)
* the capsule `loss` function (to satisfy #4 above)

In [None]:
'''Loss function'''
# define the margin loss like hinge loss
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)

'''Activation function'''
# the squashing function.
# we use 0.5 instead of 1 in hinton's paper.
# if 1, the norm of vector will be zoomed out.
# if 0.5, the norm will be zoomed in while original norm is less than 0.5
# and be zoomed out while original norm is greater than 0.5.
def squash(x, axis=-1):
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm)
    return scale * x

# define our own softmax function instead of K.softmax
# because K.softmax can not specify axis.
def softmax(x, axis=-1):
    ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
    return ex / K.sum(ex, axis=axis, keepdims=True)

In [None]:
'''Capsule layer'''
class CapsuleLayer(Layer):
    """
    The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the 
    neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron
    from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \
    [None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1.
    
    :param num_capsule: number of capsules in this layer
    :param dim_capsule: dimension of the output vectors of the capsules in this layer
    :param routings: number of iterations for the routing algorithm
    """
    def __init__(self, num_capsule, dim_capsule, routings=3,
                 kernel_initializer='glorot_uniform',
                 **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]

        # Transform matrix
        self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                        self.dim_capsule, self.input_dim_capsule],
                                 initializer=self.kernel_initializer,
                                 name='W')

        self.built = True

    def call(self, inputs, training=None):
        # inputs.shape=[None, input_num_capsule, input_dim_capsule]
        # inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
        inputs_expand = K.expand_dims(inputs, 1)

        # Replicate num_capsule dimension to prepare being multiplied by W
        # inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
        inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])

        # Compute `inputs * W` by scanning inputs_tiled on dimension 0.
        # x.shape=[num_capsule, input_num_capsule, input_dim_capsule]
        # W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
        # Regard the first two dimensions as `batch` dimension,
        # then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
        # inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
        inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)

        # Begin: Routing algorithm ---------------------------------------------------------------------#
        # The prior for coupling coefficient, initialized as zeros.
        # b.shape = [None, self.num_capsule, self.input_num_capsule].
        b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])

        assert self.routings > 0, 'The routings should be > 0.'
        for i in range(self.routings):
            # c.shape=[batch_size, num_capsule, input_num_capsule]
            #c = tf.nn.softmax(b, dim=1)
            c = softmax(b, 1)
            # c.shape =  [batch_size, num_capsule, input_num_capsule]
            # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
            # The first two dimensions as `batch` dimension,
            # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
            # outputs.shape=[None, num_capsule, dim_capsule]
            outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))  # [None, 10, 16]

            if i < self.routings - 1:
                # outputs.shape =  [None, num_capsule, dim_capsule]
                # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
                # The first two dimensions as `batch` dimension,
                # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
                # b.shape=[batch_size, num_capsule, input_num_capsule]
                b += K.batch_dot(outputs, inputs_hat, [2, 3])
        # End: Routing algorithm -----------------------------------------------------------------------#

        return outputs

    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])

    def get_config(self):
        config = {
            'num_capsule': self.num_capsule,
            'dim_capsule': self.dim_capsule,
            'routings': self.routings
        }
        base_config = super(CapsuleLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

### Define our Capsule Network

Now that we've defined the ingredients for capsule layers, let's build a Capsule Network similar to that in [Sabour, Frosst, Hinton 2017](https://arxiv.org/pdf/1710.09829.pdf)

In [None]:
# using xinfeng guo implementation
def makeCapsNN():
    '''
    In contrast to our previous tutorials, we'll use Keras' `Model` API as opposed to its `Sequential` API
    Recall with the `Sequential` API, we used successive `.add()` calls to add layers to our model
    With the (poorly named) `Model` or 'functional' API, you instead pass each layer as an argument to the
    next layer
    '''
    input_image = Input(shape=(x_train.shape[1],x_train.shape[2],x_train.shape[3]),name='capsNN-input-image')
    x = Conv2D(filters=32, kernel_size=(9, 9), strides=(1,1),activation='relu',name='capsNN-conv1')(input_image)
    x = Conv2D(filters=8*2, kernel_size=(9, 9), strides=(2,2),name='primarycaps')(x) # two (9x9) convolutional capsules with 8 dims
    x = Reshape((-1,8),name='primarycaps-reshape')(x) 
    x = Lambda(squash,name='primarycaps-squash')(x)
    x = Dropout(0.3,name='capsNN-dropout')(x)
    x = CapsuleLayer(num_capsule=num_classes, dim_capsule=16, routings=3,name='digitcaps')(x)
    output = Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)),name='capsNN-pred-out')(x)
    model = Model(inputs=input_image, outputs=output)

    # we use the margin loss function defined earlier
    model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy'])

    return model

In [None]:
# callback functions to store weights from our convolutional layers after each epoch
layer1weights_capsNN = [] # empty list to hold our weights
layer2weights_capsNN = []
get_weights1_capsNN = LambdaCallback(on_epoch_end=lambda epoch,logs: layer1weights_capsNN.append(myCapsNN.layers[1].get_weights())) 
get_weights2_capsNN = LambdaCallback(on_epoch_end=lambda epoch,logs: layer2weights_capsNN.append(myCapsNN.layers[2].get_weights()))

myCapsNN = makeCapsNN() # make our model
myCapsNN.summary() # print a summary of model layers, outputs shapes, and trainable parameters

Do the numbers of parameters make sense? 
* `capsNN-conv1`: `filters=32` of `(9,9)`, so that's `32*9*9+32 = 2624`
* `primarycaps`: `filters=8*2` of `(9,9)`, and for each of the input `32` channels from the prev layer, so that's `32*8*2*9*9+16 = 41488`
* Now for the digitcaps, it's a bit more opaque
    * each `primarycaps` is 8-dimensional
    * in aggregate, our `primarycaps` outputs a `(6,6,16)` filtered 'image' i.e. `(6,6,8,2)`
    * `primarycaps-reshape` yielded shape of `(72,8)`, i.e. we 'flatten' our `primarycaps` output into 72 vectors
    * so technically, our two `(9x9)` convolutional capsules (with 8 dimensions) end up being 72 capsules with 8 dimensions each
    * for each of the 72 capsules, we need to learn a weight matrix, `W` of size `(i,j)` where `i` is the number of `primarycaps` dimensions (8) and `j` is the number of `digitcaps` dimensions (16)
        * and we need to do this for each of the digit classes (10 in total)
    * so the parameters then are: `72*8*16*10 = 92160`

### Define our convolutional neural network (CNN)
We'll try our best to match the architecture and number of parameters to make a fair comparison

In [None]:
def makeCNN():
    myCNN = Sequential() # instantiate a `Sequential` model
    myCNN.add(Conv2D(32,(9,9),activation='relu', input_shape=(x_train.shape[1],x_train.shape[2],x_train.shape[3]),name='CNN-conv1'))
    myCNN.add(Conv2D(8*2,(9,9),strides=(2,2),activation='relu',name='CNN-conv2'))
    myCNN.add(Flatten(name='CNN-flatten'))
    myCNN.add(Dropout(0.3,name='CNN-dropout'))    
    myCNN.add(Dense(num_classes*16))
    myCNN.add(Dense(num_classes,activation='softmax',name='CNN-pred-out'))
    myCNN.compile(loss='categorical_crossentropy', metrics=['accuracy'],optimizer='adam')
    return myCNN

In [None]:
# callback functions to store weights from our convolutional layers after each epoch
layer1weights_cnn = [] # empty list to hold our weights
layer2weights_cnn = []
get_weights1_cnn = LambdaCallback(on_epoch_end=lambda epoch,logs: layer1weights_cnn.append(myCNN.layers[0].get_weights())) 
get_weights2_cnn = LambdaCallback(on_epoch_end=lambda epoch,logs: layer2weights_cnn.append(myCNN.layers[1].get_weights()))

myCNN = makeCNN() # make our model
myCNN.summary() # print a summary of model layers, outputs shapes, and trainable parameters

### Train the Capsule network and the CNN

In [None]:
# First train the capsule network
history_caps = myCapsNN.fit(x_train_subset,y_train_subset,batch_size=batch_size,
                          epochs=epochs*2,validation_data=(x_test_subset, y_test_subset),
                          callbacks=[get_weights1_capsNN,get_weights2_capsNN])

In [None]:
history_cnn = myCNN.fit(x_train_subset,y_train_subset,batch_size=batch_size,
                        epochs=epochs*2,validation_data=(x_test_subset, y_test_subset),
                        callbacks=[get_weights1_cnn,get_weights2_cnn])

Now since we've trained our networks, let's compare their performance as a function of training

In [None]:
fig,ax = plt.subplots(2,2,figsize=(8,8))
plt.subplots_adjust(hspace=0.4)

ax[0,0].plot(history_caps.history['loss'],label='CapsNet')
ax[0,0].plot(history_cnn.history['loss'],label='CNN')
ax[0,0].set_title('Training data')
ax[0,0].set_xlabel('Training epoch')
ax[0,0].set_ylabel('Loss')
ax[0,0].legend()

ax[1,0].plot(history_caps.history['val_loss'],label='CapsNet')
ax[1,0].plot(history_cnn.history['val_loss'],label='CNN')
ax[1,0].set_title('Test data')
ax[1,0].set_xlabel('Training epoch')
ax[1,0].set_ylabel('Loss')
ax[1,0].legend()

ax[0,1].plot(history_caps.history['acc'],label='CapsNet')
ax[0,1].plot(history_cnn.history['acc'],label='CNN')
ax[0,1].set_title('Training data')
ax[0,1].set_xlabel('Training epoch')
ax[0,1].set_ylabel('Accuracy')
ax[0,1].legend()

ax[1,1].plot(history_caps.history['val_acc'],label='CapsNet')
ax[1,1].plot(history_cnn.history['val_acc'],label='CNN')
ax[1,1].set_title('Test data')
ax[1,1].set_xlabel('Training epoch')
ax[1,1].set_ylabel('Accuracy')
ax[1,1].legend()

Both networks can fit the training data well, but note how the CNN tends to overfit (test data loss increases gradually)

Note that we only gave 100 test data examples. Let's see how well the models perform across the entire test dataset (10,000 samples)

In [None]:
myCapsNN.evaluate(x_test,y_test)

In [None]:
myCNN.evaluate(x_test,y_test)

We see that the Capsule Network performs better than the CNN in this 'small data' example.

On your own, you can see how changing nSamp affects performance. You can also see how the models perform on the two other datasets (fashion-MNIST and CIFAR-10)

### Visualize what our networks have learned
Let's first explore the CapsNet

In [None]:
# helper function to plot a movie of our changing weights as a function of training
def VizLayerWeights(layerweight_by_epoch_list,figtitle,subplotshape=(8,4),frameinterval=20):
    '''
    Inputs 
        layerweight_by_epoch_list: list, layer weights, as collected by a lambda callback
        figtitle: string, figure title
        subplotshape: tuple, m by n subplots to plot (which will also dictate figure size)        
        frameinterval: int, milliseconds between movie frames in the animated figure
    
    Outputs
        a jshtml animation of your filters changing with time
    '''
    fig, ax = plt.subplots(subplotshape[1],subplotshape[0],figsize=subplotshape)
    fig.suptitle(figtitle)
    weightplots = []
    for c in range(0,len(layerweight_by_epoch_list)):
        currweights = layerweight_by_epoch_list[c][0] # get weights from current epoch
        currweights = np.reshape(currweights,(currweights.shape[0],currweights.shape[1],-1)) # shove all filters into one dimension
        Counter = 0
        weightsubplots = []
        for a in range(0,subplotshape[1]):
            for b in range(0,subplotshape[0]):
                ax[a,b].axis('off')
                img = ax[a,b].imshow(np.squeeze(currweights[:,:,Counter]))
                Counter += 1
                weightsubplots.append(img) # append to list of subplots
        weightplots.append(weightsubplots) # append to list of figure frames

    anim = ArtistAnimation(fig, weightplots, interval=frameinterval) # animate with prerendered plots in `weightplots`
    plt.close() # close the actual plotted figure
    return HTML(anim.to_jshtml()) # ...and instead display a jshtml animation

In [None]:
# First convolutional layer of CapsNet
VizLayerWeights(layer1weights_capsNN,'CapsNet: 1st convolutional layer filters by epoch')

In [None]:
# Primary capsule layers (convolutional capsules)
VizLayerWeights(layer2weights_capsNN,'CapsNet: Primary (convolutional) capsule filters by epoch (first 32)')

Now let's see what the CNN weights look like

In [None]:
# First convolutional layer of CNN
VizLayerWeights(layer1weights_cnn,'CNN: 1st convolutional layer filters by epoch')

In [None]:
# Second convolutional layer of CNN
VizLayerWeights(layer2weights_cnn,'CNN: 2nd convolutional layer filters by epoch (first 32)')

We see that the both the CapsNet and CNN convolutional layers are learning more or less what we expect--some sort of edge detecting filters. 

### Next, let's visualize the data representation in each layer
For this, we'll output a transformed sample at a few key layers in our networks

Let's see the CapsNet first

In [None]:
# Choose a sample to visualize
ChooseIdx = 8

# note how we reuse our previous model, and thus don't need to train this new model
capsNN_conv1 = Model(inputs=myCapsNN.input,outputs=myCapsNN.get_layer('capsNN-conv1').output)
capsNN_primarycaps = Model(inputs=myCapsNN.input,outputs=myCapsNN.get_layer('primarycaps').output)
capsNN_digitcaps = Model(inputs=myCapsNN.input,outputs=myCapsNN.get_layer('digitcaps').output)
capsNN_predout = Model(inputs=myCapsNN.input,outputs=myCapsNN.get_layer('capsNN-pred-out').output)

# get a prediction from the new model
capsNN_conv1_out = capsNN_conv1.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))
capsNN_primarycaps_out = capsNN_primarycaps.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))
capsNN_digitcaps_out = capsNN_digitcaps.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))
capsNN_predout_out = capsNN_predout.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))

print(capsNN_conv1_out.shape)
print(capsNN_primarycaps_out.shape)
print(capsNN_digitcaps_out.shape)
print(capsNN_predout_out.shape)

In [None]:
# first let's see the example we pushed through the network
figsize(3,3)
plt.imshow(np.squeeze(x_test[ChooseIdx,:,:,:]))
plt.axis('off')

In [None]:
# plot sample transformed 'images' from the first conv layer
figsize(8,4)
plt.suptitle('Data representation at layer: capsNN-conv1')
for a in range(32):
    plt.subplot(4,8,a+1)
    plt.imshow(np.squeeze(capsNN_conv1_out[:,:,:,a]))
    plt.axis('off')

In [None]:
# plot sample transformed 'images' from the primary capsule layers
figsize(8,2)
plt.suptitle('Data representation at layer: capsNN_primarycaps')
for a in range(16):
    plt.subplot(2,8,a+1)
    plt.imshow(np.squeeze(capsNN_primarycaps_out[:,:,:,a]))
    plt.axis('off')

In [None]:
# plot sample transformed 'images' from the digit capsule layer
figsize(6,6)
plt.title('Data representation at layer: capsNN_digitcaps')
plt.imshow(np.squeeze(capsNN_digitcaps_out))
plt.yticks([0,1,2,3,4,5,6,7,8,9])
plt.ylabel('number')
plt.xlabel('capsule embedding dimension')

In [None]:
# plot sample transformed 'images' from the output layer of the CapsNet
figsize(6,6)
plt.title('Data representation at layer: capsNN_predout')
plt.imshow(capsNN_predout_out.T)
plt.yticks([0,1,2,3,4,5,6,7,8,9])
plt.ylabel('number')
plt.xticks([])

Now let's compare against the CNN

In [None]:
# Choose a sample to visualize
ChooseIdx = 8

# note how we reuse our previous model, and thus don't need to train this new model
CNN_conv1 = Model(inputs=myCNN.input,outputs=myCNN.get_layer('CNN-conv1').output)
CNN_conv2 = Model(inputs=myCNN.input,outputs=myCNN.get_layer('CNN-conv2').output)
CNN_predout = Model(inputs=myCNN.input,outputs=myCNN.get_layer('CNN-pred-out').output)

# get a prediction from the new model
CNN_conv1_out = CNN_conv1.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))
CNN_conv2_out = CNN_conv2.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))
CNN_predout_out = CNN_predout.predict(np.reshape(x_test[ChooseIdx,:,:,:],(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])))

print(CNN_conv1_out.shape)
print(CNN_conv2_out.shape)
print(CNN_predout_out.shape)

In [None]:
# first let's see the example we pushed through the network
figsize(3,3)
plt.imshow(np.squeeze(x_test[ChooseIdx,:,:,:]))
plt.axis('off')

In [None]:
# plot sample transformed 'images' from the first CNN convolutional layer
figsize(8,4)
plt.suptitle('Data representation at layer: CNN-conv1')
for a in range(32):
    plt.subplot(4,8,a+1)
    plt.imshow(np.squeeze(CNN_conv1_out[:,:,:,a]))
    plt.axis('off')

In [None]:
# plot sample transformed 'images' from the second CNN convolutional layer
figsize(8,2)
plt.suptitle('Data representation at layer: CNN-conv2')
for a in range(16):
    plt.subplot(2,8,a+1)
    plt.imshow(np.squeeze(CNN_conv2_out[:,:,:,a]))
    plt.axis('off')

Similar to what we saw with the learned weights, we see that the CNN is representing the data differently than the CapsNet

In [None]:
# plot sample transformed 'images' from the CNN output layer
figsize(6,6)
plt.title('Data representation at layer: CNN_predout')
plt.imshow(CNN_predout_out.T)
plt.yticks([0,1,2,3,4,5,6,7,8,9])
plt.ylabel('number')
plt.xticks([])
#plt.axis('off')

### Visualize learned latent representation of `digitcaps`
First we need to train up a decoder network to learn how to decode the `digitcaps` layer representation `(10,16)` (10 digits, 16 'features' for each digit) into a digit image.

We accomplish this by making a MLP and then attaching it to the end of our CapsNet

In [None]:
'''Helper function from Xifeng Guo to mask digitcaps that are not the current digit'''
class Mask(Layer):
    """
    Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional 
    input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
    masked Tensor.
    For example:
        ```
        x = keras.layers.Input(shape=[8, 3, 2])  # batch_size=8, each sample contains 3 capsules with dim_vector=2
        y = keras.layers.Input(shape=[8, 3])  # True labels. 8 samples, 3 classes, one-hot coding.
        out = Mask()(x)  # out.shape=[8, 6]
        # or
        out2 = Mask()([x, y])  # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
        ```
    """
    def call(self, inputs, **kwargs):
        if type(inputs) is list:  # true label is provided with shape = [None, n_classes], i.e. one-hot code.
            assert len(inputs) == 2
            inputs, mask = inputs
        else:  # if no true label, mask by the max length of capsules. Mainly used for prediction
            # compute lengths of capsules
            x = K.sqrt(K.sum(K.square(inputs), -1))
            # generate the mask which is a one-hot code.
            # mask.shape=[None, n_classes]=[None, num_capsule]
            mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

        # inputs.shape=[None, num_capsule, dim_capsule]
        # mask.shape=[None, num_capsule]
        # masked.shape=[None, num_capsule * dim_capsule]
        masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
        return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  # true label provided
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # no true label provided
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config

In [None]:
myCapsNN = makeCapsNN() # re-initialize / reset our model

# Decoder network MLP
y = Input(shape=(num_classes,))
masked_by_y = Mask()([myCapsNN.get_layer('digitcaps').output, y])  # The true label is used to mask the output of capsule layer. For training
decoder_only = Sequential(name='decoder_only')
decoder_only.add(Dense(512, activation='relu',input_dim=16*num_classes,name='capsNN-decode-dense-1'))
decoder_only.add(Dense(1024, activation='relu',name='capsNN-decode-dense-2'))
#decoder_only.add(Dense(np.prod((x_train.shape[1],x_train.shape[2],x_train.shape[3])), activation='sigmoid')) # Xifeng's code incorrectly used a sigmoid here; we don't necessarily need an activation function, as we're looking to re-generate the original image; this is also incompatible with rescaling images from -1 to 1 as we did in our initial iamge loading
decoder_only.add(Dense(np.prod((x_train.shape[1],x_train.shape[2],x_train.shape[3])),name='capsNN-decode-dense-3'))
decoder_only.add(Reshape(target_shape=(x_train.shape[1],x_train.shape[2],x_train.shape[3]), name='capsNN-decode-recon'))

# Make a multi-input, multi-output model (yes you can do that! these are also to different layers!)
# Inputs: the input to our CapsNet and the training label
# Outputs: the output of our CapsNet and the output of the `decoder` network (masked by the input training label)
capsnet_decoder = Model([myCapsNN.get_layer('capsNN-input-image').output, y], 
                    [myCapsNN.get_layer('capsNN-pred-out').output, decoder_only(masked_by_y)])

In [None]:
# compile the model
capsnet_decoder.compile(optimizer='adam',loss=[margin_loss, 'mse'],loss_weights=[1., 0.392],metrics=['accuracy'])

# model architecture summary
capsnet_decoder.summary()
decoder_only.summary()

Note how the layers up through `digitcaps` are exactly the same as those of our original CapsNet. This is by design. We're linking to it.

Note how the layers that follow belong then to our decoder (plus some intermediate housekeeping layers) -- these are not shown in `train_model.summary()`, but can be seen when we invoke `decoder.summary()`

Let's now fit the new model. Note that we are training our model and the decoder network simultaneously. Although we could use our already trained CapsNet, it'll be better for the decoder to learn if we retrain everything from scratch (thus why we called `myCapsNN = makeCapsNN()` again).

In [None]:
capsnet_decoder.fit([x_train_subset, y_train_subset], [y_train_subset, x_train_subset], batch_size=batch_size, epochs=epochs*4)
#capsnet_decoder.fit([x_train, y_train], [y_train, x_train], batch_size=batch_size, epochs=epochs)

Finally (I know, it's getting more and more complicated!), we make yet another model, `explore_latent`, similar to our `capsnet_decoder`, but this time, we add an 'input' digitcaps-like layer where we will systematically explore the digitcaps latent space (by passing our own values, as opposed to the learned values of the network) 

In [None]:
# model to explore learned latent digitcaps space
masked = Mask()(myCapsNN.get_layer('digitcaps').output)  # Mask using the capsule with maximal length. For prediction
sampled_vals = Input(shape=(num_classes, 16))
sampled_vals_digitcaps = Add()([myCapsNN.get_layer('digitcaps').output, sampled_vals])
masked_sampled_vals_y = Mask()([sampled_vals_digitcaps, y])
explore_latent = Model([myCapsNN.get_layer('capsNN-input-image').output, y, sampled_vals], decoder_only(masked_sampled_vals_y))

In [None]:
# sample vals_to_explore for each dimension of a given digit capsule
DigitToExplore = 5
index = np.argmax(y_test, 1) == DigitToExplore
number = np.random.randint(low=0, high=sum(index) - 1)
xx, yy = x_test[index][number], y_test[index][number]
xx, yy = np.expand_dims(xx, 0), np.expand_dims(yy, 0)
vals_to_explore=[-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]
x_recons = []
vals_to_explore_y = np.zeros((1,10,16))
for dim in range(16):
    for r in vals_to_explore:
        tmp = np.copy(vals_to_explore_y)
        tmp[:,:,dim] = r
        x_recon = explore_latent.predict([xx,yy,tmp])
        x_recons.append(x_recon)
        
# plot the resulting explored latent space for a given digit and its capsule
num = len(x_recons)
width = len(vals_to_explore)
height = 16
shape = x_recons[0].shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
                 dtype=x_recons[0].dtype)
for index, img in enumerate(x_recons):
    i = int(index/width)
    j = index % width
    image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = np.squeeze(img[0,:,:,0])

# each row is a latent dim, each column is stepping through latent dim space
figsize(8,12)
plt.imshow(image,vmin=-1,vmax=1,cmap='gray')    
plt.axis('off')        