## Capsule Network Implementation

Importing all dependencies

In [1]:
import numpy as np
import tensorflow as tf

from tqdm import tqdm
from datetime import datetime




Setting Parameters

In [14]:
params = {
    "conv_kernels": 256,
    "primary_capsules": 32,
    "secondary_capsules": 10,
    "primary_cap_vector": 8,
    "secondary_cap_vector": 16,
    "r":3,
}

epsilon = 1e-7
lambda_ = 0.5
alpha = 0.0005
epochs = 2
m_plus = 0.9
m_minus = 0.1
secondary_capsules = 10

optimizer_adam = tf.keras.optimizers.Adam()

Retrieving MNIST dataset

In [4]:
(train_x, train_y), (test_x , test_y)= tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Preprocessing data

In [5]:
train_x = train_x / 255.0
test_x = test_x / 255.0

train_x = tf.cast(train_x, dtype=tf.float32)
test_x = tf.cast(test_x, dtype=tf.float32)

train_x = tf.expand_dims(train_x, axis=-1)
test_x = tf.expand_dims(test_x, axis=-1)

test_size = test_x.shape[0]
train_size = train_x.shape[0]

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))

train_dataset = train_dataset.shuffle(buffer_size=len(train_dataset), reshuffle_each_iteration=True)

train_dataset = train_dataset.batch(batch_size=64)
test_dataset = test_dataset.batch(batch_size=64)

Defining Capsule Network class with all its functions

In [6]:
class CapsuleNetwork(tf.keras.Model):
    
    #initialization
    def __init__(self, conv_kernels, primary_capsules, primary_cap_vector, secondary_capsules, secondary_cap_vector, r):
        
        super(CapsuleNetwork, self).__init__()
        self.conv_kernels = conv_kernels
        self.primary_capsules = primary_capsules
        self.primary_cap_vector = primary_cap_vector
        self.secondary_capsules = secondary_capsules
        self.secondary_cap_vector = secondary_cap_vector
        self.r = r
        
        #Assigning the layers
        with tf.name_scope("Variables") as scope:
            
            self.convolution = tf.keras.layers.Conv2D(self.conv_kernels, [9,9], strides=[1,1], name='ConvolutionLayer', activation='relu')
            self.primary_capsule = tf.keras.layers.Conv2D(self.primary_capsules * self.primary_cap_vector, [9,9], strides=[2,2], name="PrimaryCapsule")
            self.w = tf.Variable(tf.random_normal_initializer()(shape=[1, 1152, self.secondary_capsules, self.secondary_cap_vector, self.primary_cap_vector]), dtype=tf.float32, name="PoseEstimation", trainable=True)
            self.dense_1 = tf.keras.layers.Dense(units = 512, activation='relu')
            self.dense_2 = tf.keras.layers.Dense(units = 1024, activation='relu')
            self.dense_3 = tf.keras.layers.Dense(units = 784, activation='sigmoid', dtype='float32')
        
    def build(self, input_shape):
        
        pass
    
    #sqush function
    def squash(self, s):
        
        with tf.name_scope("SquashFunction") as scope:
            
            s_norm = tf.norm(s, axis=-1, keepdims=True)
            return tf.square(s_norm)/(1 + tf.square(s_norm)) * s/(s_norm + epsilon)
    
    @tf.function
    #input function for the NN
    def call(self, inputs):
        
        input_x, input_y = inputs
        x = self.convolution(input_x) 
        x = self.primary_capsule(x) 
        
        #Defining Capsule
        with tf.name_scope("CapsuleFormation") as scope:

            u = tf.reshape(x, (-1, self.primary_capsules * x.shape[1] * x.shape[2], 8)) 
            u = tf.expand_dims(u, axis=-2) 
            u = tf.expand_dims(u, axis=-1) 
            u_hat = tf.matmul(self.w, u) 
            u_hat = tf.squeeze(u_hat, [4]) 

            
        #Routing mechanism
        with tf.name_scope("DynamicRouting") as scope:
            
            b = tf.zeros((input_x.shape[0], 1152, self.secondary_capsules, 1))
            for _ in range(self.r):
                c = tf.nn.softmax(b, axis=-2) 
                s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) 
                v = self.squash(s) 
                agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) 
                b += agreement
        
        
        #Maksing
        with tf.name_scope("Masking") as scope:
            
            y = tf.expand_dims(input_y, axis=-1) 
            y = tf.expand_dims(y, axis=1) 
            mask = tf.cast(y, dtype=tf.float32) 
            v_masked = tf.multiply(mask, v) 
        
        #Reconstructing the images
        with tf.name_scope("Reconstruction") as scope:
            v_ = tf.reshape(v_masked, [-1, self.secondary_capsules * self.secondary_cap_vector]) 
            reconstructed_image = self.dense_1(v_) 
            reconstructed_image = self.dense_2(reconstructed_image) 
            reconstructed_image = self.dense_3(reconstructed_image) 
        
        return v, reconstructed_image

    @tf.function
    #Prediction function 
    def predict_capsule_output(self, inputs):
        
        x = self.convolution(inputs) 
        x = self.primary_capsule(x) 
        
        with tf.name_scope("CapsuleFormation") as scope:
            
            u = tf.reshape(x, (-1, self.primary_capsules * x.shape[1] * x.shape[2], 8)) 
            u = tf.expand_dims(u, axis=-2) 
            u = tf.expand_dims(u, axis=-1) 
            u_hat = tf.matmul(self.w, u) 
            u_hat = tf.squeeze(u_hat, [4]) 

        
        with tf.name_scope("DynamicRouting") as scope:
            
            b = tf.zeros((inputs.shape[0], 1152, self.secondary_capsules, 1)) 
            for i in range(self.r): 
                c = tf.nn.softmax(b, axis=-2) 
                s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1, keepdims=True) 
                v = self.squash(s) 
                agreement = tf.squeeze(tf.matmul(tf.expand_dims(u_hat, axis=-1), tf.expand_dims(v, axis=-1), transpose_a=True), [4]) 
                b += agreement
                
        return v

    @tf.function
    #Regenration function
    def regenerate_image(self, inputs):
        
        with tf.name_scope("Reconstruction") as scope:
            
            v_ = tf.reshape(inputs, [-1, self.secondary_capsules * self.secondary_cap_vector]) 
            reconstructed_image = self.dense_1(v_) 
            reconstructed_image = self.dense_2(reconstructed_image) 
            reconstructed_image = self.dense_3(reconstructed_image) 
            
        return reconstructed_image

