In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  _dtype_to_storage = {data_type(0).dtype: data_type for data_type in _storages}


## The process of our network

<img src="att2.jpg" width="480"/>

In [13]:
class Attention(torch.nn.Module):
    def __init__(self, feature_dim, seg_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.bias = bias
        self.feature_dim = feature_dim
        self.seg_dim = seg_dim
        self.flatten = nn.Flatten(start_dim=3, end_dim=-1)
        self.softmax = nn.Softmax(dim=0)
        
        weight = torch.zeros(feature_dim) # define the shape of attention weights
        nn.init.kaiming_uniform_(weight) # initialize it by using a normal distribution
        self.weight = nn.Parameter(weight)  # [6*46, 1]
        
        if bias:
            self.b = nn.Parameter(torch.zeros(seg_dim, 1)) # [100, 1]
        
    def forward(self, x):

        eij = torch.matmul(self.flatten(x.contiguous()), self.weight) # Output_shape = [25, 64, 100, 1]
        if self.bias:
            eij = eij + self.b
            
        a = self.softmax(eij) # Output_shape = [25, 64, 100, 1]
        a = torch.unsqueeze(a, -1) # Output_shape = [25, 64, 100, 1, 1]
    
        weighted_input = x * a # # Output_shape = [25, 64, 100, 6, 46]
        return weighted_input

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 1st conv layer
        self.conv1 = nn.Conv3d(1, 8, kernel_size=[1, 3, 3], stride=1, padding=0) 
        self.mp = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
        # 2nd conv layer 
        self.conv2 = nn.Conv3d(8, 16, kernel_size=[1, 3, 3], stride=1, padding=0) 
        # 3rd conv layer
        self.conv3 = nn.Conv3d(16, 32, kernel_size=[1, 3, 3], stride=1, padding=0) 
        # 4th conv layer
        self.conv4 = nn.Conv3d(32, 64, kernel_size=[1, 3, 3], stride=1, padding=0) 
        
        self.attention_layer = Attention(feature_dim=[6*46, 1], seg_dim=100)
     
        self.flatten1 = nn.Flatten(start_dim=3, end_dim=-1)  
        self.dense1 = nn.Linear(6*46, 1)
        self.flatten2 = nn.Flatten(start_dim=1, end_dim=-1)
        self.dense2 = nn.Linear(64*100, 1)
   

    def forward(self, x):
        x = self.conv1(x) # Input_size = [25, 1, 100 128, 768], Output_size = [25, 8, 100 126, 764]
        x = self.mp(x) # Input_size = [25, 8, 100 126, 764], Output_size = [25, 8, 100, 63, 383]

        x = F.relu(x)
        x = self.conv2(x) # Input_size = [25, 8, 100, 63, 383], Output_size = [25, 16, 100, 61, 381]
        x = self.mp(x) # Input_size = [25, 16, 100, 61, 381], Output_size = [25, 16, 100, 30, 190]
        x = F.relu(x)

        x = self.conv3(x) # Input_size = [25, 16, 100, 30, 190], Output_size = [25, 32, 100, 28, 188]
        x = self.mp(x) # Input_size = [25, 32, 100, 28, 188], Output_size = [25, 32, 100, 14, 94]
        x = F.relu(x)

        x = self.conv4(x) # Input_size = [25, 32, 100, 14, 94], Output_size = [25, 64, 100, 12, 92]
        x = self.mp(x) # Input_size = [25, 64, 100, 12, 92], Output_size = [25, 64, 100, 6, 46]
        x = F.relu(x)

        x = self.attention_layer(x) # Input_size = [25, 64, 100, 6, 46], Output_size = [25, 64, 100, 6, 46]

        x = self.flatten1(x) # Input_size = [25, 64, 100, 6, 46], Output_size = [25, 64, 100, 276]

        x = self.dense1(x) # Input_size = [25, 64, 100, 276], Output_size = [25, 64, 100, 1]

        x = self.flatten2(x) # Input_size = [25, 64, 100, 1], Output_size = [25, 64*100]

        x = self.dense2(x) # Input_size = [25, 64*100], Output_size = [25, 1]

        return x



In [3]:
X = torch.randn(25, 1, 100, 128, 768) #[segments, channels, tokens, embeddings]
model = Model()
out = model.forward(X)
print('Output shape: ', out.shape)

Output shape:  torch.Size([25, 1])


In [4]:
print('Output: ', out[0].item())

Output:  -0.03974542021751404
