# CapsuleNet

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plot
from IPython import display

%matplotlib inline

import time
import os

In [39]:
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets("./data/", one_hot=False, seed=0)

Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz


In [52]:
Trueimage_size = 28
label_size = 10
# Input image dimensions (single channel)
input_dims = [image_size, image_size, 1]

# Features and labels
def inputs():
    features = tf.placeholder(tf.float32, shape=[None] + input_dims, name="features")
    labels   = tf.placeholder(tf.int32,   shape=[None], name="labels")
    return features, labels

def conv2d_relu_layer(x, kernel_size, map_count=1, stride=1, padding='SAME', name=''):
    with tf.name_scope(name):
        # Assuming x is already a 4D tensor
        tf.assert_rank(x, 4)
        shape = [kernel_size, kernel_size, x.shape[3].value, map_count]
        W = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0.01, shape=[map_count]),   name="b")
        conv = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding)
        return tf.nn.relu(conv + b)

# >------ Capsule networks. -------------    
# Primary capsule.
def primary_capsule(x, caps_size, num_caps, kernel_size, stride, name):
    with tf.name_scope(name):
        shape = [kernel_size, kernel_size, x.shape[3].value, caps_size * num_caps]
        print(shape)
        W    = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="W")
        conv = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='VALID')
        # Flatten to [N, total_num_caps, caps_size] shape.
        conv = tf.reshape(conv, [tf.shape(conv)[0], conv.shape[1].value * conv.shape[2].value * num_caps, caps_size])
        return capsule_act(conv)

# Capsule activation (squashing) function.
def capsule_act(s):
    with tf.name_scope('capsule_act'):
        # Assuming input is [N, num_caps, caps_size] dims.
        tf.assert_rank(s, 3)
        l2_n = tf.expand_dims(tf.norm(s, ord=2, axis=2), -1)
        v    = s * l2_n / (1 + l2_n * l2_n)
        return v
    
def capsule_layer(u, caps_size, num_caps, name):
    with tf.name_scope(name):
        def get_u_hat(u, u_num_caps, v_num_caps, W):
            with tf.name_scope('get_u_hat'):
                # Assuming input is [u_num_caps, u_caps_size] dims.
                tf.assert_rank(u, 2)
                # Tile input (horrible hack!)
                u = tf.tile(tf.expand_dims(u, 0), [v_num_caps, 1, 1])
                # Reshape to make compatible with W.
                with tf.control_dependencies([tf.assert_equal(W.shape[0].value, u_num_caps * v_num_caps)]):
                    u = tf.reshape(u, [W.shape[0].value, u.shape[2].value])
                # Compute u_hat.
                u_hat = tf.matmul(tf.expand_dims(u, -2), W)
                # Squeeze and reshape.
                u_hat = tf.reshape(tf.squeeze(u_hat), [v_num_caps, u_num_caps, -1])
                return u_hat
        # Assuming input is [N, u_num_caps, u_caps_size] dims.
        tf.assert_rank(u, 3)
        u_num_caps  = u.shape[1].value
        u_caps_size = u.shape[2].value
        # There is a separate W_ij matrix for each (u, v) capsule.
        W_shape = [u_num_caps * num_caps, u_caps_size, caps_size]
        W = tf.Variable(tf.truncated_normal(W_shape, stddev=0.1), name="W")
        # Use the same W matrix for each sample in a batch. Is there a way to avoid map_fn?
        u_hat = tf.map_fn(lambda sample_caps: get_u_hat(sample_caps, u_num_caps, num_caps, W), u)
        # Routing logits.
        # Get b_ij dimensions and init. There is a separate b_ij for each sample in a batch.
        b_ij_dims = tf.stack([tf.shape(u)[0], num_caps, u_num_caps])
        b_ij      = tf.fill(b_ij_dims, 0.0)
        # 1
        c_ij = routing_softmax(b_ij)
        s = tf.reduce_sum(tf.expand_dims(c_ij, -1) * u_hat, 2)
        v = capsule_act(s)
        a_ij = tf.squeeze(tf.matmul(u_hat, tf.expand_dims(v, -1)), [-1])
        b_ij = b_ij + a_ij
        # 2
        c_ij = routing_softmax(b_ij)
        s = tf.reduce_sum(tf.expand_dims(c_ij, -1) * u_hat, 2)
        v = capsule_act(s)
        a_ij = tf.squeeze(tf.matmul(u_hat, tf.expand_dims(v, -1)), [-1])
        b_ij = b_ij + a_ij
        # 3
        c_ij = routing_softmax(b_ij)
        s = tf.reduce_sum(tf.expand_dims(c_ij, -1) * u_hat, 2)
        v = capsule_act(s)
        a_ij = tf.squeeze(tf.matmul(u_hat, tf.expand_dims(v, -1)), [-1])
        b_ij = b_ij + a_ij
        return v

