# Keras implementation of CapsNet
##### `Dynamic Routing Between Capsules` by Sabour, Frosst and Hinton, in NIPS 2017

### PLAN:
* implement reshaping 3D tensor (output of regular Conv2D into 32 6x6x8 capsules (or preferably, 32 36x8 capsules, so in fact make it (32x36)x8 **DONE**
* implement a capsule layer with forward pass = dynamic routing
    - implement squash function **DONE**
    - implement weight sharing between capsules in each of 32 groups **LOOKS LIKE DONE?** **NEEDS IMPROVEMENT**
    - implement dynamic routing **SEEMS DONE, BUT TWICE THERE SEEMS TO BE UNNECESSARY DIMS EXPANSION AND COPYING OF TENSORS. TRY FIXING THIS**
* implement margin loss for digit existence
* implement reconstruction loss with masking for depending on learning phase
* put it all together into a network
* train on MNIST with small data augmentation as described in the paper, use `.fit_generator()`
* reproduce paper results
* visualize some reconstructions
* visualize how reconstructions change with continous change of one of DigitCaps dimensions

### 1. Reshape

In [1]:
import numpy as np

In [2]:
input_shape = (6, 6, 256)
output_shape = (1152, 8)

In [3]:
dummy_X = np.random.random((input_shape))

In [4]:
from keras.models import Model
from keras.layers import Reshape, Permute
from keras import Input

Using TensorFlow backend.


In [5]:
# this appears correct (numpy)
reshape_intermediate = dummy_X.reshape(6, 6, 32, 8)
reshape_intermediate_2 = reshape_intermediate.reshape(36, 32, 8)
reshaped_3 = reshape_intermediate_2.reshape(1152, 8, order='F')

In [6]:
# THIS APPEARS TO WORK CORRECTLY (Keras)

input_tensor = Input(shape=input_shape)
reshaped = Reshape((36, 32, 8))(input_tensor)
permuted = Permute((2,1,3))(reshaped)
output_tensor = Reshape(output_shape)(permuted)

model = Model(input_tensor, output_tensor)
output = model.predict(dummy_X.reshape(1, 6, 6, 256))[0]

In [7]:
output[-2]

array([ 0.23075144,  0.43397957,  0.06308483,  0.90659434,  0.177563  ,
        0.55920148,  0.64529717,  0.37560156], dtype=float32)

In [8]:
dummy_X[5, 4, -8:]

array([ 0.23075145,  0.43397957,  0.06308483,  0.90659436,  0.17756299,
        0.55920148,  0.64529719,  0.37560156])

## 2. CapsNet Layer

In [9]:
from keras import backend as K
from keras.engine.topology import Layer

In [296]:
class CapsuleLayer(Layer):
    
    def __init__(self, output_capsules, capsule_dim, num_groups=None, routing_iterations=3, **kwargs):
        self.output_capsules = output_capsules
        self.capsule_dim = capsule_dim
        self.num_groups = num_groups
        self.routing_iterations = routing_iterations
        super(CapsuleLayer, self).__init__(**kwargs)
        
    def build(self, input_shape):
        if self.num_groups:
            if input_shape[1] % self.num_groups:
                raise ValueError('num_groups should divide input_shape[0] without remainder')
            
            # correct this
            self.kernel = self.add_weight(
                name='kernel', 
                shape=(
                    self.num_groups,
                    self.output_capsules,
                    input_shape[-1],
                    self.capsule_dim
                ),
                initializer='uniform',
                trainable=True
            )
        
        else:
            self.kernel = self.add_weight(
                name='kernel',
                shape=(
                    input_shape[1], 
                    self.output_capsules, 
                    input_shape[2], 
                    self.capsule_dim,
                ),
                initializer='uniform',
                trainable=True
            )
        
        super(CapsuleLayer, self).build(input_shape)
        
    def call(self, x):
        # initialize matrix of b_ij's
        input_shape = K.shape(x)
        B = K.zeros(shape=(input_shape[0], input_shape[1], self.output_capsules))
        
        x = K.expand_dims(x, axis=2)
        x = K.repeat_elements(x, rep=self.output_capsules, axis=2)
        U = K.map_fn(lambda x: K.batch_dot(x, self.kernel, axes=[2,2]), x)
        
        for i in range(self.routing_iterations):
            V, B_updated = self._routing_single_iter(B, U, i, input_shape)
            B = B_updated
            
        return V
            
    def _routing_single_iter(self, B, U, i, input_shape):
        C = K.softmax(B, axis=-1)
        C = K.expand_dims(C, axis=-1)
        C = K.repeat_elements(C, rep=self.capsule_dim, axis=-1)
        S = K.sum(C * U, axis=1)
        V = self._squash(S)
        if i != self.routing_iterations:
            V_expanded = K.expand_dims(V, axis=1)
            V_expanded = K.tile(V_expanded, [1, input_shape[1], 1, 1])
            B = B + K.sum(U*V_expanded, axis=-1)
        return V, B
        
    
    @staticmethod
    def _squash(x):
        l2_norm = K.sum(K.square(x), axis=-1, keepdims=True)
        squash = l2_norm / (1 + l2_norm) * (x / (K.sqrt(l2_norm) + K.epsilon()))
        return squash


In [268]:
B = K.zeros(shape=(12, 1152, 10))
C =  K.map_fn(lambda x: K.softmax(x, axis=-1), B)

x_batch = K.ones(shape=(12,1152, 8))
x_batch = K.expand_dims(x_batch, axis=2)
x_batch = K.repeat_elements(x_batch, rep=10, axis=2)
y_batch = K.ones(shape=(1152, 10, 8, 16))
xy_batch_dot = K.map_fn(lambda x: K.batch_dot(x, y_batch, axes=[2,2]), x_batch)

C= K.expand_dims(C, axis=-1)
C = K.repeat_elements(C, rep=16, axis=-1)

print(K.int_shape(C))
print(K.int_shape(xy_batch_dot))

S = K.sum(C * xy_batch_dot, axis=1)
S = K.expand_dims(S, axis=1)
S = K.tile(S, [1,1152,1,1])
# S = K.repeat_elements(S, rep=1152, axis=1)

print(K.int_shape(S))
new_B = K.sum(xy_batch_dot*S, axis=-1 )
print(K.int_shape(new_B))

(12, 1152, 10, 16)
(12, 1152, 10, 16)
(12, 1152, 10, 16)
(12, 1152, 10)


In [297]:
dummy_X = np.random.random((32,6,6,256))

input_tensor = Input(shape=input_shape)
reshaped = Reshape((36, 32, 8))(input_tensor)
permuted = Permute((2,1,3))(reshaped)
primary_caps = Reshape(output_shape)(permuted)
digit_caps = CapsuleLayer(10, 16)(primary_caps)

model = Model(input_tensor, digit_caps)
output = model.predict(dummy_X)

In [277]:
output.shape

(32, 10, 16)

In [12]:
c = CapsuleLayer(10, 16)
c.build((1,1152,8))

In [190]:
x_batch = K.ones(shape=(12,1152, 8))
x_batch = K.expand_dims(x_batch, axis=2)
x_batch = K.repeat_elements(x_batch, rep=10, axis=2)
y_batch = K.ones(shape=(1152, 10, 8, 16))
xy_batch_dot = K.map_fn(lambda x: K.batch_dot(x, y_batch, axes=[2,2]), x_batch)
K.int_shape(xy_batch_dot)

(12, 1152, 10, 16)