In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [2]:
class Squash(nn.Module):
    def __init__(self,epsilon = 1e-8):
        super().__init__()
        self.epsilon = epsilon
        
    def forward(self,s):
        s2 = (s**2).sum(dim=-1,keepdims=True)
        return (s2/(1+s2))*(s/torch.sqrt(s2+self.epsilon))

In [3]:
class Router(nn.Module):
    def __init__(self,in_caps,out_caps,in_d,out_d,iterations):
        super().__init__()
        
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.iterations = iterations
        self.softmax = nn.Softmax(dim=1)
        self.squash = Squash()
        
        self.weight = nn.Parameter(torch.randn(in_caps,out_caps,in_d,out_d),requires_grad=True)
        
    def forward(self,u):
        u_hat = torch.einsum('ijnm,bin->bijm',self.weight,u)
        b = u.new_zeros(u.shape[0],self.in_caps,self.out_caps)
        v = None
        
        for i in range(self.iterations):
            c = self.softmax(b)
            s = torch.einsum('bij,bijm->bjm',c,u_hat)
            v = self.squash(s)
            a = torch.einsum('bjm,bijm->bij',v,u_hat)
            b = b+a
        
        return v

In [4]:
class MarginLoss(nn.Module):
    def __init__(self,*,n_labels):
        super().__init__()
        
        self.m_poz = 0.9
        self.m_neg = 0.1
        self.lambda_ =0.5
        self.n_labels = n_labels
        
    def forward(self,v,labels):
        v_norm = torch.sqrt((v**2).sum(dim=-1))
        loss = labels * F.relu(self.m_poz - v_norm) + self.lambda_*(1.0-labels)*F.relu(v_norm - self.m_neg)
        
        return loss.sum(dim=-1).mean()  

In [5]:
class MNISTCapsuleNetworkModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=256,kernel_size=9,stride=1)
        self.conv2 = nn.Conv2d(in_channels=256,out_channels=32*8,kernel_size=9,stride=2,padding=0)
        self.squash = Squash()
        
        self.digit_capsules = Router(32*6*6,10,8,16,3)
        
        self.decoder = nn.Sequential(
            nn.Linear(16*10,512),
            nn.ReLU(),
            nn.Linear(512,1024),
            nn.ReLU(),
            nn.Linear(1024,784),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x = F.relu(self.conv1(x)) # [bs,256,20,20]
        x = self.conv2(x) # [bs ,32*8,6 6]
        
        caps = x.view(x.shape[0],8,32*6*6).permute(0,2,1)
        caps = self.squash(caps)
        caps = self.digit_capsules(caps)
        
        with torch.no_grad():
            pred = (caps **2).sum(-1).argmax(-1)
            mask = torch.eye(10,device=x.device)[pred]
            
        reconstructions = self.decoder((caps*mask[:,:,None]).view(x.shape[0],-1))
        reconstructions = reconstructions.view(-1,1,28,28)
        
        return caps,reconstructions,pred