<a href="https://colab.research.google.com/github/MoGomaa/DCaps/blob/main/ConvCapsLayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from keras import initializers, layers
import keras.backend as K

In [None]:
tf.__version__

'2.4.1'

In [None]:
class ConvCapsuleLayer(layers.Layer):
  def __init__(self, num_capsule, capsule_dim, routings=3,
               kernel_size=5, strides=1, padding='same', kernel_initializer='he_normal',
               epsilon=1e-7, **kwargs):
    super(ConvCapsuleLayer, self).__init__(**kwargs)
    self.num_capsule        = num_capsule
    self.capsule_dim        = capsule_dim
    self.routings           = routings
    self.kernel_size        = kernel_size
    self.strides            = strides
    self.padding            = padding
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.epsilon            = epsilon
    
  def build(self, input_shape):
    #assert len(input_shape) == 5, "input_shape=[None, input_height, input_width, input_num_capsule, input_capsule_dim]"
    
    self.input_height      = input_shape[1]
    self.input_width       = input_shape[2]
    self.input_num_capsule = input_shape[3]
    self.input_capsule_dim = input_shape[4]
    
    self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size, self.input_capsule_dim, self.num_capsule*self.capsule_dim],
                             initializer=self.kernel_initializer,
                             name='W')
    self.bias = self.add_weight(shape=[1, 1, self.num_capsule, self.capsule_dim],
                                initializer=initializers.constant(0.1),
                                name='b')
    self.built = True

  def call(self, input_tensor, training=None):
    input_shape = K.shape(input_tensor)
    batch_size, input_height, input_width, input_num_capsule, input_capsule_dim = input_tensor.shape
    
    input_transposed = tf.transpose(input_tensor, [0, 3, 1, 2, 4])
    input_reshaped   = K.reshape(input_transposed, [input_shape[0]*input_num_capsule, input_height, input_width, input_capsule_dim])
    input_reshaped.set_shape([None, input_height, input_width, input_capsule_dim])  

    conv = K.conv2d(input_reshaped, self.W, (self.strides, self.strides), padding=self.padding, data_format='channels_last')
    _, conv_height, conv_width, _ = conv.shape
    
    # Reshape back to 6D by splitting first dimmension to batch and input_dim
    # and splitting last dimmension to output_dim and output_atoms.
    votes = K.reshape(conv, [input_shape[0], input_num_capsule, conv_height, conv_width, self.num_capsule, self.capsule_dim])
    votes.set_shape([None, input_num_capsule, conv_height, conv_width, self.num_capsule, self.capsule_dim])
    
    with tf.name_scope("DynamicRouting") as scope:
      b = tf.zeros([input_shape[0], input_num_capsule, conv_height, conv_width, self.num_capsule, 1]) 
      for i in range(self.routings):
        c  = tf.nn.softmax(b, axis=-2)
        s  = tf.reduce_sum(tf.multiply(c, votes), axis=1, keepdims=True) + self.bias
        v  = self.__squash(s)
        a  = tf.squeeze(tf.matmul(tf.expand_dims(votes, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), axis=[-1])
        b += a
    return tf.squeeze(v)

  def __squash(self, s):
    s_norm = tf.norm(s, axis=-1, keepdims=True)
    return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + self.epsilon)

  def compute_output_shape(self, input_shape):
    return (input_shape[0], input_shape[1], input_shape[2]) + (self.num_capsule, self.capsule_dim)

  def get_config(self):
    config = {'num_capsule'        : self.num_capsule,
              'capsule_dim'        : self.capsule_dim,
              'routings'           : self.routings,
              'kernel_size'        : self.kernel_size,
              'strides'            : self.strides,
              'padding'            : self.padding,
              'kernel_initializer' : initializers.serialize(self.kernel_initializer),
              "epsilon"            : self.epsilon}
    base_config = super(ConvCapsuleLayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))