# 3. Third Step: CapsLayer Architecture

### 1) License to Huadong Liao

In [None]:
"""
License: Apache-2.0
Author: Huadong Liao
E-mail: naturomics.liao@gmail.com

[ I changed very little. The base code is Naturomics' code. ]
"""

### 2) Import Modules

In [None]:
import numpy as np
import tensorflow as tf
from configurations import cfg

### 3) Set Epsilon

In [None]:
epsilon = 1e-9

### 4) CapsLayer (MAIN)

In [None]:
class CapsLayer(object):
    '''
    < ARGUMENTS >
        input: A 4-D tensor.
        num_outputs: the number of capsule in this layer.
        vec_len: integer, the length of the output vector of a capsule.
        layer_type: string, one of 'FC' or "CONV", the type of this layer,
            fully connected or convolution, for the future expansion capability
        with_routing: boolean, this capsule is routing with the
                      lower-level layer capsule.
    < OUTPUTS >
        A 4-D tensor.
    '''
    
    def __init__(self, num_outputs, vec_len, with_routing=True, layer_type='FC'):
        self.num_outputs = num_outputs
        self.vec_len = vec_len
        self.with_routing = with_routing
        self.layer_type = layer_type

        
    def __call__(self, input, kernel_size=None, stride=None):
        '''
        The parameters 'kernel_size' and 'stride' will be used while 'layer_type' equal 'CONV'
        '''
        if self.layer_type == 'CONV':
            self.kernel_size = kernel_size
            self.stride = stride

            if not self.with_routing:
                '''
                [ PrimaryCaps layer ]
                 - a convolutional layer
                     # input: [batch_size, 20, 20, 256]
                '''
                assert input.get_shape() == [cfg.batch_size, 20, 20, 256]

                capsules = tf.contrib.layers.conv2d(input, self.num_outputs * self.vec_len,
                                                    self.kernel_size, self.stride, padding="VALID",
                                                    activation_fn=tf.nn.relu)

                capsules = tf.reshape(capsules, (cfg.batch_size, -1, self.vec_len, 1))

                '''
                [ Shape ]
                    # [batch_size, 1152, 8, 1
                '''
                capsules = squash(capsules)
                
                assert capsules.get_shape() == [cfg.batch_size, 1152, 8, 1]
                
                return capsules

            
        if self.layer_type == 'FC':
            if self.with_routing:
                '''
                [ DigitCaps layer ]
                 - a fully connected layer
                    # Reshape the input shapel into [batch_size, 1152, 1, 8, 1]
                '''
                self.input = tf.reshape(input, shape=(cfg.batch_size, -1, 1, input.shape[-2].value, 1))

                with tf.variable_scope('routing') as scope:
                    '''
                    b_IJ: [batch_size, num_caps_l, num_caps_l_plus_1, 1, 1],
                          about the reason of using 'batch_size', see issue #21
                    '''
                    b_IJ = tf.constant(np.zeros([cfg.batch_size, input.shape[1].value, self.num_outputs, 1, 1], dtype=np.float32))
                    capsules = routing(self.input, b_IJ)
                    capsules = tf.squeeze(capsules, axis=1)

            return capsules


### 5) Routing

In [None]:
def routing(input, b_IJ):
    ''' The routing algorithm.
    < ARGUMENTS >
        < INPUTS > 
               A Tensor with [batch_size, num_caps_l=1152, 1, length(u_i)=8, 1]
               shape, num_caps_l meaning the number of capsule in the layer l.
        
        < OUTPUTS >
                A Tensor of shape [batch_size, num_caps_l_plus_1, length(v_j)=16, 1]
                representing the vector output `v_j` in the layer l+1
        
        Notes:
               u_i represents the vector output of capsule i in the layer l, and
               v_j the vector output of capsule j in the layer l+1.
     '''

    '''
     # W: [num_caps_i, num_caps_j, len_u_i, len_v_j]
    '''
    with tf.name_scope('weight') as scope:
        W = tf.get_variable('weight', shape=(1, 1152, 10, 8, 16), dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer())
    '''
    [ Equation ] 
     # Calculate u_hat
       - Do tiling for input and W before matmul
       - # input ==> [batch_size, 1152, 10, 8, 1]
       - # W ==> [batch_size, 1152, 10, 8, 16]
    '''
    input = tf.tile(input, [1, 1, 10, 1, 1])
    W = tf.tile(W, [cfg.batch_size, 1, 1, 1, 1])
    
    assert input.get_shape() == [cfg.batch_size, 1152, 10, 8, 1]
    
    '''
    [ Last 2 dimensions ]
     - # [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 10, 16, 1]
     - # tf.scan, 3 iter, 1080ti, 128 batch size: 10min/epoch
     - # u_hat = tf.scan(lambda ac, x: tf.matmul(W, x, transpose_a=True), input, initializer=tf.zeros([1152, 10, 16, 1]))
     - # tf.tile, 3 iter, 1080ti, 128 batch size: 6min/epoch   
    '''
    u_hat = tf.matmul(W, input, transpose_a=True)
    
    assert u_hat.get_shape() == [cfg.batch_size, 1152, 10, 16, 1]

    '''
    [ u_hat_stopped ] 
     - # u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
    '''
    u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

    for r_iter in range(cfg.iter_routing):
        with tf.variable_scope('iter_' + str(r_iter)):
            
            c_IJ = tf.nn.softmax(b_IJ, dim=2)
            
            '''
            [ Last Iteration ]
            '''
            if r_iter == cfg.iter_routing - 1:
                '''
                    # weighting u_hat with c_IJ, element-wise in the last two dims
                    # => [batch_size, 1152, 10, 16, 1]
                '''
                s_J = tf.multiply(c_IJ, u_hat)
                '''
                    # then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1]
                '''
                s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
                
                assert s_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1]

                '''
                [ Squashing ]
                '''
                v_J = squash(s_J)
                
                assert v_J.get_shape() == [cfg.batch_size, 1, 10, 16, 1]
                
            elif r_iter < cfg.iter_routing - 1:  # Inner iterations, do not apply backpropagation
                s_J = tf.multiply(c_IJ, u_hat_stopped)
                s_J = tf.reduce_sum(s_J, axis=1, keep_dims=True)
                v_J = squash(s_J)

                '''
                    # reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 1152, 10, 16, 1]
                    # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
                    # batch_size dim, resulting in [1, 1152, 10, 1, 1]
                '''
                v_J_tiled = tf.tile(v_J, [1, 1152, 1, 1, 1])
                u_produce_v = tf.matmul(u_hat_stopped, v_J_tiled, transpose_a=True)
                
                assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 10, 1, 1]

                b_IJ += u_produce_v

    return v_J


### 6) Squasing

In [None]:
def squash(vector):
    '''Squashing function corresponding to Eq. 1
    < ARGUMENTS >
        vector: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1].
    < OUTPUTS >
        A tensor with the same shape as vector but squashed in 'vec_len' dimension.
    '''
    vec_squared_norm = tf.reduce_sum(tf.square(vector), -2, keep_dims=True)
    scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + epsilon)
    vec_squashed = scalar_factor * vector  # element-wise
    
    return vec_squashed
