Skip to content

Commit

Permalink
Support Tensorflow 2.2
Browse files Browse the repository at this point in the history
1. Change `keras` to `tf.keras`
2. Change `K.*` to `tf.*`. In TensorFlow 2, they actually call the same lower-level functions.
3. Change `K.batch_dot` to `tf.matmul` in class CapsuleLayer. The batch_dot changed its behavior in version 2.3 (or earlier). But I find `tf.matmul` is sufficient to implement the class CapsuleLayer.
  • Loading branch information
XifengGuo committed May 19, 2020
1 parent 923809b commit 9d7e641
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 70 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# CapsNet-Keras
[![License](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/XifengGuo/CapsNet-Keras/blob/master/LICENSE)

A Keras implementation of CapsNet in the paper:
A Keras/TensorFlow2.2 implementation of CapsNet in the paper:
[Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017](https://arxiv.org/abs/1710.09829)
The current `average test error = 0.34%` and `best test error = 0.30%`.

Expand All @@ -28,11 +28,9 @@ Open an issue or contact me with E-mail `guoxifeng1990@163.com` or WeChat `wenlo
## Usage

**Step 1.
Install [Keras>=2.0.7](https://github.com/fchollet/keras)
with [TensorFlow>=1.2](https://github.com/tensorflow/tensorflow) backend.**
Install [TensorFlow>=2.0](https://github.com/tensorflow/tensorflow) backend.**
```
pip install tensorflow-gpu
pip install keras
pip install tensorflow==2.2.0
```

**Step 2. Clone this repository to local.**
Expand Down
66 changes: 33 additions & 33 deletions capsulelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras`
"""

import keras.backend as K
import tensorflow as tf
from keras import initializers, layers
import tensorflow.keras.backend as K
from tensorflow.keras import initializers, layers


class Length(layers.Layer):
Expand All @@ -20,7 +20,7 @@ class Length(layers.Layer):
output: shape=[None, num_vectors]
"""
def call(self, inputs, **kwargs):
return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon())
return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1) + K.epsilon())

def compute_output_shape(self, input_shape):
return input_shape[:-1]
Expand Down Expand Up @@ -50,15 +50,15 @@ def call(self, inputs, **kwargs):
inputs, mask = inputs
else: # if no true label, mask by the max length of capsules. Mainly used for prediction
# compute lengths of capsules
x = K.sqrt(K.sum(K.square(inputs), -1))
x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
# generate the mask which is a one-hot code.
# mask.shape=[None, n_classes]=[None, num_capsule]
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
mask = tf.one_hot(indices=tf.argmax(x, 1), depth=x.shape[1])

# inputs.shape=[None, num_capsule, dim_capsule]
# mask.shape=[None, num_capsule]
# masked.shape=[None, num_capsule * dim_capsule]
masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
masked = K.batch_flatten(inputs * tf.expand_dims(mask, -1))
return masked

def compute_output_shape(self, input_shape):
Expand All @@ -79,18 +79,18 @@ def squash(vectors, axis=-1):
:param axis: the axis to squash
:return: a Tensor with same shape as input vectors
"""
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + K.epsilon())
return scale * vectors


class CapsuleLayer(layers.Layer):
"""
The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the
The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the
neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron
from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \
[None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1.
:param num_capsule: number of capsules in this layer
:param dim_capsule: dimension of the output vectors of the capsules in this layer
:param routings: number of iterations for the routing algorithm
Expand All @@ -109,7 +109,7 @@ def build(self, input_shape):
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]

# Transform matrix
# Transform matrix, from each input capsule to each output capsule, there's a unique weight as in Dense layer.
self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
Expand All @@ -119,48 +119,48 @@ def build(self, input_shape):

def call(self, inputs, training=None):
# inputs.shape=[None, input_num_capsule, input_dim_capsule]
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
inputs_expand = K.expand_dims(inputs, 1)
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule, 1]
inputs_expand = tf.expand_dims(tf.expand_dims(inputs, 1), -1)

