In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import utils

import numpy as np

from sklearn.model_selection import train_test_split

In [2]:
batch_size = 128
num_classes = 10
img_rows, img_cols = 28, 28

In [3]:
mnist = keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()

In [4]:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)

In [5]:
def squash(x, axis=-1):
    squared_norm = keras.backend.sum(keras.backend.square(x), axis, keepdims=True) + keras.backend.epsilon()
    scale = keras.backend.sqrt(squared_norm)/(1+squared_norm)
    return scale*x

In [6]:
class digitCaps(keras.layers.Layer):
    def __init__(self,in_type=1152, in_shape=8, out_type=10, out_shape=16, routings=3, activation='squash',**kwargs):
        self.in_type = in_type
        self.in_shape = in_shape       
        self.out_type = out_type
        self.out_shape = out_shape
        self.routings = routings
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = keras.layers.Activation(activation)
        super(digitCaps,self).__init__(**kwargs)
        
    def build(self,input_shape):
        self.kernel1 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel2 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel3 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel4 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel5 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel6 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel7 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel8 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel9 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel10 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        super(digitCaps,self).build(input_shape)

    def call(self,inputs):
        type1 = tf.squeeze(tf.matmul(inputs, self.kernel1),axis=2)
        type2 = tf.squeeze(tf.matmul(inputs, self.kernel2),axis=2)
        type3 = tf.squeeze(tf.matmul(inputs, self.kernel3),axis=2)
        type4 = tf.squeeze(tf.matmul(inputs, self.kernel4),axis=2)
        type5 = tf.squeeze(tf.matmul(inputs, self.kernel5),axis=2)
        type6 = tf.squeeze(tf.matmul(inputs, self.kernel6),axis=2)
        type7 = tf.squeeze(tf.matmul(inputs, self.kernel7),axis=2)
        type8 = tf.squeeze(tf.matmul(inputs, self.kernel8),axis=2)
        type9 = tf.squeeze(tf.matmul(inputs, self.kernel9),axis=2)
        type10 = tf.squeeze(tf.matmul(inputs, self.kernel10),axis=2)
        result = tf.stack([type1,type2,type3,type4,type5,type6,type7,type8,type9,type10],1)
        B = tf.zeros_like(tf.transpose(tf.transpose(result)[1]))
        for i in range(self.routings):
            C = tf.nn.softmax(B,1)
            O = tf.einsum('ijk,ijkl->ijl', C, result)
            V = squash(O,-1)
            B = B + tf.einsum('ijk,ijlk->ijl', V, result)
        return V

In [7]:
class selfCaps(keras.layers.Layer):
    def __init__(self,in_type=1152, in_shape=8, out_type=10, out_shape=16, activation='squash',**kwargs):
        self.in_type = in_type
        self.in_shape = in_shape       
        self.out_type = out_type
        self.out_shape = out_shape
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = keras.layers.Activation(activation)
        super(selfCaps,self).__init__(**kwargs)
        
    def build(self,input_shape):
        self.kernel1 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel2 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel3 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel4 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel5 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel6 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel7 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel8 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel9 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.kernel10 = self.add_weight(shape=(self.in_type,self.in_shape,self.out_shape),initializer="uniform", trainable=True)
        self.O = self.add_weight(shape=(self.out_type,self.out_shape,self.out_shape),initializer="ones", trainable=True)
        super(selfCaps,self).build(input_shape)

    def call(self,inputs):
        type1 = tf.squeeze(tf.matmul(inputs, self.kernel1),axis=2)
        type2 = tf.squeeze(tf.matmul(inputs, self.kernel2),axis=2)
        type3 = tf.squeeze(tf.matmul(inputs, self.kernel3),axis=2)
        type4 = tf.squeeze(tf.matmul(inputs, self.kernel4),axis=2)
        type5 = tf.squeeze(tf.matmul(inputs, self.kernel5),axis=2)
        type6 = tf.squeeze(tf.matmul(inputs, self.kernel6),axis=2)
        type7 = tf.squeeze(tf.matmul(inputs, self.kernel7),axis=2)
        type8 = tf.squeeze(tf.matmul(inputs, self.kernel8),axis=2)
        type9 = tf.squeeze(tf.matmul(inputs, self.kernel9),axis=2)
        type10 = tf.squeeze(tf.matmul(inputs, self.kernel10),axis=2)
        result = tf.stack([type1,type2,type3,type4,type5,type6,type7,type8,type9,type10],1)
        V = tf.reduce_sum(result, axis=2, keepdims=True)
        V = squash(V,-1)
        V = tf.squeeze(tf.matmul(V, self.O),axis=2)
        return V

