In [1]:
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("datasets/")

Extracting datasets/train-images-idx3-ubyte.gz
Extracting datasets/train-labels-idx1-ubyte.gz
Extracting datasets/t10k-images-idx3-ubyte.gz
Extracting datasets/t10k-labels-idx1-ubyte.gz


In [3]:
train_data = mnist.train.images
test_data = mnist.test.images
validation_data = mnist.validation.images

print("Training -", len(train_data))
print("Testing -", len(test_data))
print("Validating -", len(validation_data))

Training - 55000
Testing - 10000
Validating - 5000


In [4]:
# input variable
X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
X

<tf.Tensor 'X:0' shape=(?, 28, 28, 1) dtype=float32>

In [5]:
# first convolutional layer
conv1 = tf.layers.conv2d(X, filters=256, kernel_size=9, strides=1, padding="valid", activation=tf.nn.relu, name="conv1")

In [6]:
# second convolutional layer
conv2 = tf.layers.conv2d(conv1, filters=256, kernel_size=9, strides=2, padding="valid", activation=tf.nn.relu, name="conv2")

In [7]:
# Capsule Layer hyper parameters
caps1_dimension = 8
caps1_maps = 32
caps1_capsules = 1152     # 32*6*6

In [8]:
# second convolutional layer reshapes to form a capsule of shape batch_size, 1152, 8
caps1 = tf.reshape(conv2, [-1, caps1_capsules, caps1_dimension], name="caps1") 
caps1

<tf.Tensor 'caps1:0' shape=(?, 1152, 8) dtype=float32>

In [9]:
# function to squash vectors to ensure that their length is between 0 and 1
def squash(s, axis=-1, epsilon=1e-7, name=None):
    with tf.name_scope(name, default_name="squash"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
                                     keepdims=True)
        safe_norm = tf.sqrt(squared_norm + epsilon)
        squash_factor = squared_norm / (1. + squared_norm)
        unit_vector = s / safe_norm
        return squash_factor * unit_vector

In [10]:
caps1_output = squash(caps1, name="caps1_output")

In [11]:
caps1_output

<tf.Tensor 'caps1_output/mul:0' shape=(?, 1152, 8) dtype=float32>

In [12]:
# Digit layer hyper paramaeters
caps2_capsules = 10
caps2_dimension = 16

In [13]:
W_init = tf.random_normal(
                          shape=(1, caps1_capsules, caps2_capsules, caps1_dimension, caps2_dimension),
                          stddev=0.01, dtype=tf.float32)
W = tf.Variable(W_init, name="W")
W

batch_size = tf.shape(X)[0]

W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")
W_tiled

<tf.Tensor 'W_tiled:0' shape=(?, 1152, 10, 8, 16) dtype=float32>

In [14]:
caps1_output_expanded1 = tf.expand_dims(caps1_output, -2, name="caps1_output_expanded1")
caps1_output_expanded1

<tf.Tensor 'caps1_output_expanded1:0' shape=(?, 1152, 1, 8) dtype=float32>

In [15]:
caps1_output_expanded2 = tf.expand_dims(caps1_output_expanded1, -3, name="caps1_output_expanded2")
caps1_output_expanded2

<tf.Tensor 'caps1_output_expanded2:0' shape=(?, 1152, 1, 1, 8) dtype=float32>

In [16]:
caps1_output_tiled = tf.tile(caps1_output_expanded2, [1, 1, caps2_capsules, 1, 1], name="caps1_output_tiled")
caps1_output_tiled

<tf.Tensor 'caps1_output_tiled:0' shape=(?, 1152, 10, 1, 8) dtype=float32>

In [17]:
caps2_predicted = tf.matmul(caps1_output_tiled, W_tiled, name="caps2_perdicted")
caps2_predicted

<tf.Tensor 'caps2_perdicted:0' shape=(?, 1152, 10, 1, 16) dtype=float32>

In [18]:
# Routing by agreement


In [19]:
# Round 1