# Replicate num_capsule dimension to prepare being multiplied by W
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule, 1]
inputs_tiled = tf.tile(inputs_expand, [1, self.num_capsule, 1, 1, 1])

# Compute `inputs * W` by scanning inputs_tiled on dimension 0.
# x.shape=[num_capsule, input_num_capsule, input_dim_capsule]
# W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
# Regard the first two dimensions as `batch` dimension,
# then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
# x.shape=[num_capsule, input_num_capsule, input_dim_capsule, 1]
# Regard the first two dimensions as `batch` dimension, then
# matmul(W, x): [..., dim_capsule, input_dim_capsule] x [..., input_dim_capsule, 1] -> [..., dim_capsule, 1].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)
inputs_hat = tf.squeeze(tf.map_fn(lambda x: tf.matmul(self.W, x), elems=inputs_tiled))

# Begin: Routing algorithm ---------------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
# b.shape = [None, self.num_capsule, self.input_num_capsule].
b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
# b.shape = [None, self.num_capsule, 1, self.input_num_capsule].
b = tf.zeros(shape=[inputs.shape[0], self.num_capsule, 1, self.input_num_capsule])

assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
# c.shape=[batch_size, num_capsule, input_num_capsule]
c = tf.nn.softmax(b, dim=1)
# c.shape=[batch_size, num_capsule, 1, input_num_capsule]
c = tf.nn.softmax(b, axis=1)

# c.shape = [batch_size, num_capsule, input_num_capsule]
# c.shape = [batch_size, num_capsule, 1, input_num_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16]
# then matmal: [..., 1, input_num_capsule] x [..., input_num_capsule, dim_capsule] -> [..., 1, dim_capsule].
# outputs.shape=[None, num_capsule, 1, dim_capsule]
outputs = squash(tf.matmul(c, inputs_hat)) # [None, 10, 1, 16]

if i < self.routings - 1:
# outputs.shape = [None, num_capsule, dim_capsule]
# outputs.shape = [None, num_capsule, 1, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
# b.shape=[batch_size, num_capsule, input_num_capsule]
b += K.batch_dot(outputs, inputs_hat, [2, 3])
# The first two dimensions as `batch` dimension, then
# matmal:[..., 1, dim_capsule] x [..., input_num_capsule, dim_capsule]^T -> [..., 1, input_num_capsule].
# b.shape=[batch_size, num_capsule, 1, input_num_capsule]
b += tf.matmul(outputs, inputs_hat, transpose_b=True)
# End: Routing algorithm -----------------------------------------------------------------------#

return outputs
return tf.squeeze(outputs)

def compute_output_shape(self, input_shape):
return tuple([None, self.num_capsule, self.dim_capsule])
Expand Down
66 changes: 34 additions & 32 deletions capsulenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
"""

import numpy as np
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image
Expand All @@ -28,16 +29,17 @@
K.set_image_data_format('channels_last')


def CapsNet(input_shape, n_class, routings):
def CapsNet(input_shape, n_class, routings, batch_size):
"""
A Capsule Network on MNIST.
:param input_shape: data shape, 3d, [width, height, channels]
:param n_class: number of classes
:param routings: number of routing iterations
:param batch_size: size of batch
:return: Two Keras Models, the first one used for training, and the second one for evaluation.
`eval_model` can also be used for training.
"""
x = layers.Input(shape=input_shape)
x = layers.Input(shape=input_shape, batch_size=batch_size)

# Layer 1: Just a conventional Conv2D layer
conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
Expand All @@ -46,8 +48,7 @@ def CapsNet(input_shape, n_class, routings):
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

# Layer 3: Capsule layer. Routing algorithm works here.
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
name='digitcaps')(primarycaps)
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings, name='digitcaps')(primarycaps)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
# If using tensorflow, this will not be necessary. :)
Expand All @@ -60,7 +61,7 @@ def CapsNet(input_shape, n_class, routings):

# Shared Decoder model in training and prediction
decoder = models.Sequential(name='decoder')
decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class))
decoder.add(layers.Dense(512, activation='relu', input_dim=16 * n_class))
decoder.add(layers.Dense(1024, activation='relu'))
decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))
Expand All @@ -84,13 +85,15 @@ def margin_loss(y_true, y_pred):
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
# return tf.reduce_mean(tf.square(y_pred))
L = y_true * tf.square(tf.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * tf.square(tf.maximum(0., y_pred - 0.1))

return K.mean(K.sum(L, 1))
return tf.reduce_mean(tf.reduce_sum(L, 1))


def train(model, data, args):
def train(model, # type: models.Model
data, args):
"""
Training a CapsuleNet
:param model: the CapsuleNet model
Expand All @@ -103,8 +106,6 @@ def train(model, data, args):

