In [13]:
from __future__ import print_function
from keras import backend as K
from keras.layers import Layer
from keras import activations
from keras import utils
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam


# the squashing function.
# we use 0.5 in stead of 1 in hinton's paper.
# if 1, the norm of vector will be zoomed out.
# if 0.5, the norm will be zoomed in while original norm is less than 0.5
# and be zoomed out while original norm is greater than 0.5.
def squash(x, axis=-1):
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm)
    return scale * x


# define our own softmax function instead of K.softmax
# because K.softmax can not specify axis.
def softmax(x, axis=-1):
    ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
    return ex / K.sum(ex, axis=axis, keepdims=True)


# define the margin loss like hinge loss
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)


class Capsule(Layer):
    """A Capsule Implement with Pure Keras
    There are two vesions of Capsule.
    One is like dense layer (for the fixed-shape input),
    and the other is like timedistributed dense (for various length input).

    The input shape of Capsule must be (batch_size,
                                        input_num_capsule,
                                        input_dim_capsule
                                       )
    and the output shape is (batch_size,
                             num_capsule,
                             dim_capsule
                            )

    Capsule Implement is from https://github.com/bojone/Capsule/
    Capsule Paper: https://arxiv.org/abs/1710.09829
    """

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 share_weights=True,
                 activation='squash',
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.share_weights = share_weights
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = activations.get(activation)

    def build(self, input_shape):
        input_dim_capsule = input_shape[-1]
        if self.share_weights:
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(1, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)
        else:
            input_num_capsule = input_shape[-2]
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(input_num_capsule, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)

    def call(self, inputs):
        """Following the routing algorithm from Hinton's paper,
        but replace b = b + <u,v> with b = <u,v>.

        This change can improve the feature representation of Capsule.

        However, you can replace
            b = K.batch_dot(outputs, hat_inputs, [2, 3])
        with
            b += K.batch_dot(outputs, hat_inputs, [2, 3])
        to realize a standard routing.
        """

        if self.share_weights:
            hat_inputs = K.conv1d(inputs, self.kernel)
        else:
            hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1])

        batch_size = K.shape(inputs)[0]
        input_num_capsule = K.shape(inputs)[1]
        hat_inputs = K.reshape(hat_inputs,
                               (batch_size, input_num_capsule,
                                self.num_capsule, self.dim_capsule))
        hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3))

        b = K.zeros_like(hat_inputs[:, :, :, 0])
        for i in range(self.routings):
            c = softmax(b, 1)
            o = self.activation(K.batch_dot(c, hat_inputs, [2, 2]))
            if i < self.routings - 1:
                b = K.batch_dot(o, hat_inputs, [2, 3])
                if K.backend() == 'theano':
                    o = K.sum(o, axis=1)

        return b

    def compute_output_shape(self, input_shape):
        return (None, self.num_capsule, self.dim_capsule)


In [20]:
input_image = Input(shape=(128, 128, 3))
x = Conv2D(32, (5, 5), activation='relu')(input_image)
x = Dropout(.2)(x,training=True)
#x = Conv2D(64, (5, 5), activation='relu')(x)
x = Conv2D(32, (5, 5), activation='relu')(x)
x = AveragePooling2D((2, 2))(x)
x = Dropout(.2)(x,training=True)
x = Conv2D(32, (5, 5), activation='relu')(x)


"""now we reshape it as (batch_size, input_num_capsule, input_dim_capsule)
then connect a Capsule layer.

the output of final model is the lengths of 10 Capsule, whose dim=16.

the length of Capsule is the proba,
so the problem becomes a 10 two-classification problem.
"""

x = Reshape((-1, 128))(x)
capsule = Capsule(6, 16, 3, True)(x)
x = Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)))(capsule)
x = Dense(128,activation='relu')(x)
x = Dense(128,activation='relu')(x)
x = Dropout(.2)(x, training=True)
x = Dense(6)(x)
output = Softmax(axis=-1)(x)
#output = softmax(axis=-1)(output)
model = Model(inputs=input_image, outputs=output)
optimizer = Adam(lr=0.0001,decay=1e-5)
# we use a margin loss
model.compile(optimizer,loss= 'categorical_crossentropy', metrics=['accuracy'])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 124, 124, 32)      2432      
_________________________________________________________________
dropout_16 (Dropout)         (None, 124, 124, 32)      0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 120, 120, 32)      25632     
_________________________________________________________________
average_pooling2d_8 (Average (None, 60, 60, 32)        0         
_________________________________________________________________
dropout_17 (Dropout)         (None, 60, 60, 32)        0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 56, 56, 32)        25632     
__________

In [21]:
from keras.utils import plot_model 
from dataLoader import data_loader
import time
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()

In [24]:
# Data flow
parser.add_argument('--train_data_dir', help = "Directory containing training data", type = str, default = 'data/train')
parser.add_argument('--val_data_dir', help = "Directory for validation data", type =str, default = "data/val")
parser.add_argument('--test_data_dir', help = "Directory for test data", type = str, default = "data/test")
args = parser.parse_args()

usage: ipykernel_launcher.py [-h] [--train_data_dir TRAIN_DATA_DIR]
                             [--val_data_dir VAL_DATA_DIR]
                             [--test_data_dir TEST_DATA_DIR]
ipykernel_launcher.py: error: unrecognized arguments: -f /run/user/1003/jupyter/kernel-852189b0-07ca-4655-903d-34cb58235a04.json


SystemExit: 2

In [23]:
dataLoader = data_loader("data/train", "data/val", "data/test", data_size = 200)
train_generator, validation_generator, test_generator = dataLoader.load_images()

start = time.time()
model.fit_generator(train_generator, steps_per_epoch =100, epochs = 10, validation_data  = validation_generator,

                    validation_steps = 2)
print('--------Test data--------')

x = model.evaluate_generator(test_generator, steps = 2, verbose = 1)
print(x)
end = time.time()

Found 15996 images belonging to 6 classes.
Found 1998 images belonging to 6 classes.
Found 1998 images belonging to 6 classes.
Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 