raw_weights = tf.zeros([batch_size, caps1_capsules, caps2_capsules, 1, 1],
                        dtype=tf.float32, name="raw_weights")
routing_weights = tf.nn.softmax(raw_weights, axis=2, name="routing_weights")
weighted_predictions = tf.multiply(routing_weights, caps2_predicted, name="weighted_predictions")
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keepdims=True, name="weighted_sum")
caps2_output_round1 = squash(weighted_sum, axis=2, name="caps2_output_round1")

In [20]:
caps2_output_round1

<tf.Tensor 'caps2_output_round1/mul:0' shape=(?, 1, 10, 1, 16) dtype=float32>

In [21]:
caps2_output_round1_tiled = tf.tile(caps2_output_round1, [1, caps1_capsules, 1, 1, 1], name="caps2_output_round1_tiled")
caps2_output_round1_tiled

<tf.Tensor 'caps2_output_round1_tiled:0' shape=(?, 1152, 10, 1, 16) dtype=float32>

In [22]:
agreement = tf.matmul(caps2_predicted, caps2_output_round1_tiled, transpose_b=True, name="agreement")
agreement

<tf.Tensor 'agreement:0' shape=(?, 1152, 10, 1, 1) dtype=float32>

In [23]:
# Round 2

raw_weights_round2 = tf.add(raw_weights, agreement, name="raw_weights_round2")
routing_weights_round2 = tf.nn.softmax(raw_weights_round2, axis=2, name="routing_weights_round2")
weighted_predictions_round2 = tf.multiply(routing_weights_round2, caps2_predicted, name="weighted_predictions_round2")
weighted_sum_round2 = tf.reduce_sum(weighted_predictions_round2, axis=1, keepdims=True, name="weighted_sum_round2")
caps2_output_round2 = squash(weighted_sum_round2, axis=-2, name="caps2_output_round2")

In [24]:
caps2_output = caps2_output_round2
caps2_output

<tf.Tensor 'caps2_output_round2/mul:0' shape=(?, 1, 10, 1, 16) dtype=float32>

In [25]:
# Estimated class probabilities

def safe_norm(s, axis=-1, epsilon=1e-7, keepdims=False, name=None):
    with tf.name_scope(name, default_name="safe_norm"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
                                     keepdims=keepdims)
        return tf.sqrt(squared_norm + epsilon)

In [26]:
y_prob = safe_norm(caps2_output, axis=-1, name="y_prob")
y_prob

<tf.Tensor 'y_prob/Sqrt:0' shape=(?, 1, 10, 1) dtype=float32>

In [27]:
y_prob_argmax = tf.argmax(y_prob, axis=2, name="t_prob_argmax")
y_prob_argmax

<tf.Tensor 't_prob_argmax:0' shape=(?, 1, 1) dtype=int64>

In [28]:
y_pred = tf.squeeze(y_prob_argmax, axis=[1,2], name="y_pred")
y_pred

<tf.Tensor 'y_pred:0' shape=(?,) dtype=int64>

In [29]:
y = tf.placeholder(shape=[None], dtype = tf.int64, name="y")
y

<tf.Tensor 'y:0' shape=(?,) dtype=int64>

In [30]:
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5

In [31]:
T = tf.one_hot(y, depth=caps2_capsules, name="T")
T

<tf.Tensor 'T:0' shape=(?, 10) dtype=float32>

In [32]:
caps2_output_norm = safe_norm(caps2_output, axis=-1, keepdims=True, name="caps2output_norm")
caps2_output_norm

<tf.Tensor 'caps2output_norm/Sqrt:0' shape=(?, 1, 10, 1, 1) dtype=float32>

In [33]:
# square of max(0,m+ - v)

present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm))
present_error = tf.reshape(present_error_raw, shape=(-1, 10), name="present_error")
present_error_raw

<tf.Tensor 'Square:0' shape=(?, 1, 10, 1, 1) dtype=float32>

In [34]:
# square of max(0,v - m-)

absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus), name="absent_error_raw")
absent_error = tf.reshape(absent_error_raw, shape=(-1, 10), name="absent_error")
absent_error

