In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--conv-in", type=int, default=4, help="Input sequence features"
    )
    return parser.parse_args()

In [8]:
def squash(tensor, dim=-1):
    squared_norm = (tensor**2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * tensor / torch.sqrt(squared_norm)

In [9]:
class Squash(nn.Module):
    def __init__(self, eps=10e-21, **kwargs):
        super(Squash,self).__init__(**kwargs)
        self.eps = eps

    def forward(self, s):
        n = nn.norm(s, axis=-1, keepdims=True)
        return (1 - 1 / (torch.exp(n) + self.eps)) * (s / (n + self.eps))

In [14]:
def dynamic_routing(x, iterations=3):
    x=x.unsqueeze(-1)
    N = x.shape[1]  # previous layer
    N1 = 1 # next layer
    B = x.shape[0]
    #feature_dim = x.shape[2]

    b = torch.zeros(B, N1, N,1,1).to(x.device)
    for _ in range(iterations):
        c = F.softmax(b, dim=1)
        s = torch.sum(x.matmul(c), dim=2).squeeze(-1)
        v = squash(s)

        b = b + v.matmul(x)

    return v

In [10]:
c=F.softmax(torch.zeros(1, 1, 5,1,1), dim=1)
print(c)
print(c.shape)
x = torch.randn(1, 5, 21 * 5,1)
a = x.matmul(c)
print(a)
print(a.shape)
s = torch.sum(x.matmul(c), dim=2).squeeze(-1)
print(s)
print(s.shape)

tensor([[[[[1.]],

          [[1.]],

          [[1.]],

          [[1.]],

          [[1.]]]]])
torch.Size([1, 1, 5, 1, 1])
tensor([[[[[-7.8850e-01],
           [ 3.1928e-03],
           [-4.9653e-02],
           [-1.4989e+00],
           [-7.6802e-01],
           [ 1.1338e+00],
           [ 3.3159e-02],
           [-5.5799e-01],
           [-7.2579e-01],
           [ 3.3569e-01],
           [ 3.6142e-01],
           [-6.3129e-01],
           [-5.9700e-01],
           [-9.4282e-02],
           [-5.2561e-01],
           [ 1.4987e-01],
           [-1.3559e+00],
           [-2.9956e-02],
           [-1.1692e+00],
           [-4.3946e-01],
           [-3.9858e-01],
           [ 2.0871e+00],
           [-1.8853e+00],
           [ 6.1805e-01],
           [ 7.8107e-01],
           [ 2.4735e-01],
           [ 2.8662e-01],
           [-2.9941e-01],
           [-3.6614e-01],
           [-1.4977e+00],
           [-5.7680e-01],
           [-1.1964e+00],
           [-2.6039e+00],
           [-2.38

In [None]:
class PrimaryCapsuleLayer(nn.Module):
    """
    Create a primary capsule layer with the methodology described in 'Efficient-CapsNet: Capsule Network with Self-Attention Routing'.
    Properties of each capsule s_n are exatracted using a 1D depthwise convolution.

    ...

    Attributes
    ----------
    kernel_size[w]: int
        depthwise conv kernel dimension
    conv_num: int
        number of primary capsules
    feature_dimension: int
        primary capsules dimension (number of properties)
    conv_stride: int
        depthwise conv strides
    Methods
    -------
    forward(inputs)
        compute the primary capsule layer
    """

    def __init__(
        self, conv_in=2, feature_dimension=21*5, kernel_size=2, conv_num=5,base_num=21
    ):
        super().__init__()
        self.conv_out = feature_dimension//(conv_num*base_num)
        self.conv_num = conv_num
        self.primary_capsule_layer = nn.ModuleList(
            [
                nn.Conv1d(conv_in, self.conv_out, kernel_size, conv_stride,dilation=conv_stride, padding='same')
                for conv_stride in range(1,conv_num+1)
            ]
        )

    def forward(self, x):
        capsules = [conv(x) for conv in self.primary_capsule_layer]
        #capsules_reshaped = [
        #    c.reshape(self.conv_num, self.feature_dimension) for c in capsules
        #]
        output_tensor = torch.cat(capsules, dim=1)
        return Squash()(output_tensor)

In [15]:
import torch.nn as nn

# Define the input and output channels
in_channels = 2
out_channels = 1

# Define the kernel size and dilation
kernel_size = 2

# Define the 1D dilated convolution layers
conv1d_list = nn.ModuleList()
for dilation in range(1, 6):
    padding = 'same'
    conv1d_list.append(nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation,padding = padding))

# Define the input tensor
input_tensor = torch.randn(1, in_channels, 21 * 5)

# Apply the 1D dilated convolutions to the input tensor
output_tensor_list = []
for conv1d in conv1d_list:
    print(conv1d(input_tensor).shape)
    output_tensor_list.append(conv1d(input_tensor))

# Concatenate the output tensors along the channel dimension
output_tensor = torch.cat(output_tensor_list, dim=1)

print(output_tensor.shape)
print(dynamic_routing(output_tensor).shape)

torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 1, 105])
torch.Size([1, 5, 105])
torch.Size([1, 1, 105])


In [None]:
class CapsLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels):
        super().__init__()
        self.W = nn.Parameter(
            torch.randn(num_capsules, num_route_nodes, in_channels, out_channels)
        )

    def forward(self, x):
        x = x[:, None, ..., None]
        u_hat = self.W.matmul(x)
        class_capsules = dynamic_routing(u_hat)
        return class_capsules

In [None]:
class CapsNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Conv1d(1, 256, 9)
        self.primary_layer = PrimaryCapsuleLayer()
        self.caps_layer = CapsLayer(nclasses=10, out_channels_dim=16)

    def forward(self, x):
        x = self.conv_layer(x)
        x = self.primary_layer(x)
        x = self.caps_layer(x)
        return x

In [None]:
args = parse_args()