<a href="https://colab.research.google.com/github/MoGomaa/CapsuleNetworks/blob/main/CapsLayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import initializers, layers

In [2]:
tf.__version__

'2.4.1'

In [3]:
class CapsuleLayer(layers.Layer):
  """
  - input shape  = [None, input_num_capsule, input_dim_capsule].
  - output shape = [None, num_capsule, dim_capsule].

  - param num_capsule: number of capsules in this layer.                                ( int )
  - param dim_capsule: dimension of the output vectors of the capsules in this layer.   ( int )
  - param routings   : number of iterations for the routing algorithm.                  ( int )
  - param layer_type : 'CapsFC' or "CapsPrimary".                                               ( string )
  """
  def __init__(self, num_capsule, dim_capsule, routings=3, layer_type="FC", epsilon=1e-7, kernel_initializer='glorot_uniform', **kwargs):
    super(CapsuleLayer, self).__init__(**kwargs)

    self.num_capsule        = num_capsule
    self.dim_capsule        = dim_capsule
    self.routings           = routings
    self.layer_type         = layer_type
    self.epsilon            = epsilon
    self.kernel_initializer = initializers.get(kernel_initializer)
  
  def build(self, input_shape):
    if self.layer_type == "CapsFC":
      assert len(input_shape) == 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]" + str(input_shape)
      self.input_num_capsule = input_shape[1]
      self.input_dim_capsule = input_shape[2]
      # Transform matrix, from each input capsule to each output capsule, there's a unique weight as in Dense layer.
      self.W = self.add_weight(shape=[1, self.input_num_capsule, self.num_capsule, self.dim_capsule , self.input_dim_capsule],
                              name='W', initializer=self.kernel_initializer, dtype=tf.float32)
  
  def call(self, inputs, training=None):
    self.batch_size = inputs.shape[0]
       
    if self.layer_type == "CapsPrimary":
      assert len(inputs.shape) == 4, "The input Tensor should have shape=[batch_size, input_width, input_height, input_channels]" 
      assert inputs.shape[1]*inputs.shape[2]*inputs.shape[3] == self.num_capsule*self.dim_capsule, "inputs.shape[1]*inputs.shape[2]*inputs.shape[3] != self.num_capsule*self.dim_capsule"
      return tf.reshape(inputs, (-1, self.num_capsule, self.dim_capsule))
    
    elif self.layer_type == "CapsFC":
      assert self.routings > 0, "Thr number routings must be greater than 0."
      with tf.name_scope("CapsuleFormation") as scope:
        u     = tf.expand_dims(tf.expand_dims(inputs, axis=-2),axis=-1)       # u.shape:     (batch_size, input_num_capsule, 1, input_dim_capsule, 1)
        u_hat = tf.squeeze(tf.matmul(self.W, u))                              # u_hat.shape: (batch_size, input_num_capsule, num_capsule, dim_capsule)
        
      with tf.name_scope("DynamicRouting") as scope:
        b = tf.zeros((inputs.shape[0], self.input_num_capsule, self.num_capsule, 1))
        for i in range(self.routings):
          c  = tf.nn.softmax(b, axis=-2)                                      # c.shape: (batch_size, input_num_capsule, num_capsule, 1)
          s  = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True)    # s.shape: (batch_size, 1, num_capsule, dim_capsule)
          v  = self.__squash(s)                                               # v.shape: (batch_size, 1, num_capsule, dim_capsule)
          a  = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), axis=[-1])
          b += a
      
      return tf.squeeze(v)
  
  def __squash(self, s):
    s_norm = tf.norm(s, axis=-1, keepdims=True)
    return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + self.epsilon)

  def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])
  
  def get_config(self):
    config = {'num_capsule': self.num_capsule,
              'dim_capsule': self.dim_capsule,
              'routings': self.routings}
    base_config = super(CapsuleLayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))