Declaring normalization function and assigning class to model variable

In [7]:
model = CapsuleNetwork(**params)




In [8]:
def safe_norm(v, axis=-1, epsilon=1e-7):
    v_ = tf.reduce_sum(tf.square(v), axis = axis, keepdims=True)
    return tf.sqrt(v_ + epsilon)


Defining loss function

In [9]:
def loss_function(v, reconstructed_image, y, y_image):
    
    prediction = safe_norm(v)
    prediction = tf.reshape(prediction, [-1, secondary_capsules])
    
    left_margin = tf.square(tf.maximum(0.0, m_plus - prediction))
    right_margin = tf.square(tf.maximum(0.0, prediction - m_minus))
    
    l = tf.add(y * left_margin, lambda_ * (1.0 - y) * right_margin)
    
    margin_loss = tf.reduce_mean(tf.reduce_sum(l, axis=-1))
    
    y_image_flat = tf.reshape(y_image, [-1, 784])
    reconstruction_loss = tf.reduce_mean(tf.square(y_image_flat - reconstructed_image))
    
    loss = tf.add(margin_loss, alpha * reconstruction_loss)
    
    return loss

Defining traning function

In [10]:
def train(x,y):
    
    y_one_hot = tf.one_hot(y, depth=10)
    with tf.GradientTape() as tape:
        
        v, reconstructed_image = model([x, y_one_hot])
        loss = loss_function(v, reconstructed_image, y_one_hot, x)
        
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer_adam.apply_gradients(zip(grad, model.trainable_variables))
    
    return loss

_ = train(train_x[:32],train_y[:32])


Prediction function declaration and creating a checkpoint

In [11]:
def predict(model, x):
    pred = safe_norm(model.predict_capsule_output(x))
    pred = tf.squeeze(pred, [1])
    return np.argmax(pred, axis=1)[:,0]
checkpoint = tf.train.Checkpoint(model=model)

Running the training function

In [15]:
losses = []
accuracy = []

for i in range(1, epochs+1, 1):

    loss = 0
    with tqdm(total=len(train_dataset)) as progress_bar:
        
        description = "Epoch " + str(i) + " of " + str(epochs)
        progress_bar.set_description_str(description)

        for X_batch, y_batch in train_dataset:

            loss += train(X_batch,y_batch)
            progress_bar.update(1)

        loss /= len(train_dataset)
        losses.append(loss.numpy())
        
        training_sum = 0

        #statement = "Loss :" + str(loss.numpy()) + " Evaluating Accuracy ..."
        progress_bar.set_postfix_str("Loss :" + str(loss.numpy()) + " Evaluating Accuracy ...")

        for X_batch, y_batch in train_dataset:
            training_sum += sum(predict(model, X_batch)==y_batch.numpy())
        accuracy.append(training_sum/train_size)

        #with file_writer.as_default():
            #tf.summary.scalar('Loss', data=loss.numpy(), step=i)
            #tf.summary.scalar('Accuracy', data=accuracy[-1], step=i)
        
        #print_statement = "Loss :" + str(loss.numpy()) + " Accuracy :" + str(accuracy[-1])

        if i % 10 == 0:
            #print_statement += ' Checkpoint Saved'
            checkpoint.save(checkpoint_path)
        
        progress_bar.set_postfix_str("Loss :" + str(loss.numpy()) + " Accuracy :" + str(accuracy[-1]))

Epoch 1 of 2: 100%|██████████| 938/938 [49:03<00:00,  3.14s/it, Loss :0.045737065 Accuracy :0.9859833333333333]
Epoch 2 of 2: 100%|██████████| 938/938 [44:44<00:00,  2.86s/it, Loss :0.016288744 Accuracy :0.9901]       


Showing accuracy of model

In [16]:
test_sum = 0
for X_batch, y_batch in test_dataset:
    test_sum += sum(predict(model, X_batch)==y_batch.numpy())
print(test_sum/test_size)

0.9853