#def routing(u_hat, r):
#    def body(i, b_ij):
#        c_ij = routing_softmax(b_ij)
#        return i + 1, b_ij
#    
#    num_caps   = u_hat.shape[1].value
#    u_num_caps = u_hat.shape[2].value
#    caps_size  = u_hat.shape[3].value
#    b_ij = tf.Variable(tf.constant(0.0, shape=[num_caps, u_num_caps]), trainable=False)
#    b_ij = tf.Variable(tf.constant(0.0, shape=[num_caps, u_num_caps]))
#    i = tf.constant(0)
#    cond = lambda i, b_ij: tf.less(i, r)
#    tf.while_loop(cond, body, (i, b_ij))
    
def routing_softmax(b_ij):
    with tf.name_scope('routing_softmax'):
        # Assuming [N, num_caps, u_num_caps] input.
        tf.assert_rank(b_ij, 3)
        c_ij = tf.nn.softmax(b_ij, dim=1)
        return c_ij
    
# Simple capsule net.
def capsule_net_01(features):
    conv1 = conv2d_relu_layer(features, 9, 256, padding='VALID', name='conv1')
    prim_caps   = primary_capsule(conv1, 8, 32, 9, 2, 'primary_caps')
    digits_caps = capsule_layer(prim_caps, 16, 10, 'digits_caps')
    return digits_caps

def margin_loss(v, labels, m_pos=0.9, m_neg=0.1, neg_scale=0.5):
    with tf.name_scope("margin_loss"):
        v_norm = tf.norm(v, ord=2, axis=2)
        pos = tf.maximum(0.0, m_pos - v_norm)
        pos = pos * pos
        neg = tf.maximum(0.0, v_norm - m_neg)
        neg = neg * neg
        loss = labels * pos + neg_scale * (1 - labels) * neg
        return tf.reduce_mean(loss)
# <------ Capsule networks. -------------    

# Training step operation
def create_train_op(loss, learning_rate, momentum):
    tf.summary.scalar(loss.op.name, loss)
    #step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    return step

def get_next_feed(dataset, batch_size, features, labels):
    f, l = dataset.next_batch(batch_size)
    return {features: f.reshape([batch_size] + input_dims), labels: l}
    
# Training action - performs training using provided train step and inputs
def do_train(sess, train_step, loss, next_feed, batch_size, epoch_size, num_epochs, 
             summary_op, summary_writer):
    for epoch in range(0, num_epochs):
        epoch_loss = 0.0
        for i in range(0, epoch_size, batch_size):
            cur_batch_size = min(batch_size, epoch_size - i)
            feed_dict = next_feed(cur_batch_size)
            sess.run(train_step,      feed_dict=feed_dict)
            loss_val = sess.run(loss, feed_dict=feed_dict)
            epoch_loss += np.sum(loss_val)
            if i % 1000 == 0:
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, epoch * epoch_size + i)
                summary_writer.flush()
        print("Epoch %d: loss = %0.05f" % (epoch, epoch_loss / epoch_size))

