# CapsuleNet

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plot
from IPython import display

%matplotlib inline

import time
import os

In [2]:
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
mnist = read_data_sets("./data/", one_hot=False)


Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz


In [71]:
image_size = 28
label_size = 10
# Input image dimensions (single channel)
input_dims = [image_size, image_size, 1]

# Features and labels
def inputs():
    features = tf.placeholder(tf.float32, shape=[None] + input_dims, name="features")
    labels   = tf.placeholder(tf.float32, shape=[None], name="labels")
    return features, labels

# Outputs (probabilities)
def probs_layer(logits):
    return tf.nn.softmax(logits)

# Cross-entropy loss with softmax activation function
def ce_sm_loss(logits, labels):
    with tf.name_scope("loss"):
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=tf.cast(labels, tf.int32), name="ce")
        return tf.reduce_mean(ce, name="ce_mean")

def conv2d_relu_layer(x, kernel_size, map_count=1, stride=1, padding='SAME', name=''):
    with tf.name_scope(name) as scope:
        # assuming x is already a 4D tensor
        shape = [kernel_size, kernel_size, x.shape[3].value, map_count]
        W = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="weights")
        b = tf.Variable(tf.constant(0.01, shape=[map_count]),    name="biases")
        conv = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding, name="conv")
        return tf.nn.relu(conv + b, name=scope)

# FC layer, no activation
def fc_layer(x, size, name):
    with tf.name_scope(name):
        # Flatten dims for the fully-connected layer.
        x_flat = tf.reshape(x, [-1, x.shape[1:].num_elements()])
        shape = [x_flat.shape[1].value, size]
        W = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="weights")
        b = tf.Variable(tf.constant(0.01, shape=[shape[1]]), name="biases")
        return tf.matmul(x_flat, W) + b

# >------ Capsule networks. -------------    
# Primary capsule.
def primary_capsule(x, caps_size, num_caps, kernel_size, stride, name):
    with tf.name_scope(name) as scope:
        shape = [kernel_size, kernel_size, x.shape[3].value, caps_size * num_caps]
        W = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="weights")
        conv = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='VALID', name="conv")
        # Flatten to [N, num_caps, caps_size] shape.
        conv = tf.reshape(conv, [tf.shape(conv)[0], conv.shape[1].value * conv.shape[2].value, caps_size])
        return capsule_act(conv)

# Capsule activation (squashing) function.
def capsule_act(s):
    # Assuming input is [N, num_caps, caps_size] dims.
    tf.assert_rank(s, 3)
    l2_n = tf.expand_dims(tf.norm(s, ord=2, axis=2), -1)
    v    = s * l2_n / (1 + l2_n * l2_n)
    return v
    
def capsule_layer(u, caps_size, num_caps, name):
    with tf.name_scope(name) as scope:
        def get_u_hat(u, u_num_caps, v_num_caps, W):
            # Assuming input is [u_num_caps, u_caps_size] dims.
            tf.assert_rank(u, 2)
            # Tile input (horrible hack!)
            u = tf.tile(tf.expand_dims(u, 0), [v_num_caps, 1, 1])
            # Reshape to make compatible with W.
            with tf.control_dependencies([tf.assert_equal(W.shape[0].value, u_num_caps * v_num_caps)]):
                u = tf.reshape(u, [W.shape[0].value, u.shape[2].value])
            # Compute u_hat.
            u_hat = tf.matmul(tf.expand_dims(u, -2), W)
            # Squeeze and reshape.
            u_hat = tf.reshape(tf.squeeze(u_hat), [v_num_caps, u_num_caps, -1])
            return u_hat
        # Assuming input is [N, u_num_caps, u_caps_size] dims.
        tf.assert_rank(u, 3)
        u_num_caps  = u.shape[1].value
        u_caps_size = u.shape[2].value
        W_shape = [u_num_caps * num_caps, u_caps_size, caps_size]
        W = tf.Variable(tf.truncated_normal(W_shape, stddev=0.1), name="W")
        # Use the same W matrix for each sample in a batch.
        u_hat = tf.map_fn(lambda sample_caps: get_u_hat(sample_caps, u_num_caps, num_caps, W), u)
        print(u_hat.shape)
        return u_hat
    
