In [1]:
import torch 
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
import torch.nn.functional as F
import math

from torch.optim import lr_scheduler

In [3]:
import argparse
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:

# MNIST Dataset
train_dataset = dsets.MNIST(root='./data/',
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data/',
                           train=False, 
                           transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100, 
                                          shuffle=False)

In [5]:
def squash(x):
    lengths2 = x.pow(2).sum(dim=2)
    lengths = lengths2.sqrt()
    x = x * (lengths2 / (1 + lengths2) / lengths).view(x.size(0), x.size(1), 1)
    return x

In [6]:
class AgreementRouting(nn.Module):
    def __init__(self, input_caps, output_caps, n_iterations):
        super(AgreementRouting, self).__init__()
        self.n_iterations = n_iterations
        self.b = nn.Parameter(torch.zeros((input_caps, output_caps)))

    def forward(self, u_predict):
        batch_size, input_caps, output_caps, output_dim = u_predict.size()

        c = F.softmax(self.b)
        s = (c.unsqueeze(2) * u_predict).sum(dim=1)
        v = squash(s)

        if self.n_iterations > 0:
            b_batch = self.b.expand((batch_size, input_caps, output_caps))
            for r in range(self.n_iterations):
                v = v.unsqueeze(1)
                b_batch = b_batch + (u_predict * v).sum(-1)

                c = F.softmax(b_batch.view(-1, output_caps)).view(-1, input_caps, output_caps, 1)
                s = (c * u_predict).sum(dim=1)
                v = squash(s)

        return v

In [7]:
class Encode_layer(nn.Module):
    def __init__(self, input_dim, output_dim, output_caps):
        super(Encode_layer, self).__init__()
        
        self.layer = nn.Linear(input_dim, output_dim*output_caps)
        self.output_dim = output_dim
        self.output_caps = output_caps
        
    def forward(self, input) : 
        output = self.layer(input)
        output = F.relu(output)
        output = output.view(output.shape[0], self.output_caps, self.output_dim)
        output = squash(output)
        
        return output

In [8]:
class Neuron_layer(nn.Module):
    def __init__(self, input_neurons, input_dim, output_dim, output_neurons, routing):
        super(Neuron_layer, self).__init__()
        self.input_neurons = input_neurons
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.output_neurons = output_neurons
        self.routing = routing
        self.weights = nn.Parameter(torch.Tensor(input_neurons, input_dim, output_neurons * output_dim))
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.input_neurons)
        self.weights.data.uniform_(-stdv, stdv)
        
    def forward(self, input):
        
        input = input.unsqueeze(2)
        predict = input.matmul(self.weights)
        predict = predict.view(predict.size(0), self.input_neurons, self.output_neurons, self.output_dim)
        v = self.routing(predict)
        
        return v
        

In [9]:
class Learning(nn.Module) :
    def __init__(self, routing_iterations) :
        super(Learning, self).__init__()
        
        self.network1 = Encode_layer(784, 128, 4)
        
        routing_module1 = AgreementRouting(4, 3, routing_iterations)
        
        self.network2 = Neuron_layer(4, 128, 32, 3, routing_module1)
        
        routing_module2 = AgreementRouting(3, 10, routing_iterations)
        
        self.network3 = Neuron_layer(3, 32, 16, 10, routing_module2)
        
    def forward(self, input):
        
        output1 = self.network1(input)
        output2 = self.network2(output1)
        output3 = self.network3(output2)
        
        probs = output3.pow(2).sum(dim=2).sqrt()
        
        return probs

In [10]:
model=Learning(3)

In [11]:
class MarginLoss(nn.Module):
    def __init__(self, m_pos, m_neg, lambda_):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, lengths, targets, size_average=True):
        t = torch.zeros(lengths.size()).long()
        if targets.is_cuda:
            t = t.cuda()
        t = t.scatter_(1, targets.data.view(-1, 1), 1)
        targets = Variable(t)
        losses = targets.float() * F.relu(self.m_pos - lengths).pow(2) + \
                 self.lambda_ * (1. - targets.float()) * F.relu(lengths - self.m_neg).pow(2)
        return losses.mean() if size_average else losses.sum()