# callbacks
log = callbacks.CSVLogger(args.save_dir + '/log.csv')
tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
batch_size=args.batch_size, histogram_freq=int(args.debug))
checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
save_best_only=True, save_weights_only=True, verbose=1)
lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))
Expand All @@ -128,14 +129,14 @@ def train_generator(x, y, batch_size, shift_fraction=0.):
generator = train_datagen.flow(x, y, batch_size=batch_size)
while 1:
x_batch, y_batch = generator.next()
yield ([x_batch, y_batch], [y_batch, x_batch])

# Training with data augmentation. If shift_fraction=0., also no augmentation.
model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
steps_per_epoch=int(y_train.shape[0] / args.batch_size),
epochs=args.epochs,
validation_data=[[x_test, y_test], [y_test, x_test]],
callbacks=[log, tb, checkpoint, lr_decay])
yield (x_batch, y_batch), (y_batch, x_batch)

# Training with data augmentation. If shift_fraction=0., no augmentation.
model.fit(train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
steps_per_epoch=int(y_train.shape[0] / args.batch_size),
epochs=args.epochs,
validation_data=((x_test, y_test), (y_test, x_test)), batch_size=args.batch_size,
callbacks=[log, checkpoint, lr_decay])
# End: Training with data augmentation -----------------------------------------------------------------------#

model.save_weights(args.save_dir + '/trained_model.h5')
Expand All @@ -150,10 +151,10 @@ def train_generator(x, y, batch_size, shift_fraction=0.):
def test(model, data, args):
x_test, y_test = data
y_pred, x_recon = model.predict(x_test, batch_size=100)
print('-'*30 + 'Begin: test' + '-'*30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])
print('-' * 30 + 'Begin: test' + '-' * 30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1)) / y_test.shape[0])

img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
img = combine_images(np.concatenate([x_test[:50], x_recon[:50]]))
image = img * 255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
print()
Expand All @@ -164,7 +165,7 @@ def test(model, data, args):


def manipulate_latent(model, data, args):
print('-'*30 + 'Begin: manipulate' + '-'*30)
print('-' * 30 + 'Begin: manipulate' + '-' * 30)
x_test, y_test = data
index = np.argmax(y_test, 1) == args.digit
number = np.random.randint(low=0, high=sum(index) - 1)
Expand All @@ -175,22 +176,22 @@ def manipulate_latent(model, data, args):
for dim in range(16):
for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
tmp = np.copy(noise)
tmp[:,:,dim] = r
tmp[:, :, dim] = r
x_recon = model.predict([x, y, tmp])
x_recons.append(x_recon)

x_recons = np.concatenate(x_recons)

img = combine_images(x_recons, height=16)
image = img*255
image = img * 255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit))
print('-' * 30 + 'End: manipulate' + '-' * 30)


def load_mnist():
# the data, shuffled and split between train and test sets
from keras.datasets import mnist
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
Expand All @@ -203,8 +204,8 @@ def load_mnist():
if __name__ == "__main__":
import os
import argparse
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import callbacks

# setting the hyper parameters
parser = argparse.ArgumentParser(description="Capsule Network on MNIST.")
Expand Down Expand Up @@ -241,7 +242,8 @@ def load_mnist():
# define model
model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
routings=args.routings,
batch_size=args.batch_size)
model.summary()

# train or test
Expand Down

0 comments on commit 9d7e641

Please sign in to comment.