In [1]:
import numpy as np
import keras.backend as K
#import tensorflow.keras.backend as K
import tensorflow as tf
from keras import layers, models, optimizers, initializers
from keras.utils import to_categorical

print("backend:", K.backend())
tf_session = K.get_session()
#from utils import plot_log

Using TensorFlow backend.


backend: tensorflow


In [2]:
def squash(vector, axis=-1):
  squared_norm = K.sum(K.square(vector), axis, keepdims=True)
  scale = squared_norm / (1 + squared_norm) / K.sqrt(squared_norm + K.epsilon())
  return vector * scale


In [3]:
def PrimaryCapsules(inputs, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid'):
  output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, 
                       padding=padding, name='primarycap_conv2d')(inputs)
  outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output)
  return layers.Lambda(squash, name='primarycap_squash')(outputs)


In [4]:
class CapsuleLayer(layers.Layer):
  def __init__(self, num_capsule, dim_capsule, routings=3, kernel_initializer='glorot_uniform', **kwargs):
    super(CapsuleLayer, self).__init__(**kwargs)
    self.num_capsule = num_capsule
    self.dim_capsule = dim_capsule
    self.routings = routings
    self.kernel_initializer = initializers.get(kernel_initializer)
    
  def build(self, input_shape):
    assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
    self.input_num_capsule = input_shape[1]
    self.input_dim_capsule = input_shape[2]
    
    self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule, self.dim_capsule, self.input_dim_capsule], 
                            initializer=self.kernel_initializer, name='W')
    
    print("w shape: ", K.int_shape(self.W)) # (10, 1152, 16, 8)
    
    #self.W_init = tf.random_normal(shape=(1, 1152, 10, 16, 8),
    #                         stddev=0.01, dtype=tf.float32, name="W_init")
    #self.W = tf.Variable(self.W_init, name="W")
    #self.W_tiled = tf.tile(self.W, [batch_size, 1, 1, 1, 1], name="W_tiled")
    #print("W_tiled shape: ", self.W_tiled.shape)
    
    self.built = True
    
  def call(self, inputs, training=None):
    
    print("inputs shape: ", K.int_shape(inputs)) # (None, 1152, 8)
    
    inputs_expand = K.expand_dims(inputs, 1)
    #inputs_expand = layers.Lambda(lambda x: K.expand_dims(x, 1))(inputs)
    print("inputs_expand shape: ", K.int_shape(inputs_expand)) # (None, 1, 1152, 8)
    
    inputs_tiled =  K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
    #inputs_tiled = layers.Lambda(lambda x: K.tile(x, [1, self.num_capsule, 1, 1]))(inputs_expand)
    print("inputs_tiled shape: ", K.int_shape(inputs_tiled)) # (None, 10, 1152, 8)
    
    inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)
    print("inputs_hat shape: ", K.int_shape(inputs_hat)) # (None, 10, None, 1152, 16)
    
    #inputs_expand = tf.expand_dims(inputs, -1, name="inputs_expand")
    #inputs_tile = tf.expand_dims(inputs_expand, 2, name="inputs_tile")
    #inputs_tiled = tf.tile(inputs_tile, [1, 1, 10, 1, 1], name="inputs_tiled")
    #print("inputs_tiled shape: ", inputs_tiled.shape)
    #inputs_hat = tf.matmul(self.W_tiled, inputs_tiled, name="inputs_hat") #prediction
    #print("inputs_hat shape: ", inputs_hat.shape)
    
    
    # Start: Routing algorithm
    b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
    print("b (weights) shape: ", K.int_shape(b)) # (None, 10, 1152)
    #raw_weights = tf.zeros([batch_size, 1152, 10, 1, 1], dtype=np.float32, name="raw_weights")
    
    assert self.routings > 0, 'Routings should be > 0'
    for i in range(self.routings):
            c = tf.nn.softmax(b, dim=1)
            #c = tf.nn.softmax(raw_weights, dim=2, name="c")

            # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule] <-- nicht erfüllt!
            # The first two dimensions as `batch` dimension,
            # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
            outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))  # [None, 10, 16]
            
            #weighted_predictions = tf.multiply(c, inputs_hat, name="weighted_predictions")
            #weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keep_dims=True, name="weighted_sum")
            #outputs = squash(weighted_sum)
            #print("outputs shape: ", outputs.shape)
            #outputs_tiled = tf.tile(outputs, [1, 1152, 1, 1, 1], name="outputs_tiled")
            #print("outputs_tiled shape: ", outputs_tiled)

            if i < self.routings - 1:
                # The first two dimensions as `batch` dimension,
                # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
                
                print("outputs shape: ", K.int_shape(outputs)) # (None, 10, 10, 1152, 16)
                b += K.batch_dot(outputs, inputs_hat, [2, 3])
                #raw_weights += tf.matmul(inputs_hat, outputs_tiled, transpose_a=True)
                
      # End: Routing algorithm

    return outputs
  
  def compute_output_shape(self, input_shape):
    return tuple([None, self.num_capsule, self.dim_capsule])
  
  def get_config(self):
    config = {
        'num_capsule': self.num_capsule,
        'dim_capsule': self.dim_capsule,
        'routings': self.routings
    }
    base_config = super(CapsuleLayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  
  
    

In [5]:
class Length(layers.Layer):
  def call(self, inputs, **kwargs):
      return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon())

  def compute_output_shape(self, input_shape):
      return input_shape[:-1]

  def get_config(self):
      config = super(Length, self).get_config()
      return config  

