<a href="https://colab.research.google.com/github/MoGomaa/DCaps/blob/main/D_Caps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import keras.backend as K

from keras import layers

In [2]:
tf.__version__

'2.4.1'

**Note**<br />If you using Google Colab with Google Drive you have to mount your drive first then using import_ipynb lib to import ConvCapsuleLayer
<br />This might be useful: [stackoverflow link](https://stackoverflow.com/questions/59020008/how-to-import-functions-of-a-jupyter-notebook-into-another-jupyter-notebook-in-g) 

In [None]:
from ConvCapsLayer import ConvCapsuleLayer

In [3]:
class ExpandDim(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.expand_dims(inputs, axis=-2)

    def compute_output_shape(self, input_shape):
        return (input_shape[0:-1] + (1,) + input_shape[-1:])

    def get_config(self):
        config = {}
        base_config = super(ExpandDim, self).get_config()

        return dict(list(base_config.items()) + list(config.items()))

class RemoveDim(layers.Layer):
    def call(self, inputs, **kwargs):
        return K.squeeze(inputs, axis=-2)

    def compute_output_shape(self, input_shape):
        return (input_shape[0:-2] + input_shape[-1:])

    def get_config(self):
        config = {}
        base_config = super(RemoveDim, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [4]:
def DiagnosisCapsules(input_shape, n_class=2, k_size=5, output_capsule_dim=16):
  if n_class == 2:
    n_class = 1 # binary output
  
  inputs         = layers.Input(shape=input_shape)

  conv1          = layers.Conv2D(filters=16, kernel_size=k_size, strides=2, padding='same', activation='relu', name='conv1')(inputs)
  conv1_reshaped = ExpandDim(name='expand_dim')(conv1)  # conv1_reshaped.shape: (None, 256, 320, 1, 16)

  primary_caps   = ConvCapsuleLayer(num_capsule=2, capsule_dim=16, routings=1,
                                    kernel_size=k_size, strides=2, padding='same', name='primary_caps')(conv1_reshaped)
                                    # primary_caps.shape: (None, 128, 160, 2, 16)

  conv_cap_2_1   = ConvCapsuleLayer(num_capsule=4, capsule_dim=16, routings=3,
                                    kernel_size=k_size, strides=1, padding='same', name='conv_cap_2_1')(primary_caps)
                                    # conv_cap_2_1.shape: (None, 128, 160, 4, 16)

  conv_cap_2_2   = ConvCapsuleLayer(num_capsule=4, capsule_dim=32, routings=3,
                                    kernel_size=k_size, strides=2, padding='same', name='conv_cap_2_2')(conv_cap_2_1)
                                    # conv_cap_2_2.shape: (None, 64, 80, 4, 32)

  conv_cap_3_1   = ConvCapsuleLayer(num_capsule=8, capsule_dim=32, routings=3,
                                    kernel_size=k_size, strides=1, padding='same', name='conv_cap_3_1')(conv_cap_2_2)
                                    # conv_cap_3_1.shape: (None, 64, 80, 8, 32)

  conv_cap_3_2   = ConvCapsuleLayer(num_capsule=8, capsule_dim=64, routings=3,
                                    kernel_size=k_size, strides=2, padding='same', name='conv_cap_3_2')(conv_cap_3_1)
                                    # conv_cap_3_2.shape: (None, 32, 40, 8, 64)

  conv_cap_4_1   = ConvCapsuleLayer(num_capsule=8, capsule_dim=32, routings=3,
                                    kernel_size=k_size, strides=1, padding='same', name='conv_cap_4_1')(conv_cap_3_2)
                                    # conv_cap_4_1.shape: (None, 32, 40, 8, 32)

  conv_cap_4_2   = ConvCapsuleLayer(num_capsule=n_class, capsule_dim=output_capsule_dim, routings=3,
                                    kernel_size=k_size, strides=2, padding='same', name='conv_cap_4_2')(conv_cap_4_1) 
                                    # conv_cap_4_2.shape: (None, 16, 20, 16)     -- Binary     (2 clsses)
                                    #                     (None, 16, 20, n, 16)  -- Multiclass (n clsses)