# Simple capsule net.
def capsule_net_01(features):
    conv1 = conv2d_relu_layer(features, 9, 256, padding='VALID', name='conv1')
    prim_caps   = primary_capsule(conv1, 8, 1, 9, 2, 'primary_caps')
    digits_caps = capsule_layer(prim_caps, 16, 10, 'digits_caps')
    return digits_caps
# <------ Capsule networks. -------------    

# Training step operation
def create_train_op(loss, learning_rate, momentum):
    tf.summary.scalar(loss.op.name, loss)
    step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    return step

def get_next_feed(dataset, batch_size, features, labels):
    f, l = dataset.next_batch(batch_size)
    return {features: f.reshape([batch_size] + input_dims), labels: l}
    
# Training action - performs training using provided train step and inputs
def do_train(sess, train_step, loss, next_feed, batch_size, epoch_size, num_epochs, 
             summary_op, summary_writer):
    for epoch in range(0, num_epochs):
        epoch_loss = 0.0
        for i in range(0, epoch_size, batch_size):
            cur_batch_size = min(batch_size, epoch_size - i)
            feed_dict = next_feed(cur_batch_size)
            _, loss_val = sess.run([train_step, loss], feed_dict=feed_dict)
            epoch_loss += np.sum(loss_val)
            if i % 1000 == 0:
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, epoch * epoch_size + i)
                summary_writer.flush()
            break
        print("Epoch %d: loss = %0.05f" % (epoch, epoch_loss / epoch_size))

# Evaluation of a model on a given dataset    
def do_eval(sess, logits, labels, next_feed, batch_size, epoch_size):
    batch_correct = tf.reduce_sum(tf.cast(tf.nn.in_top_k(logits, tf.cast(labels, tf.int32), 1), tf.int32))
    correct = 0
    total = 0
    for i in range(0, epoch_size, batch_size):
        cur_batch_size = min(batch_size, epoch_size - i)
        feed_dict = next_feed(cur_batch_size)
        correct += np.sum(sess.run(batch_correct, feed_dict=feed_dict))
        total += cur_batch_size

    return total, correct
        
def train_and_evaluate(net):
    with tf.Graph().as_default():
        # Build the network for training
        features, labels = inputs()
        
        logits = fc_layer(net(features), label_size, "logits")
        probs = probs_layer(logits)
        loss = ce_sm_loss(logits, labels)

        # Create train operation
        train_step = create_train_op(loss, learning_rate=0.1, momentum=0)

        # Summary operation
        summary_op = tf.summary.merge_all()

        with tf.Session() as sess:
            # TensorBoard needs a separate folder for each run.
            log_dir = os.path.join("/data/src/ipnbys/log", time.strftime("%Y%m%d%H%M%S"))
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # Train model
            batch_size = 64
            epoch_size = mnist.train.num_examples
            num_epochs = 1
            next_feed = lambda b: get_next_feed(mnist.train, b, features, labels)
            do_train(sess, train_step, loss, next_feed, batch_size, epoch_size, num_epochs,
                     summary_op, summary_writer)

            # Evaluate model on test dataset
            next_feed = lambda b: get_next_feed(mnist.test, b, features, labels)
            total, correct_count = do_eval(sess, probs, labels, next_feed, 100, mnist.test.num_examples)
            print('Test set accuracy: %0.04f (%d/%d)' % (correct_count / float(total), correct_count, total))

def playground():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            x = tf.Variable(tf.truncated_normal([2, 2, 1, 3], stddev=0.1, seed=0))
            #x = tf.assign(x[0,0,0], [1, 1, 1])
            W = tf.Variable(tf.truncated_normal([2, 3, 4], stddev=0.1, seed=0))
            #m = tf.matmul(x, W)
            m1 = tf.map_fn(lambda caps: tf.matmul(caps, W), x)
            m2 = tf.tile(tf.expand_dims([[1, 2], [3, 4]], 0), [2, 1, 1])
            sess.run(tf.global_variables_initializer())
            #print(x.eval())
            #print('--')
            #print(W.eval())
            #print('--')
            print(m1.eval())
            print(m2.eval())
            #r1, r2, r3 = sess.run([x, W, m])
            #print(r1)
            #print(r2)
            #print(r3)
            
train_and_evaluate(capsule_net_01)
#playground()


(?, 10, 36, 16)
Epoch 0: loss = 0.00005
Test set accuracy: 0.2215 (2215/10000)