In [6]:
class Mask(layers.Layer):
  def call(self, inputs, **kwargs):
    if type(inputs) is list: # label is provided with shape
      assert len(inputs) == 2
      inputs, mask = inputs
    else:
      x = K.sqrt(K.sum(K.square(inputs), -1))
      mask = K.one_hot(indices=K.argmax(x, 1), num_classes = x.get_shape().as_list()[1])
      
    masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
    return masked
  
  def compute_output_shape(self, input_shape):
    if type(input_shape[0]) is tuple: #with class label
      return tuple([None, input_shape[0][1] * input_shape[0][2]])
    else:
      return tuple([None, input_shape[1] * input_shape[2]])
    
  def get_config(self):
    config = super(Mask, self).get_config()
    return config

In [7]:
def CapsNet(input_shape, n_class, routings):
  x = layers.Input(shape=input_shape)
  
  
  # Start: Encoder ---------------------------
  #Layer 1: Conv2D
  conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
  
  #Layer 2: Conv2D with squash activation, then reshape to [None, num_capsule, dim_capsule]
  primary_caps = PrimaryCapsules(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
  
  #Layer 3: Capsule Layer, Routing here
  digit_caps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings, name='digitcaps')(primary_caps)
  
  # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
  # If using tensorflow, this will not be necessary. 
  out_caps = Length(name='capsnet')(digit_caps)
  # End: Encoder ---------------------------
  
  
  #Start: Decoder --------------------------
  y = layers.Input(shape=(n_class,))
  masked_by_y = Mask()([digit_caps, y]) # The true label is used to mask the output of capsule layer. For training
  masked = Mask()(digit_caps) # Mask using the capsule with maximal length. For prediction
  
  decoder = models.Sequential(name='decoder')
  decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class))
  decoder.add(layers.Dense(1024, activation='relu'))
  decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
  decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))
  
  #End: Decoder ----------------------------
  
  train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
  eval_model = models.Model(x, [out_caps, decoder(masked)])
  
  #manipulate model
  
  #noise = layers.Input(shape=(n_class, 16))
  #noised_digitcaps = layers.Add()([digit_caps, noise])
  #masked_noised_y = Mask()([noised_digitcaps, y])
  #manipulate_model = models.Model([x,y,noise], decoder(masked_noised_y))
  return train_model, eval_model#, manipulate_model
  
  
  

In [8]:
def margin_loss(y_true, y_pred):
  L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.9))
  return K.mean(K.sum(L, 1))

In [9]:
def sum_squared_error(y_true, y_pred):
    if not K.is_tensor(y_pred):
        y_pred = K.constant(y_pred)
    y_true = K.cast(y_true, y_pred.dtype)
    return K.square(y_pred - y_true)

