In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--conv-in", type=int, default=4, help="Input sequence features"
    )
    return parser.parse_args()

In [36]:
def squash(tensor, dim=-1):
    squared_norm = (tensor**2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * tensor / torch.sqrt(squared_norm)

In [37]:
class Squash(nn.Module):
    def __init__(self, eps=10e-21, **kwargs):
        super(Squash,self).__init__(**kwargs)
        self.eps = eps

    def forward(self, s):
        n = torch.norm(s, dim=-1, keepdim=True)
        return (1 - 1 / (torch.exp(n) + self.eps)) * (s / (n + self.eps))

In [45]:
def dynamic_routing(x, iterations=3):
    x=x.unsqueeze(-1)
    N = x.shape[1]  # previous layer
    N1 = 1 # next layer
    B = x.shape[0]
    #feature_dim = x.shape[2]

    b = torch.zeros(B, N1, N,1,1).to(x.device)
    for _ in range(iterations):
        print('x shape: {}'.format(x.shape))
        c = F.softmax(b, dim=1)
        print('c shape: {}'.format(c.shape))
        a = x.matmul(c)
        #print('a shape: {}'.format(a.shape))
        s = torch.sum(a, dim=2).squeeze(-1)
        #print('s shape: {}'.format(s.shape))
        v = squash(s)
        #print('v shape: {}'.format(v.shape))
        #print('x shape: {}'.format(x.shape))
        y=v.matmul(x)
        #print('y shape: {}'.format(y.shape))
        #print('b shape: {}'.format(b.shape))
        b = b + y

    return v

In [16]:
c=F.softmax(torch.zeros(1, 1, 5,1,1), dim=1)
#print(c)
print(c.shape)
x = torch.randn(1, 5, 21 * 5,1)
a = x.matmul(c)
#print(a)
print(a.shape)
s = torch.sum(x.matmul(c), dim=2).squeeze(-1)
#print(s)
print(s.shape)

torch.Size([1, 1, 5, 1, 1])
torch.Size([1, 1, 5, 105, 1])
torch.Size([1, 1, 105])


In [33]:
class PrimaryCapsuleLayer(nn.Module):
    """
    Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'.
    Properties of each capsule s_n are exatracted using a 1D depthwise convolution.

    ...

    Attributes
    ----------
    kernel_size[w]: int
        depthwise conv kernel dimension
    conv_num: int
        number of primary capsules
    feature_dimension: int
        primary capsules dimension (number of properties)
    conv_stride: int
        depthwise conv strides
    Methods
    -------
    forward(inputs)
        compute the primary capsule layer
    """

    def __init__(
        self, conv_in=2, feature_dimension=21*5, kernel_size=2, conv_num=5,base_num=21
    ):
        super().__init__()
        self.conv_out = feature_dimension//(conv_num*base_num)
        self.conv_num = conv_num
        self.primary_capsule_layer = nn.ModuleList(
            [
                nn.Conv1d(conv_in, self.conv_out, kernel_size,dilation=conv_stride, padding='same')
                for conv_stride in range(1,conv_num+1)
            ]
        )

    def forward(self, x):
        capsules = [conv(x) for conv in self.primary_capsule_layer]
        #capsules_reshaped = [
        #    c.reshape(self.conv_num, self.feature_dimension) for c in capsules
        #]
        output_tensor = torch.cat(capsules, dim=1)
        return Squash()(output_tensor)
def test_for_primary_capsule_layer():
    input = torch.rand(1,2,105)
    layer = PrimaryCapsuleLayer()
    print(layer(input).shape)
test_for_primary_capsule_layer()

torch.Size([1, 5, 105])


In [15]:
import torch.nn as nn

# Define the input and output channels
in_channels = 2
out_channels = 1

# Define the kernel size and dilation
kernel_size = 2

# Define the 1D dilated convolution layers
conv1d_list = nn.ModuleList()
for dilation in range(1, 6):
    padding = 'same'
    conv1d_list.append(nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation,padding = padding))

# Define the input tensor
input_tensor = torch.randn(1, in_channels, 21 * 5)

# Apply the 1D dilated convolutions to the input tensor
output_tensor_list = []
for conv1d in conv1d_list:
    print(conv1d(input_tensor).shape)
    output_tensor_list.append(conv1d(input_tensor))

# Concatenate the output tensors along the channel dimension
output_tensor = torch.cat(output_tensor_list, dim=1)

print(output_tensor.shape)
print(dynamic_routing(output_tensor).shape)

torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 5, 105])
torch.Size([1, 1, 105])


In [61]:
class CapsLayer(nn.Module):
    def __init__(self, num_capsules=1, num_route_nodes=5, in_channels=105, out_channels=20):
        super(CapsLayer,self).__init__()
        self.W = nn.Parameter(0.01 * 
            torch.randn(1,num_capsules, num_route_nodes, out_channels, in_channels)
        )
        print('W shape: {}'.format(self.W.shape))

    def forward(self, x):
        x = x[:,None,...,None]#x.unsqueeze(1).unsqueeze(4)
        #x = x.unsqueeze(-1)
        #print('W expand shape: {}'.format(self.W[:, None, :, :, :].shape))
        #print('CapsLayer input shape: {}'.format(x.shape))
        #print('CapsLayer input expand shape: {}'.format(x[ :, :, None, :].shape))
        # (batch_size, num_caps, num_route_nodes, out_channels, 1)
        print('x shape: {}'.format(x.shape))
        u_hat = torch.matmul(self.W, x)#(x @ self.W).squeeze(2)
        #u=u_hat.squeeze(-1)
        u_hat = u_hat.squeeze(-1)
        print('u_hat shape: {}'.format(u_hat.shape))
        class_capsules = dynamic_routing(u_hat)
        return class_capsules

In [54]:
a = torch.rand(1, 10, 10, 20, 1)
b = torch.rand(1, 1, 10, 1, 1)
c = a.matmul(b)

In [62]:
input = torch.rand(1,5,105)
layer = CapsLayer()
print(layer(input).shape)

W shape: torch.Size([1, 1, 5, 20, 105])
x shape: torch.Size([1, 1, 5, 105, 1])
u_hat shape: torch.Size([1, 1, 5, 20])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 1, 1, 1])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 5, 1, 1])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 5, 1, 1])
torch.Size([1, 1, 20])


In [63]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet,self).__init__()
        self.primary_layer = PrimaryCapsuleLayer()
        self.caps_layer = CapsLayer()

    def forward(self, x):
        x = self.primary_layer(x)
        x = self.caps_layer(x)
        return x
def test_for_caps_net():
    input = torch.rand(1,2,105)
    model = CapsNet()
    print(model(input).shape)
test_for_caps_net()

W shape: torch.Size([1, 1, 5, 20, 105])
x shape: torch.Size([1, 1, 5, 105, 1])
u_hat shape: torch.Size([1, 1, 5, 20])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 1, 1, 1])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 5, 1, 1])
x shape: torch.Size([1, 1, 5, 20, 1])
c shape: torch.Size([1, 1, 5, 1, 1])
torch.Size([1, 1, 20])


In [None]:
args = parse_args()