In [8]:
class mask(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(mask,self).__init__(**kwargs)
        
    def build(self,input_shape):
        super(mask,self).build(input_shape)

    def call(self,inputs):
        info = keras.backend.sqrt(keras.backend.sum(keras.backend.square(inputs), 2))
        inf = tf.argmax(info,axis=1)[1]
        out = inputs[:,inf,:]
        return out

In [9]:
output = keras.layers.Lambda(lambda x: keras.backend.sqrt(keras.backend.sum(keras.backend.square(x), 2)), output_shape=(10,))

In [10]:
inputs = tf.keras.layers.Input(shape=(28,28,1))
C1 = tf.keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation="relu")(inputs)#24
C2 = tf.keras.layers.Conv2D(filters=256, kernel_size=(7,7), strides=(1,1), activation="relu")(C1)#18
C3 = tf.keras.layers.Conv2D(filters=256, kernel_size=(7,7), strides=(1,1), activation="relu")(C2)#12
C4 = tf.keras.layers.Conv2D(filters=256, kernel_size=(7,7), strides=(1,1), activation="relu")(C3)#6
U3 = tf.keras.layers.UpSampling2D(size=(2,2))(C4)
P3 = tf.keras.layers.Add()([U3,C3])
U1 = tf.keras.layers.UpSampling2D(size=(2,2))(P3)
P1 = tf.keras.layers.Add()([U1,C1])
R4 = tf.keras.layers.Reshape(target_shape=(1152,1,8))(C4)
R3 = tf.keras.layers.Reshape(target_shape=(1152,1,32))(P3)
R2 = tf.keras.layers.Reshape(target_shape=(1152,1,128))(P1)
F4 = digitCaps(in_type=1152, in_shape=8, out_type=10, out_shape=16)(R4)
F3 = digitCaps(in_type=1152, in_shape=32, out_type=10, out_shape=16)(R3)
F2 = digitCaps(in_type=1152, in_shape=128, out_type=10, out_shape=16)(R2)
query =tf.stack(F4+F3+F2)
query = tf.keras.layers.Reshape(target_shape=(10,1,16))(query)
key = tf.stack([F4,F3,F2])
key = tf.transpose(key,[1,2,0,3])
o = tf.keras.layers.Attention()([query,key])
o = tf.keras.layers.Reshape(target_shape=(10,16))(o)
output = keras.layers.Lambda(lambda x: keras.backend.sqrt(keras.backend.sum(keras.backend.square(x), 2)), output_shape=(10,))(o)
model1 =  tf.keras.models.Model(inputs=inputs, outputs = output)
model1.summary()
model1.compile(optimizer=tf.keras.optimizers.Adam(0.00001),
            loss=lambda y_true,y_pred:K.mean(K.sum(y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1)),1)),
            metrics=['accuracy'])

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 24, 24, 256)  6656        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 18, 18, 256)  3211520     conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 12, 12, 256)  3211520     conv2d_1[0][0]                   
______________________________________________________________________________________________

In [11]:
x_val = x_train[:12000]
y_val = y_train[:12000]
x_train = x_train[12000:]
y_train = y_train[12000:]

In [12]:
model1his = model1.fit(x_train, y_train, batch_size=128, epochs=20, validation_data=[x_test, y_test])

Train on 48000 samples, validate on 10000 samples
Epoch 1/20
  128/48000 [..............................] - ETA: 2:38:40

KeyboardInterrupt: 

In [None]:
inputs = tf.keras.layers.Input(shape=(28,28,1))
C1 = keras.layers.Conv2D(filters=256, kernel_size=(9,9), strides=(1,1), activation="relu")(inputs)
C2 = keras.layers.Conv2D(filters=256, kernel_size=(9,9), strides=(2,2), activation="relu")(C1)
R1 = keras.layers.Reshape(target_shape=(1152,1,8))(C2)
P1 = digitCaps(in_type=1152, in_shape=8, out_type=10, out_shape=16)(R1)
out_caps = keras.layers.Lambda(lambda x: keras.backend.sqrt(keras.backend.sum(keras.backend.square(x), 2)), output_shape=(10,))(P1)
capsule =  tf.keras.models.Model(inputs=inputs, outputs = out_caps)
capsule.summary()
capsule.compile(optimizer=tf.keras.optimizers.Adam(0.00001),
            loss=[lambda y_true,y_pred:K.mean(K.sum(y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1)),1))],
            metrics=['accuracy'])

In [None]:
model1his = capsule.fit(x_train, y_train, batch_size=128, epochs=20, validation_data=[x_test, y_test])

In [None]:
import matplotlib.pyplot as plt
loss = [0.1943,0.0429,0.0283,0.0223,0.0186,0.0165,0.0145,0.0132,0.0120,0.0112,0.0100,0.0090, 0.0088,0.0079,0.0073,0.0070,0.0064,0.0059,0.0055,0.0053]
val_loss = [0.0515,0.0298,0.0223,0.0184,0.0156,0.0159,0.0134,0.0127,0.0115, 0.0109,0.0098,0.0091,0.0098,0.0092,0.0086,0.0083,0.0094,0.0078,0.0073,0.0081]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig("result2.png")

In [None]:
acc = [0.8196,0.9551,0.9719,0.9780,0.9814,0.9839,0.9852,0.9868,0.9876,0.9895,0.9902,0.9912,0.9911,0.9921,0.9928,0.9932,0.9937,0.9946,0.9948,0.9949]
val_acc = [0.9474,0.9704,0.9788,0.9813,0.9852,0.9845,0.9874,0.9872,0.9877,0.9889,0.9901,0.9907,0.9899,0.9915,0.9914,0.9915,0.9905,0.9922,0.9927,0.9928]
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig("result1.png")

In [None]:
acc = model1his.history['accuracy']
val_acc = model1his.history['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig("result3.png")

In [None]:
import matplotlib.pyplot as plt
loss = model1his.history['loss']
val_loss = model1his.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig("result4.png")

In [None]:
model1.evaluate(x_test,y_test)

In [None]:
capsule.evaluate(x_test,y_test)