Based on [naturomics](https://github.com/naturomics/CapsNet-Tensorflow)

#### Download and extract data

In [1]:
import os
import sys
import gzip
import shutil
from six.moves import urllib

In [2]:
homepage = 'http://yann.lecun.com/exdb/mnist/'
train_imgs_path = 'train-images-idx3-ubyte.gz'
train_labels_path = 'train-labels-idx1-ubyte.gz'
test_imgs_path = 't10k-images-idx3-ubyte.gz'
test_labels_path = 't10k-labels-idx1-ubyte.gz'

def download_and_uncompress(url,dataset_dir,force=False):
    filename=url.split('/')[-1]
    filepath=os.path.join(dataset_dir,filename)
    if not os.path.exists(dataset_dir):
        os.mkdir(dataset_dir)
    extract_name=os.path.splitext(filepath)[0]
    if not force and os.path.exists(filepath):
        print('file %s already exist'%(filename))
    else:
        filepath,_=urllib.request.urlretrieve(url,filepath)
        print()
        print('Successfully Downloaded',filename)
    
    with gzip.open(filepath,'rb') as f_in, open(extract_name,'wb') as f_out:
        print('Extracting', filename)
        shutil.copyfileobj(f_in,f_out)
        print('Successfully extracted')
        print()

In [3]:
download_and_uncompress(url=homepage+train_imgs_path,dataset_dir='mnist',force=False)
download_and_uncompress(url=homepage+train_labels_path,dataset_dir='mnist',force=False)
download_and_uncompress(url=homepage+test_imgs_path,dataset_dir='mnist',force=False)
download_and_uncompress(url=homepage+test_labels_path,dataset_dir='mnist',force=False)

file train-images-idx3-ubyte.gz already exist
Extracting train-images-idx3-ubyte.gz
Successfully extracted

file train-labels-idx1-ubyte.gz already exist
Extracting train-labels-idx1-ubyte.gz
Successfully extracted

file t10k-images-idx3-ubyte.gz already exist
Extracting t10k-images-idx3-ubyte.gz
Successfully extracted

file t10k-labels-idx1-ubyte.gz already exist
Extracting t10k-labels-idx1-ubyte.gz
Successfully extracted



In [1]:
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm

  from ._conv import register_converters as _register_converters


In [2]:
epsilon = 1e-9
batch_size=8
epoch=1

#parameters in loss function
lambda_val = 0.5
m_plus = 0.9
m_minus = 0.1

#the number of iteration of dynamic routing
iter_routing = 3

logdir='logdir'

dataset_path = 'mnist'
is_training = True

In [3]:
#load mnist data
def load_mnist(path,is_training):  
    f = open(os.path.join(path,'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=f,dtype=np.uint8)
    X_train = loaded[16:].reshape((60000,28,28,1)).astype(np.float)
    
    f = open(os.path.join(path,'train-labels-idx1-ubyte'))
    loaded = np.fromfile(file=f,dtype=np.uint8)
    y_train = loaded[8:].reshape((60000)).astype(np.float)
    
    f = open(os.path.join(path,'t10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=f,dtype=np.uint8)
    X_test = loaded[16:].reshape((10000,28,28,1)).astype(np.float)
    
    f = open(os.path.join(path,'t10k-labels-idx1-ubyte'))
    loaded = np.fromfile(file=f,dtype=np.uint8)
    y_test = loaded[8:].reshape((10000)).astype(np.float)
    
    X_train = tf.convert_to_tensor(X_train/255.,tf.float32)
    X_test = tf.convert_to_tensor(X_test/255.,tf.float32)
    
    y_train = tf.one_hot(y_train,depth=10,axis=1,dtype=tf.float32)
    y_test = tf.one_hot(y_test,depth=10,axis=1,dtype=tf.float32)
    
    if is_training:
        return X_train,y_train
    else:
        return X_test,y_test

In [4]:
#retrieve data by batch size
def get_batch_data():
    X_train,y_train = load_mnist(dataset_path,True)
    data_queues = tf.train.slice_input_producer([X_train,y_train])
    X,y = tf.train.shuffle_batch(data_queues,batch_size=batch_size,
                                capacity=batch_size*64,
                                min_after_dequeue=batch_size*32,
                                allow_smaller_final_batch=False)
    return (X,y)

In [5]:
class CapsLayer(object):
    def __init__(self,num_outputs,vec_len,with_routing=True,layer_type='FC'):
        self.num_outputs = num_outputs # the number of capsules in the current layer
        self.vec_len = vec_len # the len of a capsule output vector
        self.with_routing = with_routing # if get dynamic routing or not
        self.layer_type = layer_type # 'CONV' or 'FC'
        
    def __call__(self,input,kernel_size=None,stride=None):
        # use kernel_size and stride when layer_type is CONV
        if self.layer_type=='CONV':
            self.kernel_size=kernel_size
            self.stride = stride
            
            #no dynamic routing in PrimaryCaps layer
            if not self.with_routing:
               #conv layer is primarycaps layer(the second layer of CapsNet), and the output tensor of the first conv is the input
               #shape of input is [batch_size,20,20,256]
                assert input.get_shape()==[batch_size,20,20,256]
                
                capsules = tf.contrib.layers.conv2d(input, self.num_outputs * self.vec_len,
                                                    self.kernel_size, self.stride, padding="VALID")
                capsules = tf.reshape(capsules, (batch_size, -1, self.vec_len, 1))
           
                #[batch_size,1152,8,1]
                capsules = squash(capsules)
                assert capsules.get_shape()==[batch_size,1152,8,1]
                return (capsules)
        
        if self.layer_type=='FC':
            if self.with_routing:
                #the third layer of CapsNet, DigitCaps, is a fully connected layer
                #[batch_size,1152,1,8,1]
                self.input=tf.reshape(input,shape=(batch_size,-1,1,input.shape[-2].value,1))
                with tf.variable_scope('routing'):
                    #shape of b_ij is [1,1,num_caps_1,num_caps_1_plus_1,1]
                    b_IJ = tf.constant(np.zeros([1,input.shape[1].value,self.num_outputs,1,1],
                                               dtype=np.float32))
                    capsules = routing(self.input,b_IJ)
                    #put s_j into squeeze function to get the output of DigitCaps layer
                    capsules = tf.squeeze(capsules,axis=1)
            return (capsules)

In [6]:
def routing(input,b_IJ,iter_routing=3):
    #shape of W is [num_cap_j,num_cap_i,len_u_i,len_v_j]
    W = tf.get_variable('Weight',shape=(1,1152,10,8,16),dtype=tf.float32,
                       initializer=tf.random_normal_initializer(stddev=0.01))
    input = tf.tile(input,[1,1,10,1,1])
    W = tf.tile(W,[batch_size,1,1,1,1])
    assert input.get_shape()==[batch_size,1152,10,8,1]
    assert W.get_shape()==[batch_size,1152,10,8,16]
    
    u_hat = tf.matmul(W,input,transpose_a=True)
    assert u_hat.get_shape()==[batch_size,1152,10,16,1]
    
    for r_iter in range(iter_routing):
        with tf.variable_scope('iter_'+str(r_iter)):
            c_IJ = tf.nn.softmax(b_IJ,axis=3)
            c_IJ = tf.tile(c_IJ,[batch_size,1,1,1,1])
            assert c_IJ.get_shape()==[batch_size,1152,10,1,1]
            
            s_J = tf.multiply(c_IJ,u_hat)
            s_J = tf.reduce_sum(s_J,axis=1,keepdims=True)
            assert s_J.get_shape()==[batch_size,1,10,16,1]
            
            v_J = squash(s_J)
            assert v_J.get_shape()==[batch_size,1,10,16,1]
            
            v_J_tiled=tf.tile(v_J,[1,1152,1,1,1])
            u_produce_v = tf.matmul(u_hat,v_J_tiled,transpose_a=True)
            assert u_produce_v.get_shape()==[batch_size,1152,10,1,1]
            b_IJ += tf.reduce_sum(u_produce_v,axis=0,keepdims=True)
            
    return (v_J)

In [7]:
def squash(vector):
    vec_squared_norm=tf.reduce_sum(tf.square(vector),-2,keepdims=True)
    scalar_factor = vec_squared_norm/(1+vec_squared_norm)/tf.sqrt(vec_squared_norm+epsilon)
    vec_squashed = scalar_factor * vector
    
    return (vec_squashed)

In [8]:
class CapsNet():
    def __init__(self,is_training=True):
        self.graph = tf.Graph()
        with self.graph.as_default():
            if is_training:
                self.X,self.y=get_batch_data()
                self.build_arch()
                self.loss()
                
                self.optimizer=tf.train.AdamOptimizer()
                self.global_step=tf.Variable(0,name='global_step',trainable=False)
                self.train_op = self.optimizer.minimize(self.total_loss,global_step=self.global_step)
                
            else:
                self.X = tf.placeholder(tf.float32,shape=(batch_size,28,28,1))
                self.build_arch()
                
        tf.logging.info('Setting up the main structure')
        
    def build_arch(self):
        with tf.variable_scope('Conv1_layer'):
            conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,
                                             kernel_size=9, stride=1,
                                             padding='VALID')

            assert conv1.get_shape()==[batch_size,20,20,256]
            
        with tf.variable_scope('PrimaryCaps_layer'):
            primaryCaps=CapsLayer(num_outputs=32,vec_len=8,with_routing=False,layer_type='CONV')
            caps1=primaryCaps(conv1,kernel_size=9,stride=2)
            
            assert caps1.get_shape()==[batch_size,1152,8,1]
            
        with tf.variable_scope('DigitCaps_layer'):
            digitCaps=CapsLayer(num_outputs=10,vec_len=16,with_routing=True,layer_type='FC')
            self.caps2=digitCaps(caps1)
            
        with tf.variable_scope('Masking'):
            mask_with_y=True
            if mask_with_y:
                self.masked_v=tf.matmul(tf.squeeze(self.caps2),tf.reshape(self.y,(-1,10,1)),transpose_a=True)
                self.v_length = tf.sqrt(tf.reduce_sum(tf.square(self.caps2),axis=2,keepdims=True)+epsilon)
        
        with tf.variable_scope('Decoder'):
            vector_j = tf.reshape(self.masked_v,shape=(batch_size,-1))
            fc1=tf.contrib.layers.fully_connected(vector_j,num_outputs=512)
            assert fc1.get_shape()==[batch_size,512]
            fc2=tf.contrib.layers.fully_connected(fc1,num_outputs=1024)
            assert fc2.get_shape()==[batch_size,1024]
            self.decoded=tf.contrib.layers.fully_connected(fc2,num_outputs=784,activation_fn=tf.sigmoid)
            
    def loss(self):
        max_l = tf.square(tf.maximum(0.,m_plus-self.v_length))
        max_r = tf.square(tf.maximum(0.,self.v_length-m_minus))
        assert max_l.get_shape()==[batch_size,10,1,1]
        
        max_l = tf.reshape(max_l,shape=(batch_size,-1))
        max_r = tf.reshape(max_r,shape=(batch_size,-1))
        
        T_c = self.y
        L_c = T_c*max_l + lambda_val*(1-T_c)*max_r
        
        self.margin_loss = tf.reduce_mean(tf.reduce_sum(L_c,axis=1))
        
        orgin=tf.reshape(self.X,shape=(batch_size,-1))
        squared = tf.square(self.decoded-orgin)
        self.reconstruction_err=tf.reduce_mean(squared)
        
        self.total_loss=self.margin_loss+0.0005*self.reconstruction_err
        
        tf.summary.scalar('margin_loss',self.margin_loss)
        tf.summary.scalar('reconstruction_loss',self.reconstruction_err)
        tf.summary.scalar('total_loss',self.total_loss)
        recon_img = tf.reshape(self.decoded,shape=(batch_size,28,28,1))
        tf.summary.image('reconstruction_img',recon_img)
        self.merged_sum=tf.summary.merge_all()

In [None]:
if __name__ == '__main__':
    capsNet=CapsNet(is_training=is_training)
    tf.logging.info('Graph loaded')
    
    sv = tf.train.Supervisor(graph=capsNet.graph,
                            logdir=logdir,
                            save_model_secs=0)
    with sv.managed_session() as sess:
        num_batch=int(60000/batch_size)
        for epoch in range(epoch):
            if sv.should_stop():
                break
            for step in tqdm(range(num_batch),total=num_batch,ncols=70,leave=False,unit='b'):
                sess.run(capsNet.train_op)
            global_step=sess.run(capsNet.global_step)
            sv.saver.save(sess,logdir+'/model_epoch_%04d_step_%02d'%(epoch,global_step))
    tf.logging.info('Training done!')

INFO:tensorflow:Setting up the main structure
INFO:tensorflow:Graph loaded
Instructions for updating:
Please switch to tf.train.MonitoredTrainingSession
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.


  0%|                                         | 0/7500 [00:00<?, ?b/s]

INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Recording summary at step 0.


  5%|█▋                             | 396/7500 [02:00<34:46,  3.41b/s]

INFO:tensorflow:Recording summary at step 396.


  5%|█▋                             | 397/7500 [02:00<37:17,  3.17b/s]

INFO:tensorflow:global_step/sec: 3.52252


  8%|██▌                            | 633/7500 [03:08<33:25,  3.42b/s]