Skip to content

Commit

Permalink
fix #33
Browse files Browse the repository at this point in the history
  • Loading branch information
xifeng committed Dec 8, 2017
1 parent 38b5d66 commit 6e616da
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 44 deletions.
66 changes: 28 additions & 38 deletions capsulenet-multi-gpu.py
Expand Up @@ -19,7 +19,7 @@

K.set_image_data_format('channels_last')

from capsulenet import CapsNet, margin_loss, load_mnist
from capsulenet import CapsNet, margin_loss, load_mnist, manipulate_latent, test


def train(model, data, args):
Expand Down Expand Up @@ -73,26 +73,6 @@ def train_generator(x, y, batch_size, shift_fraction=0.):
return model


def test(model, data):
x_test, y_test = data
y_pred, x_recon = model.predict(x_test, batch_size=100)
print('-'*50)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image

img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
image = img * 255
Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png")
print()
print('Reconstructed images are saved to ./real_and_recon.png')
print('-'*50)
plt.imshow(plt.imread("real_and_recon.png", ))
plt.show()


if __name__ == "__main__":
import numpy as np
import tensorflow as tf
Expand All @@ -104,17 +84,26 @@ def test(model, data):

# setting the hyper parameters
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=300, type=int)
parser = argparse.ArgumentParser(description="Capsule Network on MNIST.")
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--lam_recon', default=0.392, type=float) # 784 * 0.0005, paper uses sum of SE, here uses MSE
parser.add_argument('--num_routing', default=3, type=int) # num_routing should > 0
parser.add_argument('--shift_fraction', default=0.1, type=float)
parser.add_argument('--debug', default=0, type=int) # debug>0 will save weights by TensorBoard
parser.add_argument('--batch_size', default=300, type=int)
parser.add_argument('--lam_recon', default=0.392, type=float,
help="The coefficient for the loss of decoder")
parser.add_argument('-r', '--routings', default=3, type=int,
help="Number of iterations used in routing algorithm. should > 0")
parser.add_argument('--shift_fraction', default=0.1, type=float,
help="Fraction of pixels to shift at most in each direction.")
parser.add_argument('--debug', default=0, type=int,
help="Save weights by TensorBoard")
parser.add_argument('--save_dir', default='./result')
parser.add_argument('--is_training', default=1, type=int)
parser.add_argument('--weights', default=None)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('-t', '--testing', action='store_true',
help="Test the trained model on testing dataset")
parser.add_argument('--digit', default=5, type=int,
help="Digit to manipulate")
parser.add_argument('-w', '--weights', default=None,
help="The path of the saved weights. Should be specified when testing")
parser.add_argument('--lr', default=0.001, type=float,
help="Initial learning rate")
parser.add_argument('--gpus', default=2, type=int)
args = parser.parse_args()
print(args)
Expand All @@ -126,23 +115,24 @@ def test(model, data):

# define model
with tf.device('/cpu:0'):
model, eval_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
num_routing=args.num_routing)
model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
model.summary()
plot_model(model, to_file=args.save_dir+'/model.png', show_shapes=True)

# define muti-gpu model
multi_model = multi_gpu_model(model, gpus=args.gpus)
# train or test
if args.weights is not None: # init the model weights with provided one
model.load_weights(args.weights)
if args.is_training:
if not args.testing:
# define muti-gpu model
multi_model = multi_gpu_model(model, gpus=args.gpus)
train(model=multi_model, data=((x_train, y_train), (x_test, y_test)), args=args)
model.save_weights(args.save_dir + '/trained_model.h5')
print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)
test(model=eval_model, data=(x_test, y_test))
test(model=eval_model, data=(x_test, y_test), args=args)
else: # as long as weights are given, will run testing
if args.weights is None:
print('No weights are provided. Will test using random initialized weights.')
test(model=eval_model, data=(x_test, y_test))
manipulate_latent(manipulate_model, (x_test, y_test), args)
test(model=eval_model, data=(x_test, y_test), args=args)
12 changes: 6 additions & 6 deletions capsulenet.py
Expand Up @@ -6,7 +6,7 @@
Usage:
python CapsNet.py
python CapsNet.py --epochs 50
python CapsNet.py --epochs 50 --num_routing 3
python CapsNet.py --epochs 50 --routings 3
... ...
Result:
Expand All @@ -28,12 +28,12 @@
K.set_image_data_format('channels_last')


def CapsNet(input_shape, n_class, num_routing):
def CapsNet(input_shape, n_class, routings):
"""
A Capsule Network on MNIST.
:param input_shape: data shape, 3d, [width, height, channels]
:param n_class: number of classes
:param num_routing: number of routing iterations
:param routings: number of routing iterations
:return: Two Keras Models, the first one used for training, and the second one for evaluation.
`eval_model` can also be used for training.
"""
Expand All @@ -46,7 +46,7 @@ def CapsNet(input_shape, n_class, num_routing):
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

# Layer 3: Capsule layer. Routing algorithm works here.
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, num_routing=num_routing,
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
name='digitcaps')(primarycaps)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
Expand Down Expand Up @@ -217,7 +217,7 @@ def load_mnist():
parser.add_argument('--lam_recon', default=0.392, type=float,
help="The coefficient for the loss of decoder")
parser.add_argument('-r', '--routings', default=3, type=int,
help="Number of iterations used in routing algorithm. should > 0") # num_routing should > 0
help="Number of iterations used in routing algorithm. should > 0")
parser.add_argument('--shift_fraction', default=0.1, type=float,
help="Fraction of pixels to shift at most in each direction.")
parser.add_argument('--debug', action='store_true',
Expand All @@ -241,7 +241,7 @@ def load_mnist():
# define model
model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
num_routing=args.routings)
routings=args.routings)
model.summary()


Expand Down

0 comments on commit 6e616da

Please sign in to comment.