# Implementation of CapsNet / CapsuleNet
## using Keras / TensorFlow

### from CapsNet-Keras project on GitHub
https://github.com/XifengGuo/CapsNet-Keras

### by Xifeng Guo

Play with a still hot code implementation in Keras / TensorFlow of CapsNet / CapsuleNet (aka Hinton's capsules) by Xifeng Guo. With a test error < 0.4%

The code is based on the NIPS paper »Dynamic Routing Between Capsules» (abstract) https://goo.gl/tDwtT2, (full PDF): https://goo.gl/6r24Uv

<hr style="border-width:4px;border-color:black;"/>
Keras implementation of CapsNet in Hinton's paper Dynamic Routing Between Capsules. The current version maybe only works for TensorFlow backend. Actually it will be straightforward to re-write to TF code.

Adopting to other backends should be easy, but I have not tested this. 
Usage:

* python CapsNet.py
* python CapsNet.py --epochs 100
* python CapsNet.py --epochs 100 --num_routing 3
       ... ...
       
Result:
    Validation accuracy > 99.5% after 20 epochs. Still under-fitting.
    About 110 seconds per epoch on a single GTX1070 GPU card
    
Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras`
<hr style="border-width:4px;border-color:black;"/>


#### Note: 

This Notebook is a small contribution by Claude Coulombe, PhD candidate, TÉLUQ / UQAM

I have to install the native library `graphviz` on my Macbook Pro

> \>brew install graphviz

## Import the Python libraries

In [45]:
import numpy as np
import os

import tensorflow as tf

from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model

from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask

print("Python libraries imported!")

Python libraries imported!


## Load the MNIST data

In [46]:
def load_mnist():
    # the data, shuffled and split between train and test sets
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
    x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
    y_train = to_categorical(y_train.astype('float32'))
    y_test = to_categorical(y_test.astype('float32'))
    return (x_train, y_train), (x_test, y_test)

print("Loading data")

Loading data


## Define the Capsules Network Model

<img src="result/model.png" width=500/>

In [47]:
def CapsNet(input_shape, n_class, num_routing):
    """
    A Capsule Network on MNIST.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param num_routing: number of routing iterations
    :return: A Keras Model with 2 inputs and 2 outputs
    """
    x = layers.Input(shape=input_shape)

    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_vector]
    primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

    # Layer 3: Capsule layer. Routing algorithm works here.
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_vector=16, num_routing=num_routing, name='digitcaps')(primarycaps)

    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    out_caps = Length(name='out_caps')(digitcaps)

    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked = Mask()([digitcaps, y])  # The true label is used to mask the output of capsule layer.
    x_recon = layers.Dense(512, activation='relu')(masked)
    x_recon = layers.Dense(1024, activation='relu')(x_recon)
    x_recon = layers.Dense(np.prod(input_shape), activation='sigmoid')(x_recon)
    x_recon = layers.Reshape(target_shape=input_shape, name='out_recon')(x_recon)

    # two-input-two-output keras Model
    return models.Model([x, y], [out_caps, x_recon])

print("Capsules Network Model code ready!")

Capsules Network Model code ready!


In [48]:
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

print("Loss function ready!")

Loss function ready!


## Define the training code

In [49]:
def train(model,
          data,
          save_dir,
          batch_size,
          debug,
          learning_rate,
          lam_recon,
          shift_fraction):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=save_dir + '/tensorboard-logs',
                               batch_size=batch_size, 
                               histogram_freq=debug)
    checkpoint = callbacks.ModelCheckpoint(save_dir + '/weights-{epoch:02d}.h5',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: learning_rate * (0.9 ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=learning_rate),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., lam_recon],
                  metrics={'out_caps': 'accuracy'})

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=batch_size, epochs=epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, batch_size, shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / batch_size),
                        epochs=epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % save_dir)

    from utils import plot_log
    plot_log(save_dir + '/log.csv', show=True)

    return model

print("Training code ready!")

Training code ready!


## Define the test code

In [50]:
def test(model, data):
    x_test, y_test = data
    y_pred, x_recon = model.predict([x_test, y_test], batch_size=100)
    print('-'*50)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

    import matplotlib.pyplot as plt
    from utils import combine_images
    from PIL import Image

    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png")
    print()
    print('Reconstructed images are saved to ./real_and_recon.png')
    print('-'*50)
    plt.imshow(plt.imread("real_and_recon.png", ))
    plt.show()
    
print("Testing code ready!")

Testing code ready!


## Setting hyperparameters

Below the default values

In [51]:
batch_size = 100
epochs = 30
lam_recon = 0.392
num_routing = 3
shift_fraction = 0.1
debug = 0
save_dir = './result'
is_training = True
weights = None
learning_rate = 0.001

## Train the model

In [52]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# load data
(x_train, y_train), (x_test, y_test) = load_mnist()

# define model
model = CapsNet(input_shape=[28, 28, 1],
                n_class=len(np.unique(np.argmax(y_train, 1))),
                num_routing=num_routing)
model.summary()
plot_model(model, to_file=save_dir+'/model.png', show_shapes=True)

# train or test
if weights is not None:  # init the model weights with provided one
    model.load_weights(weights)
if is_training:
    train(model=model, 
          data=((x_train, y_train), (x_test, y_test)), 
          save_dir=save_dir,  
          batch_size=batch_size,\
          debug=debug,\
          learning_rate=learning_rate,\
          lam_recon=lam_recon,\
          shift_fraction=shift_fraction
         )

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 20, 20, 256)  20992       input_5[0][0]                    
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (None, 6, 6, 256)    5308672     conv1[0][0]                      
__________________________________________________________________________________________________
primarycap_reshape (Reshape)    (None, 1152, 8)      0           primarycap_conv2d[0][0]          
__________________________________________________________________________________________________
primarycap

KeyboardInterrupt: 

## Test the model

In [None]:
is_training = False
if not is_training:
    # as long as weights are given, will run testing
    if weights is None:
        print('No weights are provided. Will test using random initialized weights.')
    test(model=model, data=(x_test, y_test))