<tf.Tensor 'absent_error:0' shape=(?, 10) dtype=float32>

In [35]:
L = tf.add(T * present_error, lambda_ * (1 - T) * absent_error, name="L")
L

<tf.Tensor 'L:0' shape=(?, 10) dtype=float32>

In [36]:
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name="margin_loss")
margin_loss

<tf.Tensor 'margin_loss:0' shape=() dtype=float32>

In [37]:
# Reconstruction

In [38]:
mask_with_labels = tf.placeholder_with_default(False, shape=(), name="mask_with_labels")
mask_with_labels

<tf.Tensor 'mask_with_labels:0' shape=() dtype=bool>

In [39]:
reconstruction_targets = tf.cond(mask_with_labels, # condition
                                 lambda: y, # if True
                                 lambda: y_pred, # if False
                                 name="reconstruction_targets")
reconstruction_targets

<tf.Tensor 'reconstruction_targets/Merge:0' shape=(?,) dtype=int64>

In [40]:
reconstruction_mask = tf.one_hot(reconstruction_targets, depth=caps2_capsules, name="resruction_mask")
reconstruction_mask

<tf.Tensor 'resruction_mask:0' shape=(?, 10) dtype=float32>

In [41]:
caps2_output

<tf.Tensor 'caps2_output_round2/mul:0' shape=(?, 1, 10, 1, 16) dtype=float32>

In [42]:
reconstruction_mask_reshaped = tf.reshape(reconstruction_mask, [-1, 1, caps2_capsules, 1, 1], name="reconstruction_mask_reshaped")
reconstruction_mask_reshaped

<tf.Tensor 'reconstruction_mask_reshaped:0' shape=(?, 1, 10, 1, 1) dtype=float32>

In [43]:
caps2_output_masked = tf.multiply(reconstruction_mask_reshaped, caps2_output, name="casp2_output_masked")
caps2_output_masked

<tf.Tensor 'casp2_output_masked:0' shape=(?, 1, 10, 1, 16) dtype=float32>

In [44]:
decoder_input = tf.reshape(caps2_output_masked, [-1, caps2_capsules * caps2_dimension], name="decoder_input")
decoder_input

<tf.Tensor 'decoder_input:0' shape=(?, 160) dtype=float32>

In [45]:
# Decoder

In [46]:
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28

In [47]:
with tf.name_scope("decoder"):
    hidden1 = tf.layers.dense(decoder_input, n_hidden1, activation=tf.nn.relu)#, name="hidden1")
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu)#, name="hidden2")
    decoder_output = tf.layers.dense(hidden2, n_output, activation=tf.nn.relu)#, name="decoder_output")

In [48]:
# Reconstruction Loss

In [49]:
X_flat = tf.reshape(X, [-1, n_output], name="X_flat")

In [50]:
squared_difference = tf.square(X_flat - decoder_output, name="squared_difference")
reconstruction_loss = tf.reduce_mean(squared_difference, name="reconstruction_loss")

In [51]:
# Final loss

In [52]:
alpha = 0.0005

loss = tf.add(margin_loss, alpha * reconstruction_loss, name="loss")

In [53]:
# Accuracy

correct = tf.equal(y, y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

In [54]:
# Training Operations

optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss, name="training_op")

In [55]:
init =tf.global_variables_initializer()
saver = tf.train.Saver()

In [56]:
# Training

In [57]:
n_epochs = 1
batch_size = 50
restore_checkpoint = True

n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty
checkpoint_path = "my_capsule_network/capsule.ckpt"
summary_path = "my_capsule_network/"