# Evaluation of a model on a given dataset    
def do_eval(sess, logits, labels, next_feed, batch_size, epoch_size):
    batch_correct = tf.reduce_sum(tf.cast(tf.nn.in_top_k(logits, labels, 1), tf.int32))
    correct = 0
    total = 0
    for i in range(0, epoch_size, batch_size):
        cur_batch_size = min(batch_size, epoch_size - i)
        feed_dict = next_feed(cur_batch_size)
        correct += np.sum(sess.run(batch_correct, feed_dict=feed_dict))
        total += cur_batch_size

    return total, correct
        
def train_and_evaluate(net):
    with tf.Graph().as_default():
        # Build the network for training
        features, labels = inputs()
        
        #logits = fc_layer(net(features), label_size, "digit")
        #probs = probs_layer(logits)
        #loss = ce_sm_loss(logits, labels)
        digits_caps = net(features)
        digits_caps_norm = tf.norm(digits_caps, ord=2, axis=2)
        loss = margin_loss(digits_caps, tf.one_hot(tf.to_int32(labels), 10))
        
        # Create train operation
        train_step = create_train_op(loss, learning_rate=0.001, momentum=0)

        # Summary operation
        summary_op = tf.summary.merge_all()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        
        with tf.Session(config=config) as sess:
            # TensorBoard needs a separate folder for each run.
            log_dir = os.path.join("/data/src/ipnbys/log", time.strftime("%Y%m%d%H%M%S"))
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # Train model
            batch_size = 64
            epoch_size = mnist.train.num_examples
            num_epochs = 1
            next_feed = lambda b: get_next_feed(mnist.train, b, features, labels)
            do_train(sess, train_step, loss, next_feed, batch_size, epoch_size, num_epochs,
                     summary_op, summary_writer)

            # Evaluate model on test dataset
            next_feed = lambda b: get_next_feed(mnist.test, b, features, labels)
            total, correct_count = do_eval(sess, digits_caps_norm, labels, next_feed, 100, mnist.test.num_examples)
            print('Test set accuracy: %0.04f (%d/%d)' % (correct_count / float(total), correct_count, total))

def playground():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            def body(i, v):
                return i + 1, v + 1
            x = tf.Variable(tf.truncated_normal([2, 3, 4], stddev=0.1, seed=0))
            #x = tf.assign(x[0,0,0], [1, 1, 1])
            W = tf.Variable(tf.constant(0.0, shape=[2, 4, 1])) #tf.Variable(tf.truncated_normal([2, 4, 1], stddev=0.1, seed=0))
            m = tf.matmul(x, W)
            #W = W + 1.0
            #i = tf.constant(0)
            #c = lambda i, v: tf.less(i, 10)
            #r = tf.while_loop(c, body, (i, W))
            #m1 = tf.map_fn(lambda caps: tf.matmul(caps, W), x)
            #m2 = tf.tile(tf.expand_dims([[1, 2], [3, 4]], 0), [2, 1, 1])
            f, l = mnist.train.next_batch(batch_size=2)
            l_h = tf.one_hot(l, 10, axis=1)
            sm = tf.nn.softmax_cross_entropy_with_logits(labels=l_h, logits=tf.truncated_normal([2, 10], stddev=0.1, seed=0))
            sess.run(tf.global_variables_initializer())
            #print(sess.run(r))
            print(sm.eval())
            print(l)
            print(l_h.eval())
            #print('--')
            #print(W.eval())
            #print('--')
            #print(m.eval())
            #print(m2.eval())
            #r1, r2, r3 = sess.run([x, W, m])
            #print(sess.run(W))
            #print(r2)
            #print(r3)
            
train_and_evaluate(capsule_net_01)
#playground()


[9, 9, 256, 256]
Epoch 0: loss = 0.00010
Test set accuracy: 0.9821 (9821/10000)
