In [1]:
import torch 
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
import torch.nn.functional as F
import math

from torch.optim import lr_scheduler

In [3]:
import argparse
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:

# MNIST Dataset
train_dataset = dsets.MNIST(root='./data/',
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data/',
                           train=False, 
                           transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=1000, 
                                          shuffle=False)

In [5]:
def squash(x):
    lengths2 = x.pow(2).sum(dim=2)
    lengths = lengths2.sqrt()
    x = x * (lengths2 / (1 + lengths2) / lengths).view(x.size(0), x.size(1), 1)
    return x

In [6]:
class FirstCapsuleLayer(nn.Module) :
    def __init__(self, input_dim, output_dim, output_caps):
        super(FirstCapsuleLayer, self).__init__()
        
        self.layer = nn.Linear(input_dim, output_dim*output_caps)
        self.output_dim = output_dim
        self.output_caps = output_caps
        
    def forward(self, input) : 
        output = self.layer(input)
        output = F.relu(output)
        output = output.view(output.shape[0], self.output_caps, self.output_dim)
        output = squash(output)
        
        return output

In [7]:
class CapsuleLayer(nn.Module) :
    def __init__(self, input_caps, input_dim, output_dim, output_caps):
        super(CapsuleLayer, self).__init__()
        
        self.input_caps = input_caps
        self.output_dim = output_dim
        self.output_caps = output_caps
        self.weights = nn.Parameter(torch.Tensor(input_caps, input_dim, output_caps * output_dim))
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.input_caps)
        self.weights.data.uniform_(-stdv, stdv)
        
    def forward(self, caps_output) : # shape of caps_output is (num_capsules of under_layer * output_dim)
        
        caps_output = caps_output.unsqueeze(2)
        u_predict = caps_output.matmul(self.weights)
        u_predict = u_predict.view(u_predict.size(0), self.input_caps, self.output_caps, self.output_dim)
        
        return u_predict

In [8]:
class routing(nn.Module):
    def __init__(self, b):
        super(routing, self).__init__()
        
        self.b=nn.Parameter(b)
        
    def forward(self, u_predict) :
        
        batch_size, input_caps, output_caps, output_dim = u_predict.size()
        
        c = F.softmax(self.b.view(-1, output_caps), dim=1).view(-1, input_caps, output_caps, 1)
        s = (c * u_predict).sum(dim=1)
        v = squash(s)
        
        v_hat = v.unsqueeze(1)
        reward = (u_predict * v_hat).sum(-1)
        
        return v, reward
    

In [9]:
class Mynet(nn.Module) :
    def __init__(self) :
        super(Mynet, self).__init__()
        
        self.Firstlayer = FirstCapsuleLayer(784, 16, 32)
        self.Capsulelayer0 = CapsuleLayer(32, 16, 16, 16)
        self.Capsulelayer1 = CapsuleLayer(16, 16, 16, 10)
        
    def forward(self, input, b0, b1) :
        
        self.routing0 = routing(b0)
        self.routing1 = routing(b1)
        
        output = self.Firstlayer(input)
        u0 = self.Capsulelayer0(output)
        v0, reward0 = self.routing0(u0)
        #u1 = self.Capsulelayer1(v0)
        #v1, reward1 = self.routing1(u1)
        probs = v0.pow(2).sum(dim=2).sqrt()
        
        return probs

In [10]:
model = Mynet()

In [11]:
class MarginLoss(nn.Module):
    def __init__(self, m_pos, m_neg, lambda_):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, lengths, targets, size_average=True):
        t = torch.zeros(lengths.size()).long()
        if targets.is_cuda:
            t = t.cuda()
        t = t.scatter_(1, targets.data.view(-1, 1), 1)
        targets = Variable(t)
        losses = targets.float() * F.relu(self.m_pos - lengths).pow(2) + \
                 self.lambda_ * (1. - targets.float()) * F.relu(lengths - self.m_neg).pow(2)
        return losses.mean() if size_average else losses.sum()


In [12]:
optimizer = optim.SGD(model.parameters(), lr=0.001)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=15, min_lr=1e-6)

loss_fn = MarginLoss(0.9, 0.1, 0.5)

In [13]:
import numpy as np

In [14]:
a=0.8
batch_size=100
data_size=60000

input_caps0 = 32
output_caps0=16
output_caps1 = 10
num_epochs=5
iteration=1

