In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse


In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--conv-in', type=int, default=4,
                        help='Input sequence features')
    return parser.parse_args()

In [None]:
def squash(tensor, dim=-1):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)


In [None]:
class Squash(nn.Module):


    def __init__(self, eps=10e-21, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps

    def forward(self, s):
        n = nn.norm(s,axis=-1,keepdims=True)
        return (1 - 1/(nn.math.exp(n)+self.eps))*(s/(n+self.eps))

In [None]:
def dynamic_routing(x, iterations=3):

    N = 32*6*6 # previous layer
    N1 = 10 # next layer
    B = x.shape[0]

    b = torch.zeros(B,N1,N,1, 1).to(x.device)
    for _ in range(iterations):        
        c = F.softmax(b, dim=1)  
        s = torch.sum(x.matmul(c), dim=2).squeeze(-1)
        v = squash(s)

        b = b + v[:,:,None,None,:].matmul(x)


    return v


In [None]:
class PrimaryCapsuleLayer(nn.Module):
    """
    Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'. 
    Properties of each capsule s_n are exatracted using a 2D depthwise convolution.
    
    ...
    
    Attributes
    ----------
    kernel_size[h,w]: int
        depthwise conv kernel dimension
    conv_num: int
        number of primary capsules
    feature_dimension: int
        primary capsules dimension (number of properties)
    conv_stride: int
        depthwise conv strides
    Methods
    -------
    call(inputs)
        compute the primary capsule layer
    """
    def __init__(self,conv_in, conv_out, kernel_size,conv_stride,feature_dimension,conv_num=1):
        super().__init__()
        self.feature_dimension=feature_dimension
        self.conv_num=conv_num
        self.primary_capsule_layer = \
            nn.ModuleList([nn.Conv2d(conv_in, conv_out, kernel_size, conv_stride) for _ in range(conv_num)])

    def forward(self, x):
        capsules = [conv(x) for conv in self.primary_capsule_layer]  
        capsules_reshaped = [c.reshape(self.conv_num,self.feature_dimension) for c in capsules]  
        return Squash()(capsules_reshaped)

In [None]:
class CapsLayer(nn.Module):
    def __init__(self,nclasses, out_channels_dim):
        super().__init__()
        self.W = nn.Parameter(1e-3 * torch.randn(1,nclasses,32*6*6,out_channels_dim,8))

    def forward(self, x):

        x = x[:,None,...,None]
        u_hat = self.W.matmul(x)
        class_capsules = dynamic_routing(u_hat)
        return class_capsules

In [None]:
class CapsNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Conv1d(1,256,9)
        self.primary_layer = PrimaryCapsuleLayer()
        self.caps_layer = CapsLayer(nclasses=10, out_channels_dim=16)

    def forward(self, x):
        x = self.conv_layer(x)  
        x = self.primary_layer(x)  
        x = self.caps_layer(x) 
        return x

In [None]:
args = parse_args()