In [12]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=15, min_lr=1e-6)

loss_fn = MarginLoss(0.9, 0.1, 0.5)

In [13]:
def train(epoch):
    model.train()
    correct=0.
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data.view(data.shape[0], -1)), Variable(target, requires_grad=False)
        optimizer.zero_grad()
        probs= model(data)
        
        loss = loss_fn(probs, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0]))
            
        pred = probs.data.max(1, keepdim=True)[1]  # get the index of the max probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
    print('Accuracy : {:.2f}%'.format(100. * correct / len(train_loader.dataset)))

In [14]:
import numpy as np

In [15]:
for epoch in range(100) :
    print(epoch)
    train(epoch)  #More stable 

0


  # Remove the CWD from sys.path while we load stuff.


Accuracy : 82.17%
1
Accuracy : 94.89%
2


Accuracy : 96.30%
3
Accuracy : 97.22%
4
Accuracy : 97.86%
5


Accuracy : 98.32%
6
Accuracy : 98.65%
7
Accuracy : 98.97%
8


Accuracy : 99.16%
9
Accuracy : 99.28%
10


Accuracy : 99.43%
11
Accuracy : 99.49%
12
Accuracy : 99.56%
13


Accuracy : 99.70%
14
Accuracy : 99.75%
15
Accuracy : 99.78%
16


Accuracy : 99.71%
17
Accuracy : 99.81%
18


Accuracy : 99.84%
19
Accuracy : 99.86%
20
Accuracy : 99.90%
21


Accuracy : 99.92%
22
Accuracy : 99.86%
23
Accuracy : 99.90%
24


Accuracy : 99.92%
25
Accuracy : 99.94%
26
Accuracy : 99.88%
27


Accuracy : 99.94%
28
Accuracy : 99.93%
29


Accuracy : 99.95%
30
Accuracy : 99.97%
31
Accuracy : 99.92%
32


Accuracy : 99.90%
33
Accuracy : 99.94%
34
Accuracy : 99.95%
35


Accuracy : 99.94%
36
Accuracy : 99.91%
37


Accuracy : 99.97%
38
Accuracy : 99.97%
39
Accuracy : 99.92%
40


Accuracy : 99.92%
41
Accuracy : 99.94%
42
Accuracy : 99.94%
43


Accuracy : 99.98%
44
Accuracy : 99.98%
45


Accuracy : 99.95%
46
Accuracy : 99.91%
47
Accuracy : 99.95%
48


Accuracy : 99.95%
49
Accuracy : 99.95%
50
Accuracy : 99.94%
51


Accuracy : 99.98%
52
Accuracy : 99.98%
53


Accuracy : 99.97%
54
Accuracy : 99.89%
55
Accuracy : 99.98%
56


Accuracy : 99.98%
57
Accuracy : 99.95%
58
Accuracy : 99.98%
59


Accuracy : 99.98%
60
Accuracy : 99.96%
61


Accuracy : 99.96%
62
Accuracy : 99.98%
63
Accuracy : 99.96%
64


Accuracy : 99.98%
65
Accuracy : 99.99%
66
Accuracy : 99.97%
67


Accuracy : 99.95%
68
Accuracy : 99.97%
69


Accuracy : 99.99%
70
Accuracy : 99.98%
71
Accuracy : 99.93%
72


Accuracy : 99.98%
73
Accuracy : 99.97%
74
Accuracy : 99.99%
75


Accuracy : 99.99%
76
Accuracy : 99.99%
77
Accuracy : 99.91%
78


Accuracy : 99.99%
79
Accuracy : 99.99%
80


Accuracy : 99.99%
81
Accuracy : 99.99%
82
Accuracy : 100.00%
83


Accuracy : 99.90%
84
Accuracy : 99.99%
85
Accuracy : 99.99%
86


Accuracy : 100.00%
87
Accuracy : 100.00%
88


Accuracy : 100.00%
89
Accuracy : 100.00%
90
Accuracy : 100.00%
91


Accuracy : 99.89%
92
Accuracy : 99.94%
93
Accuracy : 100.00%
94


Accuracy : 100.00%
95
Accuracy : 100.00%
96