for epoch in range(num_epochs) :
    
    num_batches = int(data_size//batch_size)
    
    b0 = torch.zeros((data_size, input_caps0, output_caps0))
    b1 = torch.zeros((data_size, output_caps0, output_caps1))
    
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        
        input = Variable(data.view(data.shape[0], -1))
        target = Variable(target, requires_grad=False)
        optimizer.zero_grad()
        
        b0_batch = b0[batch_idx*batch_size:(batch_idx+1)*batch_size]
        b1_batch = b1[batch_idx*batch_size:(batch_idx+1)*batch_size]
        
        for r in range(iteration) :
            output= model(input, b0_batch, b1_batch)
            
            print(output[0])
            if r == (iteration-1) :
                break
            b1_batch=reward1.data
            
            for i in range(batch_size):
                for j in range(b0_batch.shape[1]) :
                    max_value, _ =torch.max(b1_batch[i], dim=1)
                    b0_batch[i][j]=(1-a)*b0_batch[i][j] + a*(reward0.data[i][j] + max_value)
            
        loss = loss_fn(output, target)
        loss.backward()
        
        optimizer.step()
        
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0]))
            
    print('Accuracy : {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

Variable containing:
1.00000e-03 *
  1.7837
  1.1233
  1.1275
  0.6593
  1.4274
  1.3118
  1.3566
  1.8059
  0.9126
  1.1679
  0.6104
  0.7570
  1.2866
  1.2827
  0.5909
  0.7880
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.0086
  0.9499
  0.6722
  1.4452
  0.7104
  0.8043
  0.6273
  1.2505
  0.3995
  0.6695
  0.3708
  0.3545
  1.2206
  1.0467
  0.8623
  1.2075
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.4489
  0.5627
  0.8177
  0.4188
  0.5529
  1.0901
  0.3657
  0.5937
  0.5385
  0.8893
  0.3956
  0.5561
  0.7578
  0.9902
  0.4427
  0.6508
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-04 *
  9.0497
  6.7354
  2.1445
  4.1088
  4.1751
  7.4977
  4.4261
  6.5090
  2.2972
  6.8398
  3.6283
  7.2590
  4.2475
  8.7224
  4.1671
  7.1158
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.2557
  1.4357
  0.6046
  1.3099
  1.0104
  1.7046
  0.4603
  2.0414
  1.0047
  0.3509
  0.2603
  0.6733
  1.3687
  0.

Variable containing:
1.00000e-03 *
  1.0304
  0.9845
  0.8422
  0.6182
  1.3632
  1.3103
  0.8600
  1.5544
  0.5913
  0.5275
  0.5353
  1.0737
  0.9846
  0.9029
  0.3760
  1.4099
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.7953
  0.7424
  0.8262
  0.8867
  0.7655
  0.6240
  0.5865
  1.5204
  1.0110
  0.6277
  0.6347
  0.6114
  0.8731
  1.1244
  0.8111
  0.7437
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-04 *
  2.3410
  2.5404
  1.6130
  3.1585
  2.9189
  3.6123
  3.4469
  5.9156
  3.5589
  1.6217
  2.2093
  2.8550
  3.8736
  3.7779
  2.8298
  5.7043
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.5078
  0.4159
  0.4150
  0.3196
  0.6383
  0.7241
  0.3218
  0.6893
  0.4966
  0.5811
  0.2907
  0.2898
  0.6781
  0.4447
  0.3821
  1.0370
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.9790
  1.3202
  2.4512
  2.8423
  2.4197
  2.5436
  1.9163
  3.4715
  1.5586
  1.5485
  0.7699
  1.8752
  2.6090
  1.

Variable containing:
1.00000e-04 *
  5.4810
  4.1977
  4.8823
  3.1948
  4.1004
  8.2755
  3.6329
  9.5092
  6.5070
  6.3370
  3.3655
  6.2014
  7.4985
  7.3532
  3.5413
  7.3173
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.3122
  1.8263
  1.3214
  1.2343
  2.4021
  2.4109
  1.0891
  3.0672
  1.4614
  1.7293
  0.8853
  1.0221
  1.9530
  1.7823
  1.2499
  3.3615
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.2125
  0.6133
  0.9914
  1.0863
  1.1702
  0.8303
  0.5763
  1.4433
  0.9767
  0.5723
  0.6553
  0.7421
  1.2282
  1.7154
  0.7830
  1.5124
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.8014
  1.2728
  1.1128
  1.0089
  0.8966
  1.1736
  1.1897
  1.8758
  1.5201
  1.1366
  0.6826
  1.2581
  1.4470
  1.9238
  0.7932
  1.6524
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.0455
  1.5349
  1.2986
  1.1467
  1.6556
  1.6991
  0.7871
  3.0875
  1.2701
  1.6260
  0.7118
  1.2652
  2.1962
  0.

Variable containing:
1.00000e-03 *
  0.9456
  1.5128
  2.0606
  0.7723
  1.0640
  0.6761
  0.7148
  1.6083
  0.8860
  1.1291
  0.4539
  0.6482
  1.0763
  1.1286
  0.5673
  1.1320
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.7015
  0.5672
  0.5285
  1.0492
  0.8167
  0.9970
  0.6176
  1.2860
  0.9023
  0.9040
  0.5355
  0.6372
  0.8191
  1.0581
  0.8953
  0.9395
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.1258
  0.9019
  1.5020
  1.4250
  0.7797
  0.6540
  1.1722
  2.4764
  1.2450
  1.1596
  0.5681
  0.9176
  0.6166
  0.7110
  0.4709
  0.9153
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-04 *
  4.2546
  4.1491
  4.8222
  3.4960
  4.6437
  8.9625
  5.2173
  6.6560
  5.9205
  5.3009
  3.5668
  8.0536
  8.9593
  5.9550
  3.2387
  8.9615
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.9857
  1.9151
  0.6265
  1.0921
  1.3821
  2.2811
  1.1850
  2.2375
  1.1900
  1.3091
  1.0664
  1.3534
  2.2441
  1.

Variable containing:
1.00000e-03 *
  1.1385
  0.6049
  1.0719
  0.9722
  0.6575
  0.9171
  1.3867
  1.2438
  1.1438
  0.6663
  0.9733
  0.6167
  1.2718
  1.5827
  0.5622
  1.2338
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.4570
  2.3123
  2.1900
  1.3553
  1.9541
  1.5064
  1.0013
  2.0332
  1.9715
  1.4487
  0.6991
  1.8484
  2.5730
  2.1991
  1.3984
  3.3025
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.4389
  0.5385
  0.3858
  0.5564
  0.6204
  0.8506
  0.5916
  0.8359
  0.7167
  0.3928
  0.3331
  0.4803
  0.6566
  0.7759
  0.5117
  1.3674
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.5004
  0.4212
  0.6129
  0.6016
  0.5311
  0.5851
  0.5617
  1.2077
  0.6374
  0.3600
  0.3723
  0.4675
  0.6044
  0.7912
  0.2967
  1.1043
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.5761
  0.6709
  0.2915
  0.9630
  0.9434
  1.3471
  1.2135
  1.3863
  0.3678
  0.6814
  0.6030
  0.4660
  0.7053
  1.

Variable containing:
1.00000e-04 *
  2.7064
  2.0897
  4.0680
  4.9562
  2.0276
  4.0785
  3.8008
  5.6203
  2.2727
  2.8929
  2.7056
  2.9922
  4.1776
  3.6706
  3.5010
  3.6794
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.2799
  1.8019
  2.7435
  3.1539
  3.5241
  3.4905
  2.1262
  3.5187
  1.7932
  2.2064
  1.3563
  2.1649
  2.3597
  2.1348
  1.8746
  1.8287
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.0272
  0.7610
  1.2201
  0.6943
  0.5486
  0.8091
  1.0590
  0.7136
  0.4024
  0.8951
  0.4111
  0.9951
  0.7795
  1.1055
  0.4544
  1.3455
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.9840
  1.5583
  0.6815
  0.7154
  0.9704
  0.9451
  1.7096
  2.2470
  0.8197
  0.8288
  1.2204
  0.8420
  0.6387
  0.8588
  0.9653
  0.6732
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.9257
  1.6018
  1.3794
  1.8934
  1.7091
  2.5642
  1.1579
  2.8012
  1.2125
  1.1190
  1.1113
  1.1460
  1.8003
  1.

Variable containing:
1.00000e-04 *
  7.0186
  4.4581
  7.0887
  6.3313
  8.5999
  9.7015
  4.3874
  6.8392
  5.0303
  4.0755
  3.9671
  6.3734
  8.4243
  5.6350
  2.1533
  8.3234
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.2331
  0.8649
  1.2495
  1.2715
  1.4164
  1.4375
  1.7372
  1.6324
  0.9141
  0.8118
  0.6648
  0.8747
  1.5178
  1.6665
  0.7310
  2.3286
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.5165
  0.4429
  0.5477
  0.8265
  0.7788
  0.5098
  0.6232
  1.0544
  0.6907
  0.4251
  0.4254
  0.3337
  1.0267
  1.1484
  0.5853
  0.4653
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  2.0597
  1.0814
  1.7005
  1.2565
  2.4398
  1.7200
  1.8610
  2.4982
  0.8649
  2.4607
  0.8053
  1.0302
  1.4293
  0.7758
  1.3289
  1.2492
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.7455
  0.6384
  0.9844
  0.6031
  0.8544
  0.3781
  0.8823
  0.7887
  0.6013
  0.5654
  0.5076
  0.5602
  0.4081
  0.

Variable containing:
1.00000e-03 *
  1.0507
  0.9464
  0.6344
  0.8149
  0.5395
  0.6340
  0.6914
  1.4906
  0.7970
  0.9864
  0.7099
  0.7575
  1.1273
  1.8121
  0.9259
  0.8223
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  0.9789
  0.6702
  1.0161
  0.5185
  0.4394
  1.1500
  0.7965
  0.9738
  0.8380
  0.6909
  0.6644
  0.7405
  0.6440
  0.6167
  0.5692
  1.2308
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.2808
  0.8867
  0.8882
  0.6218
  1.1953
  1.1398
  0.8731
  1.6852
  0.3307
  0.4803
  0.2810
  0.6608
  1.1939
  0.7301
  0.4665
  1.2533
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-03 *
  1.8098
  1.1906
  1.0722
  1.2680
  1.0783
  1.5101
  1.1434
  1.9348
  0.6674
  0.9034
  0.9423
  1.8424
  1.1115
  0.8240
  1.2057
  1.9422
[torch.FloatTensor of size 16]

Variable containing:
1.00000e-04 *
  6.4768
  9.7809
  4.4546
  6.0939
  4.3920
  7.8863
  3.0811
  8.7584
  6.3746
  4.6623
  5.9624
  5.1250
  7.2139
  7.

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
[torch.FloatTensor of size 16]

Variable containing:
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan

KeyboardInterrupt: 