<a href="https://colab.research.google.com/github/MoGomaa/CapsuleNetworks/blob/main/CapsNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

In [2]:
tf.__version__

'2.4.1'

In [None]:
# If you using Google Colab with Google Drive you have to mount your drive first then using import_ipynb lib to import CapsuleLayer
# This might be useful: https://stackoverflow.com/questions/59020008/how-to-import-functions-of-a-jupyter-notebook-into-another-jupyter-notebook-in-g

from CapsLayer import CapsuleLayer

In [None]:
class CapsuleNetwork(tf.keras.Model):
  def __init__(self, conv_kernels, primary_num_capsules, primary_capsule_dimension, digit_num_capsules, digit_capsule_dimension, routings):
    super(CapsuleNetwork, self).__init__()
    self.conv_kernels              = conv_kernels
    self.primary_num_capsules      = primary_num_capsules
    self.primary_capsule_dimension = primary_capsule_dimension
    self.digit_num_capsules        = digit_num_capsules
    self.digit_capsule_dimension   = digit_capsule_dimension
    self.routings                  = routings

    with tf.name_scope("Variables") as scope:
      self.conv1       = tf.keras.layers.Conv2D(self.conv_kernels, kernel_size=[9,9], strides=[1,1], name='conv1', activation='relu')
      self.conv2       = tf.keras.layers.Conv2D(self.conv_kernels, kernel_size=[9,9], strides=[2,2], name="conv2")
      
      self.CapsPrimary = CapsuleLayer(primary_num_capsules, primary_capsule_dimension, layer_type="CapsPrimary")
      self.CapsFC      = CapsuleLayer(digit_num_capsules, digit_capsule_dimension, layer_type="CapsFC")
      
      self.dense_1     = tf.keras.layers.Dense(units = 512, activation='relu')
      self.dense_2     = tf.keras.layers.Dense(units = 1024, activation='relu')
      self.dense_3     = tf.keras.layers.Dense(units = 784, activation='sigmoid', dtype='float32')

  @tf.function
  def call(self, inputs):
    input_x, y = inputs                                                                   # input_x.shape             : (None, 28, 28, 1)
                                                                                          # y.shape                   : (None, 10)
    
    outs = self.predict_capsule_output(input_x)                                           # outs.shape                : (None, 10, 16)
    
    with tf.name_scope("Masking") as scope:
      y = tf.expand_dims(y, axis=-1)                                                      # y.shape                   : (None, 10, 1)
      mask = tf.cast(y, dtype=tf.float32)                                                 # mask.shape                : (None, 10, 1)
      outs_masked = tf.multiply(mask, outs)                                               # outs_masked.shape         : (None, 1, 10, 16)

    reconstructed_image = self.regenerate_image(outs_masked)                              # reconstructed_image.shape : (None, 784)
    
    return outs, reconstructed_image

  @tf.function
  def predict_capsule_output(self, inputs):
    x    = self.conv1(inputs)                                                             # x.shape    : (None, 20, 20, 256)
    x    = self.conv2(x)                                                                  # x.shape    : (None, 6, 6, 256)
    x    = self.CapsPrimary(x)                                                            # x.shape    : (None, 1152, 8)
    outs = self.CapsFC(x)                                                                 # outs.shape : (None, 10, 16)
    return outs

  @tf.function
  def regenerate_image(self, inputs):
    x = tf.reshape(inputs, [-1, self.digit_num_capsules * self.digit_capsule_dimension])  # x.shape                   : (None, 160)
    x = self.dense_1(x)                                                                   # x.shape                   : (None, 512)
    x = self.dense_2(x)                                                                   # x.shape                   : (None, 1024)
    reconstructed_image = self.dense_3(x)                                                 # reconstructed_image.shape : (None, 784)
    return reconstructed_image