In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse


In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--conv-in', type=int, default=4,
                        help='Input sequence features')
    return parser.parse_args()

In [None]:
def squash(tensor, dim=-1):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)


In [None]:
class Squash(nn.Module):
    """
    Squash activation used in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'.
    
    ...
    
    Attributes
    ----------
    eps: int
        fuzz factor used in numeric expression
    
    Methods
    -------
    call(s)
        compute the activation from input capsules
    """

    def __init__(self, eps=10e-21, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps

    def forward(self, s):
        n = nn.norm(s,axis=-1,keepdims=True)
        return (1 - 1/(nn.math.exp(n)+self.eps))*(s/(n+self.eps))

In [None]:
def dynamic_routing(x, iterations=3):

    N = 32*6*6 # previous layer
    N1 = 10 # next layer
    B = x.shape[0]

    b = torch.zeros(B,N1,N,1, 1).to(x.device)
    for _ in range(iterations):        
        # probability of each vector to be distributed is 1
        # (B,10,32*6*6,1, 1)
        c = F.softmax(b, dim=1)  

        # (B,10,16)
        s = torch.sum(x.matmul(c), dim=2).squeeze(-1)

        # (B,10,16)
        v = squash(s)

        # (B,10,32*6*6,1,1)
        b = b + v[:,:,None,None,:].matmul(x)


    return v


In [None]:
class PrimaryCapsuleLayer(nn.Module):
    def __init__(self,conv_in, conv_out, conv_k,conv_stride,conv_num=1):
        super().__init__()
        self.primary_capsule_layer = \
            nn.ModuleList([nn.Conv1d(conv_in, conv_out, conv_k, conv_stride) for _ in range(conv_num)])

    def forward(self, x):
        capsules = [conv(x) for conv in self.primary_capsule_layer]  
        capsules_reshaped = [c.reshape(-1,8,6*6) for c in capsules] 
        s = torch.cat(capsules_reshaped, dim=-1).permute(0, 2, 1) 
        return squash(s)

In [None]:
class CapsLayer(nn.Module):
    def __init__(self,nclasses=10, out_channels_dim=16):
        super().__init__()
        self.W = nn.Parameter(1e-3 * torch.randn(1,nclasses,32*6*6,out_channels_dim,8))

    def forward(self, x):
        """Predict and routing

        Args:
            x: Input vectors, (B, 32*6*6, 8)

        Return:
            class capsules, (B, 10, 16)
        """
        x = x[:,None,...,None]
        u_hat = self.W.matmul(x)  # (B, 10, 32x6x6, 16, 1)
        assert u_hat.shape[1:] == (10, 32*6*6, 16, 1)
        class_capsules = dynamic_routing(u_hat)
        return class_capsules

In [None]:
class CapsNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Conv1d(1,256,9)
        self.primary_layer = PrimaryCapsuleLayer()
        self.caps_layer = CapsLayer(nclasses=10, out_channels_dim=16)

    def forward(self, x):
        """
        Args:
            x : Input img, (B, 1, 28, 28)

        Return:
            the class capsules, each capsule is a 16 dimension vector
        """
        x = self.conv_layer(x)  # (B, 256, 20, 20)
        x = self.primary_layer(x)  # (B, 32*6*6, 8)
        x = self.caps_layer(x)  # (B, 10, 16)
        return x

In [None]:
args = parse_args()