In [1]:
import numpy as np
from chainer import *
import chainer.functions as F
import chainer.links as L
from chainer.dataset import concat_examples

In [2]:
class Caps(Chain):
    def __init__(self):
        super(Caps, self).__init__()
        self.n_iterations = 3
        self.n_grids = 6  # grid width of primary capsules layer
        self.n_raw_grids = self.n_grids
        self.init = initializers.Uniform(scale=0.05)
        with self.init_scope():
            self.conv1 = L.Convolution2D(3, 256, ksize=9, stride=1, initialW=self.init)
            self.conv2 = L.Convolution2D(256, 32 * 8, ksize=9, stride=2, initialW=self.init)
            self.Ws = ChainList(
                *[L.Convolution2D(8, 16 * 10, ksize=1, stride=1, initialW=self.init)
                  for i in range(32)])
            
    def __call__(self, x, t):
        vs_norm, vs = self.output(x)
        return self.calculate_loss(vs_norm, t, vs, x)
    
    def output(self, x):
        batchsize = x.shape[0]
        n_iters = self.n_iterations
        gg = self.n_grids * self.n_grids
        
        h1 = F.relu(self.conv1(x))
        pr_caps = F.split_axis(self.conv2(h1), 32, axis=1)

        Preds = []
        for i in range(32):
            pred = self.Ws[i](pr_caps[i])
            Pred = pred.reshape((batchsize, 16, 10, gg))
            Preds.append(Pred)
        Preds = F.stack(Preds, axis=3)

        bs = self.xp.zeros((batchsize, 10, 32, gg), dtype='f')
        for i_iter in range(n_iters):
            cs = F.softmax(bs, axis=1)
            Cs = F.broadcast_to(cs[:, None], Preds.shape)
            ss = F.sum(Cs * Preds, axis=(3, 4))
            vs = self.squash(ss)

            if i_iter != n_iters - 1:
                Vs = F.broadcast_to(vs[:, :, :, None, None], Preds.shape)
                bs = bs + F.sum(Vs * Preds, axis=1)

        vs_norm = F.sqrt(F.sum(vs ** 2, axis=1))
        return vs_norm, vs
    
    def calculate_loss(self, vs_norm, t, vs, x):
        xp = self.xp
        batchsize = t.shape[0]
        I = xp.arange(batchsize)
        T = xp.zeros(vs_norm.shape, dtype='f')
        T[I, t] = 1.
        m = xp.full(vs_norm.shape, 0.1, dtype='f')
        m[I, t] = 0.9

        loss = T * F.relu(m - vs_norm) ** 2 + \
            0.5 * (1. - T) * F.relu(vs_norm - m) ** 2
        return F.sum(loss) / batchsize
    
    def squash(self, ss):
        ss_norm2 = F.sum(ss ** 2, axis=1, keepdims=True)
        norm_div_1pnorm2 = F.sqrt(ss_norm2) / (1. + ss_norm2)
        norm_div_1pnorm2 = F.broadcast_to(norm_div_1pnorm2, ss.shape)
        vs = norm_div_1pnorm2 * ss
        return vs

In [3]:
data = datasets.get_cifar10()

In [None]:
model = Caps()
optimizer = optimizers.Adam(alpha=1e-3)
optimizer.setup(model)

max_epoch = 10
batchsize = 32

train, test = data
train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)
mean_test_loss = []
mean_train_loss = []
train_losses = []

In [None]:
while train_iter.epoch < max_epoch:

    train_batch = train_iter.next()
    image_train, target_train = concat_examples(train_batch)
    image_train = image_train[:,:,2:-2,2:-2]
    
    loss = model(image_train, target_train)
    train_losses.append(loss.data)
    print(loss.data)
    model.cleargrads() #renew gradient calculations
    loss.backward() #runs error backpropagation

    optimizer.update() #update variables
    if train_iter.is_new_epoch:
        print('epoch {0:2d}'.format(train_iter.epoch))
        test_losses = []
        test_accuracies = []
        while True:
            test_batch = test_iter.next()
            image_test, target_test = concat_examples(test_batch)
            image_test = image_test[:,:,2:-2,2:-2]
            loss_test = model(image_test, target_test)
            test_losses.append(loss_test.data)
            
            if test_iter.is_new_epoch:
                test_iter.epoch = 0
                test_iter.current_position = 0
                test_iter.is_new_epoch = False
                test_iter._pushed_position = None
                break

        mean_test_loss.append(np.mean(test_losses))
        mean_train_loss.append(np.mean(train_losses))
        #track mean losses for visualization later
        train_losses = []

3.304143190383911
3.6086864471435547
3.6206626892089844
3.59881854057312
3.5822975635528564
3.5487542152404785
3.268223285675049
2.909348487854004
2.153256416320801
1.3822581768035889
1.0791442394256592
0.9709808826446533
0.7704699039459229
0.6184738874435425
0.5658478736877441
0.6060736775398254
0.5596253871917725
0.5563259124755859
0.5550771951675415
0.5442459583282471
0.5604561567306519
0.5687301158905029
0.5271015167236328
0.5607495307922363
0.5243494510650635
0.5176807045936584
0.5676111578941345
0.5882319808006287
0.5556985139846802
0.5517815351486206
0.5712131261825562
0.5615161657333374
0.5466750860214233
0.5635944604873657
0.5678510665893555
0.5464619398117065
0.5362943410873413
0.5554027557373047
0.5552487373352051
0.5967618823051453
0.5650333166122437
0.5680869817733765
0.5451664328575134
0.5478667616844177
0.5594124794006348
0.5141820311546326
0.5683400630950928
0.5717233419418335
0.5569087862968445
0.5408111810684204
0.5486119985580444
0.5751840472221375
0.5622426271438599

In [None]:
plt.xlabel('nr of epochs')
plt.ylabel('loss')
epochs = range(0,np.size(mean_train_loss))
plt.plot(epochs,mean_train_loss)
plt.title('Caps Loss')
plt.show()