In [2]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
import numpy as np
import torchvision
from torch.autograd import Variable
from torchvision import transforms 
BATCH_SIZE = 100
NUM_EPOCHS = 3
NUM_ROUTING_ITERATIONS = 3
USE_CUDA = True

# 读取mnist数据集

In [3]:
class Mnist:
    def __init__(self, batch_size):
        dataset_transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

        train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=dataset_transform)
        test_dataset = torchvision.datasets.MNIST('./data', train=False, download=True, transform=dataset_transform)
        
        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)        

# Model

In [4]:
def squash(v):
    epsilon = 0.00000001
    vector_norm = (v ** 2).sum(-1,keepdim = True) + epsilon
    output = vector_norm * v/((1. + vector_norm) * torch.sqrt(vector_norm))
    return output

In [5]:
#第一层，使用普通卷积得到基础特征
#(batch,28,28,1)
class Convlayer(nn.Module):
    def __init__(self,in_channels = 1,out_channels = 256,kernel_size = 9):
        super(Convlayer, self).__init__()
        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=kernel_size,
                              stride = 1)
    def forward(self,x):
        #(batch_size,20,20,256)
        return F.relu(self.conv(x))

In [6]:
class PrimaryCapslayer(nn.Module):
    def __init__(self,num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCapslayer,self).__init__()
        self.capsules = nn.ModuleList([nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = kernel_size,stride = 2,padding=0)
                                      for _ in range(num_capsules)])

    def forward(self,x):
        u = [capsule(x).view(x.size(0),-1,1) for capsule in self.capsules]
        #u:(batch_size,8,6,6,32)
        u = torch.cat(u,dim=-1)
        #u:(batch_size,1152,8)
        return squash(u)

In [7]:
class DigitCaps(nn.Module):
    def __init__(self,num_capsules = 10,num_routes = 32*6*6,in_channels = 8,out_channels = 16):
        super(DigitCaps,self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        
        self.W = nn.Parameter(torch.randn(num_routes,num_capsules,out_channels,in_channels))
    
    def forward(self,x):
        u = x.unsqueeze(3)
        u = u.unsqueeze(2)
        #print("u.shape:",u.shape)
        #x:(batch_size,1152,1,8,1)
        #W:(1152,10,16,8)
        #W*x = (batch_size,1152,10,16,1)
        #W = self.W.unsqueeze(0)
        #print("w.shape",self.W.shape)
        u_hat = torch.matmul(self.W,u)
        #u_hat(batch_size, 1152, 10, 16, 1)
        u_hat = u_hat.squeeze(-1)
        #print("u_hat.shape",u_hat.shape)
        u_hat=u_hat.permute(0,2,1,3)
        
        b_ij = torch.zeros(u_hat.size(0),self.num_capsules,1,self.num_routes)
        if USE_CUDA:
            b_ij = b_ij.cuda()
            
        num_iterations = NUM_ROUTING_ITERATIONS
        for iteration in range(num_iterations):
            #print(b_ij.shape)
            c_ij = F.softmax(b_ij,-1)
            #print("u_hat.shape",u_hat.shape)
            #print("c_ij.shape:",c_ij.shape)
            s_j = torch.matmul(c_ij,u_hat)
            #print("s_j:",s_j.shape)
            v_j = squash(s_j)
            if iteration < num_iterations -1:
                a_ij = torch.matmul(v_j,u_hat.permute(0,1,3,2))
                b_ij = b_ij + a_ij
        v_j = v_j.permute(0,1,3,2).squeeze(-1)
        #print("v_j",v_j.shape)
        return v_j

In [8]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet,self).__init__()
        self.conv_layer = Convlayer()
        self.primarycaps_layer = PrimaryCapslayer()
        self.DigitCaps_layer = DigitCaps()
        self.mse_loss = nn.MSELoss()
        
    def forward(self,data):
        data = self.conv_layer(data)
        data = self.primarycaps_layer(data)
        output = self.DigitCaps_layer(data)
        return output
    def loss(self, data, x, target):
        return self.margin_loss(x, target) 
    
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2))
        #print("vc",v_c.shape)
        #(batch_size,10)
        left = F.relu(0.9 - v_c)**2
        right = F.relu(v_c - 0.1)**2
        labels = torch.sparse.torch.eye(10).index_select(dim=0, index=labels.data.cpu())
        #print(labels.shape)
        if USE_CUDA:
            labels = labels.cuda()
        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()
        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.001

# Train

In [9]:
capsule_net = CapsNet()
if USE_CUDA:
    capsule_net = capsule_net.cuda()
optimizer = optim.Adam(capsule_net.parameters())

In [35]:
batch_size = BATCH_SIZE
mnist = Mnist(batch_size)
print(batch_size)
n_epochs = NUM_EPOCHS

for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    for batch_id, (data,target) in enumerate(mnist.train_loader):
        if USE_CUDA:
            data,target = data.cuda(),target.cuda()
            
        optimizer.zero_grad()
        output = capsule_net(data)
        loss = capsule_net.loss(data,output,target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        if batch_id % 10 == 0:
            output = torch.sqrt((output**2).sum(dim=2, keepdim=True))
            output = output.squeeze(-1)
#             print(np.argmax(output.cpu().detach().numpy(),1))
#             print(target.cpu().detach())
            print(loss.item())
            print("train accuracy:",sum(np.argmax(output.cpu().detach().numpy(),1) == target.data.cpu().numpy())/float(batch_size))
    
    capsule_net.eval()
    test_loss = 0
    for batch_id, (data, target) in enumerate(mnist.test_loader):
    #         labels = torch.sparse.torch.eye(10).index_select(dim=0, index=labels.data.cpu())

            target = target.long()

            data, target = Variable(data), Variable(target)

            if USE_CUDA:
                data, target = data.cuda(), target.cuda()

            output = capsule_net(data)
            loss = capsule_net.loss(data, output, target)

    #         test_loss += loss.data[0]
            test_loss += loss.item()

            if batch_id % 100 == 0:
                output = torch.sqrt((output**2).sum(dim=2, keepdim=True))
                output = output.squeeze(-1)
                print("test accuracy:", sum(np.argmax(output.cpu().detach().numpy(),1) == target.data.cpu().numpy())/float(batch_size))

    
    print(test_loss / len(mnist.test_loader))

100
0.13491101562976837
train accuracy: 0.86
0.09772814810276031
train accuracy: 0.93
0.07092873752117157
train accuracy: 0.94
0.0762912929058075
train accuracy: 0.92
0.06049880012869835
train accuracy: 0.96
0.054591625928878784
train accuracy: 0.98
0.037535686045885086
train accuracy: 0.97
0.021849989891052246
train accuracy: 1.0
0.052823055535554886
train accuracy: 0.96
0.03668174520134926
train accuracy: 0.97
0.04867825284600258
train accuracy: 0.95
0.0379539430141449
train accuracy: 0.95
0.02712220698595047
train accuracy: 0.96
0.02531808242201805
train accuracy: 0.99
0.027278419584035873
train accuracy: 0.98
0.03970300406217575
train accuracy: 0.96
0.03321731090545654
train accuracy: 0.97
0.03557023033499718
train accuracy: 0.98
0.023673653602600098
train accuracy: 0.98
0.03448548540472984
train accuracy: 0.97
0.025972655043005943
train accuracy: 0.98
0.023576343432068825
train accuracy: 0.99
0.04173723980784416
train accuracy: 0.96
0.011991513893008232
train accuracy: 1.0
0.02321