In [1]:
import numpy as np
from chainer import *
import chainer.functions as F
import chainer.links as L
from chainer.dataset import concat_examples

In [2]:
class Caps(Chain):
    def __init__(self, n_output, channels, n_primary=32, d_primary=8, n_grids=6, d_output=16, n_iterations=3):
        super(Caps, self).__init__()
        self.n_iterations = n_iterations
        self.n_grids = n_grids
        self.n_raw_grids = self.n_grids
        self.n_primary = n_primary
        self.d_primary = d_primary
        self.n_output = n_output
        self.d_output = d_output
        self.init = initializers.Uniform(scale=0.05)
        with self.init_scope():
            self.conv1 = L.Convolution2D(channels, 64, ksize=9, stride=1, initialW=self.init)
            self.conv2 = L.Convolution2D(64, n_primary * d_primary, ksize=9, stride=2, initialW=self.init)
            self.Ws = ChainList(
                *[L.Convolution2D(d_primary, n_output * d_output, ksize=1, stride=1, initialW=self.init)
                  for i in range(n_primary)])
            
    def __call__(self, x, t):
        vs_norm, vs = self.output(x)
        return self.calculate_loss(vs_norm, t, vs, x)
    
    def output(self, x):
        batchsize = x.shape[0]
        n_iters = self.n_iterations
        d_prim = self.d_primary
        n_prim = self.n_primary
        d_out = self.d_output
        n_out = self.n_output
        gg = self.n_grids * self.n_grids
        
        h1 = F.relu(self.conv1(x))
        pr_caps = F.split_axis(self.conv2(h1), n_prim, axis=1)

        preds = []
        for i in range(n_prim):
            pred = self.Ws[i](pr_caps[i])
            pred = pred.reshape((batchsize, d_out, n_out, gg))
            preds.append(Pred)
        preds = F.stack(preds, axis=3)

        bs = self.xp.zeros((batchsize, n_out, n_prim, gg), dtype='f')
        for i_iter in range(n_iters):
            cs = F.softmax(bs, axis=1)
            Cs = F.broadcast_to(cs[:, None], preds.shape)
            ss = F.sum(Cs * preds, axis=(3, 4))
            vs = self.squash(ss)

            if i_iter != n_iters - 1:
                Vs = F.broadcast_to(vs[:, :, :, None, None], preds.shape)
                bs = bs + F.sum(Vs * preds, axis=1)

        vs_norm = F.sqrt(F.sum(vs ** 2, axis=1))
        return vs_norm, vs
    
    def calculate_loss(self, vs_norm, t, vs, x):
        xp = self.xp
        batchsize = t.shape[0]
        I = xp.arange(batchsize)
        T = xp.zeros(vs_norm.shape, dtype='f')
        T[I, t] = 1.
        m = xp.full(vs_norm.shape, 0.1, dtype='f')
        m[I, t] = 0.9

        loss = T * F.relu(m - vs_norm) ** 2 + \
            0.5 * (1. - T) * F.relu(vs_norm - m) ** 2
        return F.sum(loss) / batchsize
    
    def squash(self, ss):
        ss_norm2 = F.sum(ss ** 2, axis=1, keepdims=True)
        norm_div_1pnorm2 = F.sqrt(ss_norm2) / (1. + ss_norm2)
        norm_div_1pnorm2 = F.broadcast_to(norm_div_1pnorm2, ss.shape)
        vs = norm_div_1pnorm2 * ss
        return vs

In [3]:
data = datasets.get_cifar100()

In [None]:
model = Caps(100, 3, n_primary=16, d_primary=6, d_output=8)
optimizer = optimizers.Adam(alpha=1e-3)
optimizer.setup(model)

max_epoch = 10
batchsize = 32

train, test = data
train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)
mean_test_loss = []
mean_train_loss = []
train_losses = []

In [None]:
while train_iter.epoch < max_epoch:

    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)
    image_train = image_train[:,:,2:-2,2:-2]
    
    loss = model(image_train, target_train)
    train_losses.append(loss.data)
    print(loss.data)
    model.cleargrads() #renew gradient calculations
    loss.backward() #runs error backpropagation

    optimizer.update() #update variables
    if train_iter.is_new_epoch:
        print('epoch {0:2d}'.format(train_iter.epoch))
        test_losses = []
        test_accuracies = []
        while True:
            test_batch = test_iter.next()
            image_test, target_test = concat_examples(test_batch)
            image_test = image_test[:,:,2:-2,2:-2]
            loss_test = model(image_test, target_test)
            test_losses.append(loss_test.data)
            
            if test_iter.is_new_epoch:
                test_iter.epoch = 0
                test_iter.current_position = 0
                test_iter.is_new_epoch = False
                test_iter._pushed_position = None
                break

        mean_test_loss.append(np.mean(test_losses))
        mean_train_loss.append(np.mean(train_losses))
        #track mean losses for visualization later
        train_losses = []

0.7952632904052734
0.919095516204834
0.780173659324646
0.7952120304107666
0.7902535796165466
0.7897486090660095
0.7807108163833618
0.7729950547218323
0.7544610500335693
0.7346677184104919
0.7425596714019775
0.7218933701515198
0.7102875709533691
0.6971392035484314
0.6957331299781799
0.6944629549980164
0.6982715725898743
0.7002586126327515
0.701023519039154
0.6970616579055786
0.7029879093170166
0.7004846334457397
0.6884542107582092
0.6963273286819458
0.6957551836967468
0.6852080225944519
0.695458710193634
0.7028652429580688
0.6795175671577454
0.6699709892272949
0.6788751482963562
0.6681269407272339
0.6784891486167908
0.6891242861747742
0.6739841103553772
0.6841603517532349
0.6913002729415894
0.6860209107398987
0.6781191825866699
0.6775566339492798
0.6577186584472656
0.6635931730270386
0.6668238639831543
0.6631033420562744
0.6629984974861145
0.66387939453125
0.6676069498062134
0.6649587154388428
0.6572043895721436
0.6580466032028198
0.6689885854721069
0.6537253856658936
0.6493048071861267

In [None]:
import matplotlib.pyplot as plt
plt.xlabel('nr of epochs')
plt.ylabel('loss')
epochs = range(0,np.size(mean_train_loss))
plt.plot(epochs,mean_train_loss)
plt.plot(epochs,mean_test_loss)
plt.title('Caps Loss')
plt.show()