Accuracy : 100.00%
97
Accuracy : 100.00%
98
Accuracy : 100.00%
99


Accuracy : 100.00%


In [16]:
#Test
total=0
correct=0.
#batch size 다를 땐 테스트가 안 됌?
for images, labels in test_loader:
    images = Variable(images.view(images.shape[0], -1))
    outputs= model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))


  # Remove the CWD from sys.path while we load stuff.


Test Accuracy of the model on the 10000 test images: 98 %


In [17]:
adv_test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=1, 
                                          shuffle=False)

In [18]:
dummy=1
s=0.
t=0.


for (x, y) in adv_test_loader :
    x, y = Variable(x, requires_grad=True), Variable(y, requires_grad=False)
    probs= model(x.view(x.shape[0], -1))
    y_pred = np.argmax(probs.data.numpy())
    loss = nn.CrossEntropyLoss()(probs, y)
    loss.backward()
    epsilon = 0.1 
    x_grad   = torch.sign(x.grad.data)
    x_adversarial = torch.clamp(x.data + epsilon * x_grad, 0, 1)
    adversarial_probs= model(Variable(x_adversarial).view(x.shape[0], -1))
    y_pred_adversarial = np.argmax(adversarial_probs.data.numpy())
    
    print("{0} th example ".format(dummy))
    print ("True value: "+ str(y.data.numpy()[0])+"\nPredicted value : "+str(y_pred)+ "\nAdversarial :" + str(y_pred_adversarial)+"\n" )
    
    dummy+=1
    
    if y.data.numpy()[0]!=y_pred :
        t+=1
    
    if y_pred!=y_pred_adversarial :
        s+=1
    if dummy==100:
        break
        
print("Accuracy of test_model : {0:.3f} , Adversarials : {1:.3f}".format((t/dummy)*100, (s/dummy)*100))

1 th example 
True value: 7
Predicted value : 7
Adversarial :3

2 th example 
True value: 2
Predicted value : 2
Adversarial :3

3 th example 
True value: 1
Predicted value : 1
Adversarial :8

4 th example 
True value: 0
Predicted value : 0
Adversarial :0

5 th example 
True value: 4
Predicted value : 4
Adversarial :9

6 th example 
True value: 1
Predicted value : 1
Adversarial :7

7 th example 
True value: 4
Predicted value : 4
Adversarial :8

8 th example 
True value: 9
Predicted value : 9
Adversarial :3

9 th example 
True value: 5
Predicted value : 5
Adversarial :4

10 th example 
True value: 9
Predicted value : 9
Adversarial :9

11 th example 
True value: 0
Predicted value : 0
Adversarial :2

12 th example 
True value: 6
Predicted value : 6
Adversarial :6

13 th example 
True value: 9
Predicted value : 9
Adversarial :7

14 th example 
True value: 0
Predicted value : 0
Adversarial :0

15 th example 
True value: 1
Predicted value : 1
Adversarial :8

16 th example 
True value: 5
Predi

  # Remove the CWD from sys.path while we load stuff.


29 th example 
True value: 0
Predicted value : 0
Adversarial :0

30 th example 
True value: 1
Predicted value : 1
Adversarial :3

31 th example 
True value: 3
Predicted value : 3
Adversarial :3

32 th example 
True value: 1
Predicted value : 1
Adversarial :8

33 th example 
True value: 3
Predicted value : 3
Adversarial :3

34 th example 
True value: 4
Predicted value : 4
Adversarial :0

35 th example 
True value: 7
Predicted value : 7
Adversarial :3

36 th example 
True value: 2
Predicted value : 2
Adversarial :2

37 th example 
True value: 7
Predicted value : 7
Adversarial :2

38 th example 
True value: 1
Predicted value : 1
Adversarial :8

39 th example 
True value: 2
Predicted value : 2
Adversarial :3

40 th example 
True value: 1
Predicted value : 1
Adversarial :1

41 th example 
True value: 1
Predicted value : 1
Adversarial :8

42 th example 
True value: 7
Predicted value : 7
Adversarial :3

43 th example 
True value: 4
Predicted value : 4
Adversarial :9

44 th example 
True value