In [10]:
def train(model, data, save_dir):
  (x_train, y_train), (x_test, y_test) = data
  
  log = callbacks.CSVLogger(save_dir + '/log.csv')
  tb = callbacks.TensorBoard(log_dir = save_dir + '/tensorboard-logs', batch_size = batch_size, histogram_freq = int(debug))
  checkpoint = callbacks.ModelCheckpoint(save_dir + '/weights-{epoch:02d}.h5', monitor = 'val_capsnet_acc', save_best_only = True, save_weights_only = True, verbose = 1)
  #lr_decay = callbacks.LearningRateScheduler(schedule = lambda epoch : args.lr * (args.lr_decay ** epoch))
  
  model.compile(optimizer = optimizers.Adam(), loss = [margin_loss, sum_squared_error], loss_weights = [1., 0.0005], metrics = {'capsnet': 'accuracy'})
  
  model.fit([x_train, y_train], [y_train, x_train], batch_size = batch_size, epochs = epochs, validation_data = [[x_test, y_test], [y_test, x_test]], callbacks = [log, tb, checkpoint])
  
  model.fit([x_train, y_train], [y_train, x_train], batch_size=batch_size, epochs=epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint])

  model.save_weights(save_dir + '/trained_model.h5')
  print('Trained model saved to \ %s/trained_model.h5\'' %  (save_dir))
  #plot_log(save_dir + '/log.csv', show = True)
  return model

In [11]:
def test(model, data, save_dir):
    x_test, y_test = data
    y_pred, x_recon = model.predict(x_test, batch_size=100)
    print('-'*30 + 'Begin: test' + '-'*30)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])
'''
    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(save_dir + "/real_and_recon.png")
    print()
    print('Reconstructed images are saved to %s/real_and_recon.png' % save_dir)
    print('-' * 30 + 'End: test' + '-' * 30)
    plt.imshow(plt.imread(save_dir + "/real_and_recon.png"))
    plt.show()'''

'\n    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))\n    image = img * 255\n    Image.fromarray(image.astype(np.uint8)).save(save_dir + "/real_and_recon.png")\n    print()\n    print(\'Reconstructed images are saved to %s/real_and_recon.png\' % save_dir)\n    print(\'-\' * 30 + \'End: test\' + \'-\' * 30)\n    plt.imshow(plt.imread(save_dir + "/real_and_recon.png"))\n    plt.show()'

In [12]:
def load_mnist():
    # the data, shuffled and split between train and test sets
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
    x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    return (x_train, y_train), (x_test, y_test)

In [13]:
epochs = 50
batch_size = 100
routings = 3
shift_fraction = 0.1
testing = False
debug = True
save_dir = './result'
weights = None

if __name__ == "__main__":
    import os
    import argparse
    from keras.preprocessing.image import ImageDataGenerator
    from keras import callbacks

    # setting the hyper parameters
    #parser = argparse.ArgumentParser(description="Capsule Network on MNIST.")
    #parser.add_argument('--digit', default=5, type=int,
    #                    help="Digit to manipulate")

    #args = parser.parse_args()
    #print(args)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # load data
    (x_train, y_train), (x_test, y_test) = load_mnist()

    # define model
    model, eval_model = CapsNet(input_shape=x_train.shape[1:],
                                                  n_class=len(np.unique(np.argmax(y_train, 1))),
                                                  routings=routings)
    model.summary()

    # train or test
    if weights is not None:  # init the model weights with provided one
        model.load_weights(weights)
    if not testing:
        train(model=model, data=((x_train, y_train), (x_test, y_test)), save_dir=save_dir)
    else:  # as long as weights are given, will run testing
        if weights is None:
            print('No weights are provided. Will test using random initialized weights.')
        #manipulate_latent(manipulate_model, (x_test, y_test), args)
        test(model=eval_model, data=(x_test, y_test), save_dir=save_dir)

w shape:  (10, 1152, 16, 8)
inputs shape:  (None, 1152, 8)
inputs_expand shape:  (None, 1, None, 8)
inputs_tiled shape:  (None, 10, None, 8)
inputs_hat shape:  (None, 10, 1152, 16)
b (weights) shape:  (None, 10, 1152)
Instructions for updating:
dim is deprecated, use axis instead
outputs shape:  (None, 10, 16)
outputs shape:  (None, 10, 16)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 20, 20, 256)  20992       input_1[0][0]                    
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (None, 6, 6, 256)    5308672   


Epoch 00015: val_capsnet_acc did not improve from 0.76840
Epoch 16/50

Epoch 00016: val_capsnet_acc improved from 0.76840 to 0.78920, saving model to ./result/weights-16.h5
Epoch 17/50

Epoch 00017: val_capsnet_acc improved from 0.78920 to 0.82150, saving model to ./result/weights-17.h5
Epoch 18/50

Epoch 00018: val_capsnet_acc did not improve from 0.82150
Epoch 19/50

Epoch 00019: val_capsnet_acc improved from 0.82150 to 0.84860, saving model to ./result/weights-19.h5
Epoch 20/50

Epoch 00020: val_capsnet_acc improved from 0.84860 to 0.85280, saving model to ./result/weights-20.h5
Epoch 21/50