with tf.Session() as sess:    
    if restore_checkpoint and tf.train.checkpoint_exists(checkpoint_path):
        saver.restore(sess, checkpoint_path)
    else:
        init.run()
        for epoch in range(n_epochs):
            for iteration in range(1, n_iterations_per_epoch + 1):
                X_batch, y_batch = mnist.train.next_batch(batch_size)


                # Run the training operation and measure the loss:
                _, loss_train = sess.run([training_op, loss],
                                          feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                                                     y: y_batch,
                                                     mask_with_labels:True})

                print("\rIteration: {}/{} ({:.1f}%)  Loss: {:.5f}".format(
                          iteration, n_iterations_per_epoch,iteration * 100 / n_iterations_per_epoch,loss_train),end="")

            # At the end of each epoch,
            # measure the validation loss and accuracy:
            loss_vals = []
            acc_vals = []
            for iteration in range(1, n_iterations_validation + 1):
                X_batch, y_batch = mnist.validation.next_batch(batch_size)
                loss_val, acc_val = sess.run([loss, accuracy],
                                             feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                                                        y: y_batch})
                loss_vals.append(loss_val)
                acc_vals.append(acc_val)
                print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
                          iteration, n_iterations_validation,
                          iteration * 100 / n_iterations_validation),
                      end=" " * 10)
            loss_val = np.mean(loss_vals)
            acc_val = np.mean(acc_vals)
            print("\rEpoch: {}  Val accuracy: {:.4f}%  Loss: {:.6f}{}".format(
                        epoch + 1, acc_val * 100, loss_val, " (improved)" if loss_val < best_loss_val else ""))

            # And save the model if it improved:
            if loss_val < best_loss_val:
                save_path = saver.save(sess, checkpoint_path)
                best_loss_val = loss_val

                writer = tf.summary.FileWriter(summary_path, sess.graph)

INFO:tensorflow:Restoring parameters from my_capsule_network/capsule.ckpt


In [None]:
n_iterations_test = mnist.test.num_examples // batch_size

with tf.Session() as sess:
    
    saver.restore(sess, checkpoint_path)
    
    loss_tests = []
    acc_tests = []
    
    for iteration in range(n_iterations_test + 1):
        X_batch, y_batch = mnist.test.next_batch(batch_size)
        
        loss_test, acc_test = sess.run([loss, accuracy],
                                       feed_dict = {X: X_batch.reshape([-1, 28, 28, 1]),
                                                    y: y_batch})
        loss_tests.append(loss_test)
        acc_tests.append(acc_test)
        
        
        print("\nEvaluating the model: {}/{} ({:.1f}%)".format(
                    iteration, n_iterations_test, iteration * 100 / n_iterations_test), end=" " * 10)
        
    loss_test = np.mean(loss_tests)
    acc_test = np.mean(acc_tests)
        
    print("\rFinal test accuracy: {: 4f}%  Loss: {:.6f}".format(acc_test * 100, loss_test))

INFO:tensorflow:Restoring parameters from my_capsule_network/capsule.ckpt

Evaluating the model: 0/200 (0.0%)          

In [None]:
# tensorboard --logdir="C:/Users/Admin/Documents/anaconda files"

In [84]:
"""
a = mnist.test.images[0:1]
b = mnist.test.labels[0:1]

plt.imshow(a.reshape([28,28]))

sess = tf.Session()
    
saver.restore(sess, checkpoint_path)
    
e, f, g = sess.run([loss, accuracy, decoder_output], feed_dict = {X: a.reshape([-1, 28, 28, 1]), y: b})

h = sess.run(y_pred, feed_dict = {X: a.reshape([-1, 28, 28, 1])})
                                  
sess.close()
    
print(g.size)

x = g.reshape([28,28])
print("H = ", h)

plt.imshow(x)

"""

'\na = mnist.test.images[0:1]\nb = mnist.test.labels[0:1]\n\nplt.imshow(a.reshape([28,28]))\n\nsess = tf.Session()\n    \nsaver.restore(sess, checkpoint_path)\n    \ne, f, g = sess.run([loss, accuracy, decoder_output], feed_dict = {X: a.reshape([-1, 28, 28, 1]), y: b})\n\nh = sess.run(y_pred, feed_dict = {X: a.reshape([-1, 28, 28, 1])})\n                                  \nsess.close()\n    \nprint(g.size)\n\nx = g.reshape([28,28])\nprint("H = ", h)\n\nplt.imshow(x)\n\n'