From d5b202146418bdfffad944a77b4ced6aa95b8089 Mon Sep 17 00:00:00 2001 From: InnerPeace-Wu Date: Tue, 31 Oct 2017 22:35:16 +0800 Subject: [PATCH] fix primary cap issue and re-organised all --- .gitignore | 6 + CapsNet.py | 400 +++++++++++++++++++++++++++++++++++++++++++++++ README.md | 28 +++- __init__.py | 0 capules_mnist.py | 237 ---------------------------- config.py | 62 ++++++++ requirements.txt | 5 + train.py | 81 ++++++++++ 8 files changed, 580 insertions(+), 239 deletions(-) create mode 100644 CapsNet.py create mode 100644 __init__.py delete mode 100644 capules_mnist.py create mode 100644 config.py create mode 100644 requirements.txt create mode 100644 train.py diff --git a/.gitignore b/.gitignore index 7bbc71c..9c4cd6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ +# new added +output/ +tensorboard/ +.idea/ +data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CapsNet.py b/CapsNet.py new file mode 100644 index 0000000..2541873 --- /dev/null +++ b/CapsNet.py @@ -0,0 +1,400 @@ +# ------------------------------------------------------------------ +# Capsules_mnist +# By InnerPeace Wu +# ------------------------------------------------------------------ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from six.moves import xrange +from tensorflow.contrib import slim +from tqdm import tqdm + +from config import cfg + + +def squash(cap_input): + """ + squash function for keep the length of capsules between 0 - 1 + :arg + cap_input: total input of capsules, + with shape: [None, h, w, c] or [None, n, d] + :return + cap_output: output of each capsules, which has the shape as cap_input + """ + + # compute norm of inputs with the last axis, keep dims for broadcasting + # ||s_j|| in paper + input_norm = tf.norm(cap_input, ord=2, axis=-1, + keep_dims=True, name='norm') + # input_norm shape: [None, h, w, 1] + # ||s_j||^2 in paper + input_norm_square = tf.square(input_norm, name='norm_square') + + # ||s_j||^2 / (1. + ||s_j||^2) * (s_j / ||s_j||) + with tf.name_scope('squash'): + cap_out = tf.div(input_norm_square, + 1. + input_norm_square) * tf.div(cap_input, input_norm) + + return cap_out + + +class CapsNet(object): + def __init__(self, mnist): + """initial class with mnist dataset""" + self._mnist = mnist + + # keep tracking of the dimension of feature maps + self._dim = 28 + # store number of capsules of each capsule layer + # the conv1-layer has 0 capsules + self._num_caps = [0] + + def _capsule(self, input, i_c, o_c, idx): + """ + compute a capsule, + conv op with kernel: 9x9, stride: 2, + padding: VALID, output channels: 8 per capsule. + As described in the paper. + :arg + input: input for computing capsule, shape: [None, w, h, c] + i_c: input channels + o_c: output channels + idx: index of the capsule about to create + + :return + capsule: computed capsule + """ + with tf.variable_scope('cap_' + str(idx)): + w = tf.get_variable('w', shape=[9, 9, i_c, o_c], dtype=tf.float32) + cap = tf.nn.conv2d(input, w, [1, 2, 2, 1], + padding='VALID', name='cap_conv') + if cfg.USE_BIAS: + b = tf.get_variable('b', shape=[o_c, ], dtype=tf.float32, + initializer=self._b_initializer) + cap = cap + b + # cap with shape [None, 6, 6, 8] for mnist dataset + + # Note: use "squash" as its non-linearity. + capsule = squash(cap) + # capsule with shape: [None, 6, 6, 8] + # expand the dimensions to [None, 1, 6, 6, 8] for following concat + capsule = tf.expand_dims(capsule, axis=1) + + # return capsule with shape [None, 1, 6, 6, 8] + return capsule + + def _dynamic_routing(self, primary_caps, layer_index): + """" + dynamic routing between capsules + :arg + primary_caps: primary capsules with shape [None, 1, 32 x 6 x 6, 1, 8] + layer_index: index of the current capsule layer, i.e. the input layer for routing + :return + digit_caps: the output of digit capsule layer output, with shape: [None, 10, 16] + """ + # number of the capsules in current layer + num_caps = self._num_caps[layer_index] + # weight matrix for capsules in "layer_index" layer + # W_ij + cap_ws = tf.get_variable('cap_w', shape=[10, num_caps, 8, 16], + dtype=tf.float32, + ) + # initial value for "tf.scan", see official doc for details + fn_init = tf.zeros([10, num_caps, 1, 16]) + + # x after tiled with shape: [10, num_caps, 1, 8] + # cap_ws with shape: [10, num_caps, 8, 16], + # [8 x 16] for each pair of capsules between two layers + # u_hat_j|i = W_ij * u_i + cap_predicts = tf.scan(lambda ac, x: tf.matmul(tf.tile(x, [10, 1, 1, 1]), cap_ws), + primary_caps, initializer=fn_init, name='cap_predicts') + # cap_predicts with shape: [None, 10, num_caps, 1, 16] + cap_predictions = tf.squeeze(cap_predicts, axis=[3]) + # after squeeze with shape: [None, 10, num_caps, 16] + + # log prior probabilities + log_prior = tf.get_variable('log_prior', shape=[10, num_caps], dtype=tf.float32, + initializer=tf.zeros_initializer()) + # log_prior with shape: [10, num_caps] + for idx in xrange(cfg.ROUTING_ITERS): + with tf.name_scope('routing_%s' % idx): + # the first iteration + if idx == 0: + c = tf.nn.softmax(log_prior, dim=0) + # c shape: [10, num_caps] + c_t = tf.expand_dims(c, axis=2) + # c_t shape: [10, num_caps, 1] + # iterations > 1 + else: + # [None, 10, num_caps] + c = tf.nn.softmax(log_prior, dim=1) + # [None, 10, num_caps, 1] + c_t = tf.expand_dims(c, axis=3) + + s_t = tf.multiply(cap_predictions, c_t) + # s_t shape: [None, 10, num_caps, 16] + # for each capsule in the layer after, add all the weighted capsules to get + # the capsule input for it. + + # s_j = Sum_i (c_ij u_hat_j|i) + s = tf.reduce_sum(s_t, axis=[2]) + + # s shape: [None, 10, 16] + digit_caps = squash(s) + # digit_caps shape: [None, 10, 16] + + # u_hat_j|i * v_j + delta_prior = tf.reduce_sum(tf.multiply(tf.expand_dims(digit_caps, axis=2), + cap_predictions), + axis=[-1]) + # delta_prior shape: [None, 10, num_caps] + + log_prior = log_prior + delta_prior + + return digit_caps + + def _reconstruct(self, digit_caps): + """ + reconstruct from digit capsules with 3 fully connected layer + :param + digit_caps: digit capsules with shape [None, 10, 16] + :return: + out: out of reconstruction + """ + # (TODO wu) there is two ways to do reconstruction. + # 1. only use the target capsule with dimension [None, 16] or [16,] (use it for default) + # 2. use all the capsule, including the masked out ones with lots of zeros + with tf.name_scope('reconstruct'): + y_ = tf.expand_dims(self._y_, axis=2) + # y_ shape: [None, 10, 1] + + # for method 1. + target_cap = y_ * digit_caps + # target_cap shape: [None, 10, 16] + target_cap = tf.reduce_sum(target_cap, axis=1) + # target_cap: [None, 16] + + # for method 2. + # target_cap = tf.reshape(y_ * digit_caps, [-1, 10*16]) + + fc = slim.fully_connected(target_cap, 512, + weights_initializer=self._w_initializer) + fc = slim.fully_connected(fc, 1024, + weights_initializer=self._w_initializer) + fc = slim.fully_connected(fc, 784, + weights_initializer=self._w_initializer, + activation_fn=None) + # the last layer with sigmoid activation + out = tf.sigmoid(fc) + # out with shape [None, 784] + + return out + + def _add_loss(self, digit_caps): + """ + add the margin loss and reconstruction loss + :arg + digit_caps: output of digit capsule layer, shape [None, 10, 16] + :return + total_loss: + """ + with tf.name_scope('loss'): + # [None, 10 , 1] + # y_ = tf.expand_dims(self._y_, axis=2) + self._digit_caps_norm = tf.norm(digit_caps, ord=2, axis=2, + name='digit_caps_norm') + # digit_caps_norm shape: [None, 10] + # loss of positive classes + # max(0, m+ - ||v_c||) ^ 2 + pos_loss = tf.maximum(0., cfg.M_POS - tf.reduce_sum(self._digit_caps_norm * self._y_, + axis=1), name='pos_max') + pos_loss = tf.square(pos_loss, name='pos_square') + pos_loss = tf.reduce_mean(pos_loss) + tf.summary.scalar('pos_loss', pos_loss) + # pos_loss shape: [None, ] + + # get index of negative classes + y_negs = 1. - self._y_ + # max(0, ||v_c|| - m-) ^ 2 + neg_loss = tf.maximum(0., tf.reduce_sum(self._digit_caps_norm * y_negs, + axis=1) - cfg.M_NEG) + neg_loss = tf.square(neg_loss) * cfg.LAMBDA + neg_loss = tf.reduce_mean(neg_loss) + tf.summary.scalar('neg_loss', neg_loss) + # neg_loss shape: [None, ] + + reconstruct = self._reconstruct(digit_caps) + + # loss of reconstruction + reconstruct_loss = tf.nn.l2_loss(self._x - reconstruct, name='l2_loss') * 2. + tf.summary.scalar('reconstruct_loss', reconstruct_loss) + + total_loss = pos_loss + neg_loss + \ + cfg.RECONSTRUCT_W * reconstruct_loss + + tf.summary.scalar('loss', total_loss) + + return total_loss + + def creat_architecture(self): + """creat architecture of the whole network""" + # set up placeholder of input data and labels + self._x = tf.placeholder(tf.float32, [None, 784]) + self._y_ = tf.placeholder(tf.float32, [None, 10]) + + # set up initializer for weights and bias + self._w_initializer = tf.truncated_normal_initializer(stddev=0.1) + self._b_initializer = tf.zeros_initializer() + + with tf.variable_scope('CapsNet', initializer=self._w_initializer): + # build net + self._build_net() + + # set up exponentially decay learning rate + self._global_step = tf.Variable(0, trainable=False) + learning_rate = tf.train.exponential_decay(cfg.LR, self._global_step, + cfg.STEP_SIZE, cfg.DECAY_RATIO, + staircase=True) + tf.summary.scalar('learning rate', learning_rate) + + # set up adam optimizer with default setting + self._optimizer = tf.train.AdamOptimizer(learning_rate) + gradidents = self._optimizer.compute_gradients(self._loss) + tf.summary.scalar('grad_norm', tf.global_norm(gradidents)) + + self._train_op = self._optimizer.apply_gradients(gradidents, + global_step=self._global_step) + # set up accuracy ops + self._accuracy() + self._summary_op = tf.summary.merge_all() + + self.saver = tf.train.Saver() + + # set up summary writer + self.train_writer = tf.summary.FileWriter(cfg.TB_DIR + '/train') + self.val_writer = tf.summary.FileWriter(cfg.TB_DIR + '/val') + + def _build_net(self): + """build the graph of the network""" + + # reshape for conv ops + with tf.name_scope('x_reshape'): + x_image = tf.reshape(self._x, [-1, 28, 28, 1]) + + # initial conv1 op + # 1). conv1 with kernel 9x9, stride 1, output channels 256 + with tf.variable_scope('conv1'): + # specially initialize it with xavier initializer with no good reason. + w = tf.get_variable('w', shape=[9, 9, 1, 256], dtype=tf.float32, + initializer=tf.contrib.layers.xavier_initializer() + ) + # conv op + conv1 = tf.nn.conv2d(x_image, w, [1, 1, 1, 1], + padding='VALID', name='conv1') + if cfg.USE_BIAS: + # bias (TODO wu) no idea if the paper uses bias or not + b = tf.get_variable('b', shape=[256, ], dtype=tf.float32, + initializer=self._b_initializer) + conv1 = tf.nn.relu(conv1 + b) + else: + conv1 = tf.nn.relu(conv1) + + # update dimensions of feature map + self._dim = (self._dim - 9) // 1 + 1 + assert self._dim == 20, "after conv1, dimensions of feature map" \ + "should be 20x20" + + # conv1 with shape [None, 20, 20, 256] + + # build up primary capsules + with tf.variable_scope('PrimaryCaps'): + # build up PriamryCaps with 32 channels and 8-D vector + caps = [] + for idx in xrange(cfg.PRIMARY_CAPS_CHANNELS): + # get a capsule with 8-D + cap = self._capsule(conv1, 256, 8, idx) + # cap with shape: [None, 1, 6, 6, 8] + caps.append(cap) + + # concat all the primary capsules + primary_caps = tf.concat(caps, axis=1) + # primary_caps with shape: [None, 32, 6, 6, 8] + + # update dim of capsule grid + self._dim = (self._dim - 9) // 2 + 1 + # number of primary caps: 6x6x32 = 1152 + self._num_caps.append(self._dim ** 2 * cfg.PRIMARY_CAPS_CHANNELS) + assert self._dim == 6, "dims for primary caps grid should be 6x6." + + with tf.name_scope('primary_cap_reshape'): + # reshape and expand dims for broadcasting in dynamic routing + primary_caps_reshape = tf.reshape(primary_caps, + shape=[-1, 1, self._num_caps[1], 1, 8]) + # primary_caps_reshape with shape: [None, 1, 1152, 1, 8] + + # dynamic routing + with tf.variable_scope("digit_caps"): + self._digit_caps = self._dynamic_routing(primary_caps_reshape, 1) + + # set up losses + self._loss = self._add_loss(self._digit_caps) + + def _accuracy(self): + with tf.name_scope('accuracy'): + # digit_caps_norm = tf.norm(self._digit_caps, ord=2, axis=-1) + correct_prediction = tf.equal(tf.argmax(self._y_, 1), + tf.argmax(self._digit_caps_norm, 1)) + correct_prediction = tf.cast(correct_prediction, tf.float32) + self.accuracy = tf.reduce_mean(correct_prediction) + tf.summary.scalar('accuracy', self.accuracy) + + def train_with_summary(self, sess, batch_size=100, iters=0): + batch = self._mnist.train.next_batch(batch_size) + loss, _, train_acc, train_summary = sess.run([self._loss, self._train_op, + self.accuracy, self._summary_op], + feed_dict={self._x: batch[0], + self._y_: batch[1]}) + if iters % cfg.PRINT_EVERY == 0 and iters > 0: + val_batch = self._mnist.validation.next_batch(batch_size) + + self.train_writer.add_summary(train_summary, iters) + self.train_writer.flush() + + print("iters: %d / %d, loss ==> %.4f " % (iters, cfg.MAX_ITERS, loss)) + print('train accuracy: %.4f' % train_acc) + + test_acc, test_summary = sess.run([self.accuracy, self._summary_op], + feed_dict={self._x: val_batch[0], + self._y_: val_batch[1]}) + print('val accuracy: %.4f' % test_acc) + self.val_writer.add_summary(test_summary, iters) + self.val_writer.flush() + + if iters % cfg.SAVE_EVERY == 0 and iters > 0: + self.snapshot(sess, iters=iters) + self.test(sess) + + def snapshot(self, sess, iters=0): + save_path = cfg.TRAIN_DIR +'/capsnet' + self.saver.save(sess, save_path, iters) + + def test(self, sess, set='validation'): + if set == 'test': + x = self._mnist.test.images + y_ = self._mnist.test.labels + else: + x = self._mnist.validation.images + y_ = self._mnist.validation.labels + acc = [] + for i in tqdm(xrange(len(x) // 100), desc="calculating %s accuracy" % set): + x_i = x[i * 100: (i + 1) * 100] + y_i = y_[i * 100: (i + 1) * 100] + ac = sess.run(self.accuracy, + feed_dict={self._x: x_i, + self._y_: y_i}) + acc.append(ac) + all_ac = np.mean(np.array(acc)) + print("whole {} accuracy: {}".format(set, all_ac)) diff --git a/README.md b/README.md index 22afcd7..d0f30ae 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,31 @@ # Dynamic Routing Between Capsules reference: [Dynamic routing between capsules](https://arxiv.org/abs/1710.09829v1) by **Sara Sabour, Nicholas Frosst, Geoffrey E Hinton** -Note: this implementation strictly follow the instructions of the paper with `3` times dynamic routing iterations, check the paper for details. +Note: this implementation strictly follow the instructions of the paper, check the paper for details. + +## Dependencies + +* Codes are tested on `tensorflow 1.3`, and `python 2.7`. But it should be compatible with `python 3.x` +* Other dependencies as follows, install it by running `pip install -r requirements.txt` in `ROOT` directory. + +``` +numpy>=1.7.1 +scipy>=0.13.2 +easydict>=1.6 +tqdm>=4.17.1 +``` + +## Train + +* clone the repo +* then + +```bash +cd $ROOT +python train.py +``` + +NOTE: First try with `50` iterations, it got `69.91%` accuracy on test set. ## TODO -- [ ] report experiment results +- [ ] report exclusive experiment results diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/capules_mnist.py b/capules_mnist.py deleted file mode 100644 index e4572c3..0000000 --- a/capules_mnist.py +++ /dev/null @@ -1,237 +0,0 @@ -# -------------------------------------- -# Capsules_mnist -# By InnerPeace Wu -# This file is adapted from tensorflow -# official tutorial of mnist. -# -------------------------------------- -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys -from six.moves import xrange - -from tensorflow.examples.tutorials.mnist import input_data -import tensorflow as tf -from tensorflow.contrib import slim - -FLAGS = None -num_PrimaryCaps = 32 -numcaps = 32*36 -num_routing = 3 -m_pos = 0.9 -m_neg = 0.1 -neg_reg = 0.5 -lr = 0.001 - - -class CapsuleMnist(object): - def __init__(self, mnist): - self._mnist = mnist - - - def _capsule(self, input, i_c, o_c, i): - with tf.variable_scope('cap_' + str(i)): - w = tf.get_variable('w', shape=[9, 9, i_c, o_c], dtype=tf.float32, - initializer=self._def_w_initializer) - # b = tf.get_variable('b', shape=[o_c, ], dtype=tf.float32, - # initializer=self._def_b_initializer) - cap = tf.nn.conv2d(input, w, [1, 2, 2, 1], padding='VALID', name='cap_conv') - # cap = tf.nn.relu(cap + b) - cap = tf.expand_dims(cap, axis=1) - - return cap - - def _squash(self, cap_s): - s_norm = tf.norm(cap_s, ord=2, axis=2, keep_dims=True) - s_norm_square = tf.pow(s_norm, 2) - cap_o = tf.div(s_norm_square, 1 + s_norm_square) * \ - tf.div(cap_s, s_norm) - - return cap_o - - def _dynamic_routing(self, primary_caps): - """"input with shape [None, 1, 32 x 6 x 6, 1, 8]""" - - with tf.name_scope('digit_caps'): - with tf.variable_scope('digit_caps'): - cap_ws = tf.get_variable('cap_w', shape=[10, numcaps, 8, 16], dtype=tf.float32, - initializer=self._def_w_initializer) - fn_init = tf.zeros([10, numcaps, 1, 16]) - # [None, 10, 1152, 1, 16] - cap_predictions = tf.scan(lambda ac, x: tf.matmul(x, cap_ws), - tf.tile(primary_caps, [1, 10, 1, 1, 1]), initializer=fn_init) - # [None, 10 ,1152, 16] - cap_predictions = tf.squeeze(cap_predictions, axis=[3]) - - log_prior = tf.get_variable('log_prior', shape=[numcaps, 10], dtype=tf.float32, - initializer=tf.zeros_initializer()) - for idx in xrange(num_routing): - # [numcaps, 10] - if idx == 0: - c = tf.nn.softmax(log_prior) - # [10, numcaps, 1] - c_t = tf.expand_dims(tf.transpose(c), axis=2) - else: - # [None, 10, numcaps] - c = tf.nn.softmax(log_prior, dim=1) - # [None, 10, numcaps, 1] - c_t = tf.expand_dims(c, axis=3) - - s_t = tf.multiply(cap_predictions, c_t) - # [None, 10, 16] - s = tf.reduce_sum(s_t, axis=[2]) - digit_caps = self._squash(s) - # [None, 10, 1152] - delta_prior = tf.reduce_sum(tf.multiply(tf.expand_dims(digit_caps, axis=2), - cap_predictions), - axis=[-1]) - if idx == 0: - log_prior = tf.transpose(log_prior) - log_prior = log_prior + delta_prior - - return digit_caps - - def _reconstruct(self, digit_caps): - with tf.name_scope('reconstruct'): - y_ = tf.expand_dims(self._y_, axis=2) - # [None, 16] - # gt_feature = tf.reduce_sum(y_ * digit_caps, axis=1) - # [None, 10, 16] - gt_feature = y_ * digit_caps - gt_feature = tf.reduce_sum(gt_feature, axis=1) - # gt_feature = y_ * gt_feature - fc = slim.fully_connected(gt_feature, 512, - weights_initializer=self._def_w_initializer) - fc = slim.fully_connected(fc, 1024, - weights_initializer=self._def_w_initializer) - fc = slim.fully_connected(fc, 784, - weights_initializer=self._def_w_initializer, - activation_fn=None) - out = tf.sigmoid(fc) - - return out - - def _add_loss(self, digit_caps): - with tf.name_scope('loss'): - # [None, 10 , 1] - # y_ = tf.expand_dims(self._y_, axis=2) - # [None, 10] - digit_caps_norm = tf.norm(digit_caps, ord=2, axis=2) - # [None, ] - loss_pos = tf.pow(tf.maximum(0., m_pos - tf.reduce_sum(digit_caps_norm * self._y_, - axis=1),), 2) - y_negs = 1. - self._y_ - # [None, ] - loss_neg = neg_reg * tf.pow(tf.maximum(0., tf.reduce_sum(digit_caps_norm * y_negs, - axis=1) - m_neg), 2) - reconstruct = self._reconstruct(digit_caps) - loss_resconstruct = tf.nn.l2_loss(self._x - reconstruct) * 2. - total_loss = tf.reduce_mean(loss_pos + loss_neg + 0.0005*loss_resconstruct) - - return total_loss - - def _build_net(self): - self._x = tf.placeholder(tf.float32, [None, 784]) - self._y_ = tf.placeholder(tf.float32, [None, 10]) - # set up initializer for weights and bias - self._def_w_initializer = tf.truncated_normal_initializer(stddev=0.1) - self._def_b_initializer = tf.zeros_initializer() - - # reshape for conv ops - with tf.name_scope('x_reshape'): - x_image = tf.reshape(self._x, [-1, 28, 28, 1]) - - # initial conv op - with tf.variable_scope('conv1'): - w = tf.get_variable('w', shape=[9, 9, 1, 256], dtype=tf.float32, - initializer=tf.contrib.layers.xavier_initializer()) - conv1 = tf.nn.conv2d(x_image, w, [1, 1, 1, 1], padding='VALID', name='conv1') - b = tf.get_variable('b', shape=[256, ], dtype=tf.float32, - initializer=self._def_b_initializer) - conv1 = tf.nn.relu(conv1 + b) - - # build up primary capsules - with tf.name_scope('primary_caps'): - caps = [] - for idx in xrange(num_PrimaryCaps): - cap = self._capsule(conv1, 256, 8, idx) - caps.append(cap) - # [None, 32, 6, 6, 8] - primary_caps = tf.concat(caps, axis=1) - # cap_shape = primary_caps.shape - # numcaps = tf.cast(cap_shape[1] * cap_shape[2] * cap_shape[3], tf.int32) - # [None, 32 x 6 x 6, 1, 8] - primary_caps = tf.reshape(primary_caps, shape=[-1, 1, numcaps, 1, 8]) - self._digit_caps = self._dynamic_routing(primary_caps) - self._loss = self._add_loss(self._digit_caps) - - self.global_step = tf.Variable(0, trainable=False) - learning_rate = tf.train.exponential_decay(lr, self.global_step, - 2000, 0.96, staircase=True) - optimizer = tf.train.AdamOptimizer(learning_rate) - self._train_op = optimizer.minimize(self._loss) - self._accuracy() - - def _accuracy(self): - with tf.name_scope('accuracy'): - digit_caps_norm = tf.norm(self._digit_caps, ord=2, axis=2) - correct_prediction = tf.equal(tf.argmax(self._y_, 1), tf.argmax(digit_caps_norm, 1)) - correct_prediction = tf.cast(correct_prediction, tf.float32) - - self.accuracy = tf.reduce_mean(correct_prediction) - - def train_with_predict(self,sess, batch_size=50, idx = 0): - batch = self._mnist.train.next_batch(batch_size) - loss, _ = sess.run([self._loss, self._train_op], feed_dict={self._x: batch[0], - self._y_: batch[1]}) - if idx % 10 == 0: - print('accuracy: {}'.format(self.accuracy.eval(feed_dict={ - self._x: batch[0], self._y_: batch[1]}))) - return loss - - -def model_test(): - model = CapsuleMnist(None) - model._build_net() - print('pass') - - -def main(_): - # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) - - tf.reset_default_graph() - - # Create the model - capsule_mnist = CapsuleMnist(mnist) - capsule_mnist._build_net() - - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - - - with tf.Session(config=config) as sess: - init = tf.global_variables_initializer() - sess.run(init) - # # Train - for i in range(1000): - loss = capsule_mnist.train_with_predict(sess, batch_size=50, idx=i) - if i % 10 == 0: - print("loss: {}".format(loss)) - # # Test trained model - # correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) - # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - # print(sess.run(accuracy, feed_dict={x: mnist.test.images, - # y_: mnist.test.labels})) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, default='./input_data', - help='Directory for storing input data') - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) - - # model_test() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..0ce5b41 --- /dev/null +++ b/config.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------ +# Capsules_mnist +# By InnerPeace Wu +# ------------------------------------------------------------------ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from easydict import EasyDict as edict + +__C = edict() + +# get config by: from config import cfg +cfg = __C + +# number of channels of PrimaryCaps +__C.PRIMARY_CAPS_CHANNELS = 32 + +# iterations of dynamic routing +__C.ROUTING_ITERS = 3 + +# constant m+ in margin loss +__C.M_POS = 0.9 + +# constant m- in margin loss +__C.M_NEG = 0.1 + +# down-weighting constant lambda for negative classes +__C.LAMBDA = 0.5 + +# weight of reconstruction loss +__C.RECONSTRUCT_W = 0.0005 + +# initial learning rate +__C.LR = 0.001 + +# learning rate decay step size +__C.STEP_SIZE = 500 + +# learning rate decay ratio +__C.DECAY_RATIO = 0.96 + +# choose use bias during conv operations +__C.USE_BIAS = True + +# print out loss every x steps +__C.PRINT_EVERY = 10 + +# snapshot every x iterations +__C.SAVE_EVERY = 1000 + +# number of training iterations +__C.MAX_ITERS = 5000 + +# directory for saving data +__C.DATA_DIR = './data' + +# directory for saving check points +__C.TRAIN_DIR = './output' + +# direcotry for saving tensorboard files +__C.TB_DIR = './tensorboard' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2bfb895 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +six +numpy>=1.7.1 +scipy>=0.13.2 +easydict>=1.6 +tqdm>=4.17.1 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..e70440b --- /dev/null +++ b/train.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------------ +# Capsules_mnist +# By InnerPeace Wu +# This file is adapted from tensorflow official tutorial of mnist. +# ------------------------------------------------------------------ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import time + +import tensorflow as tf +from six.moves import xrange +from tensorflow.examples.tutorials.mnist import input_data + +from CapsNet import CapsNet +from config import cfg + +FLAGS = None + + +def model_test(): + model = CapsNet(None) + model.creat_architecture() + print("pass") + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + tf.reset_default_graph() + + # Create the model + caps_net = CapsNet(mnist) + caps_net.creat_architecture() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + + train_dir = cfg.TRAIN_DIR + ckpt = tf.train.get_checkpoint_state(train_dir) + + with tf.Session(config=config) as sess: + if ckpt: + print("Reading parameters from %s" % ckpt.model_checkpoint_path) + caps_net.saver.restore(sess, ckpt.model_checkpoint_path) + else: + print('Created model with fresh paramters.') + sess.run(tf.global_variables_initializer()) + print('Num params: %d' % sum(v.get_shape().num_elements() + for v in tf.trainable_variables())) + + caps_net.train_writer.add_graph(sess.graph) + iters = 0 + tic = time.time() + for iters in xrange(cfg.MAX_ITERS): + sys.stdout.write('>>> %d / %d \r' % (iters % cfg.PRINT_EVERY, cfg.PRINT_EVERY)) + sys.stdout.flush() + caps_net.train_with_summary(sess, batch_size=100, iters=iters) + if iters % cfg.PRINT_EVERY == 0 and iters > 0: + toc = time.time() + print('average time: %.2f secs' % (toc - tic)) + tic = time.time() + + caps_net.snapshot(sess, iters) + caps_net.test(sess, 'test') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default=cfg.DATA_DIR, + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + + # for model building test + # model_test()