Epoch 00021: val_capsnet_acc did not improve from 0.85280
Epoch 22/50

Epoch 00022: val_capsnet_acc did not improve from 0.85280
Epoch 23/50

Epoch 00023: val_capsnet_acc did not improve from 0.85280
Epoch 24/50

Epoch 00024: val_capsnet_acc improved from 0.85280 to 0.86280, saving model to ./result/weights-24.h5
Epoch 25/50

Epoch 00025: val_capsnet_acc did not improve from 0.86280
Epoch 26/50


Epoch 00038: val_capsnet_acc did not improve from 0.95220
Epoch 39/50

Epoch 00039: val_capsnet_acc did not improve from 0.95220
Epoch 40/50

Epoch 00040: val_capsnet_acc did not improve from 0.95220
Epoch 41/50

Epoch 00041: val_capsnet_acc did not improve from 0.95220
Epoch 42/50

Epoch 00042: val_capsnet_acc improved from 0.95220 to 0.95420, saving model to ./result/weights-42.h5
Epoch 43/50

Epoch 00043: val_capsnet_acc did not improve from 0.95420
Epoch 44/50

Epoch 00044: val_capsnet_acc improved from 0.95420 to 0.95500, saving model to ./result/weights-44.h5
Epoch 45/50

Epoch 00045: val_capsnet_acc did not improve from 0.95500
Epoch 46/50

Epoch 00046: val_capsnet_acc improved from 0.95500 to 0.96810, saving model to ./result/weights-46.h5
Epoch 47/50

Epoch 00047: val_capsnet_acc did not improve from 0.96810
Epoch 48/50

Epoch 00048: val_capsnet_acc did not improve from 0.96810
Epoch 49/50

Epoch 00049: val_capsnet_acc did not improve from 0.96810
Epoch 50/50

Epoch 00050: va


Epoch 00012: val_capsnet_acc did not improve from 0.97210
Epoch 13/50

Epoch 00013: val_capsnet_acc did not improve from 0.97210
Epoch 14/50

Epoch 00014: val_capsnet_acc did not improve from 0.97210
Epoch 15/50

Epoch 00015: val_capsnet_acc improved from 0.97210 to 0.97890, saving model to ./result/weights-15.h5
Epoch 16/50

Epoch 00016: val_capsnet_acc did not improve from 0.97890
Epoch 17/50

Epoch 00017: val_capsnet_acc did not improve from 0.97890
Epoch 18/50

Epoch 00018: val_capsnet_acc did not improve from 0.97890
Epoch 19/50

Epoch 00019: val_capsnet_acc improved from 0.97890 to 0.97930, saving model to ./result/weights-19.h5
Epoch 20/50

Epoch 00020: val_capsnet_acc did not improve from 0.97930
Epoch 21/50

Epoch 00021: val_capsnet_acc did not improve from 0.97930
Epoch 22/50

Epoch 00022: val_capsnet_acc did not improve from 0.97930
Epoch 23/50

Epoch 00023: val_capsnet_acc did not improve from 0.97930
Epoch 24/50

Epoch 00024: val_capsnet_acc improved from 0.97930 to 0.983


Epoch 00036: val_capsnet_acc did not improve from 0.98540
Epoch 37/50

Epoch 00037: val_capsnet_acc did not improve from 0.98540
Epoch 38/50

Epoch 00038: val_capsnet_acc improved from 0.98540 to 0.98640, saving model to ./result/weights-38.h5
Epoch 39/50

Epoch 00039: val_capsnet_acc did not improve from 0.98640
Epoch 40/50

Epoch 00040: val_capsnet_acc improved from 0.98640 to 0.98650, saving model to ./result/weights-40.h5
Epoch 41/50

Epoch 00041: val_capsnet_acc improved from 0.98650 to 0.98730, saving model to ./result/weights-41.h5
Epoch 42/50

Epoch 00042: val_capsnet_acc improved from 0.98730 to 0.98830, saving model to ./result/weights-42.h5
Epoch 43/50

Epoch 00043: val_capsnet_acc did not improve from 0.98830
Epoch 44/50

Epoch 00044: val_capsnet_acc did not improve from 0.98830
Epoch 45/50

Epoch 00045: val_capsnet_acc did not improve from 0.98830
Epoch 46/50

Epoch 00046: val_capsnet_acc did not improve from 0.98830
Epoch 47/50

Epoch 00047: val_capsnet_acc did not impro