In [7]:
from torch import nn

class PatchEncoder(nn.Module):
    def __init__(self, input_shape, input_position_shape, projection_dim):
        super(PatchEncoder, self).__init__()
        self.projection = nn.Linear(input_shape, projection_dim)
        self.position_embedding = nn.Linear(input_position_shape, projection_dim)

    def forward(self, patch, position):
        return (self.projection(patch) + self.position_embedding(position))


class SCSTransformer(nn.Module):
    def __init__(self,class_num, input_shape, input_position_shape, projection_dim, num_heads, transformer_units, transformer_layers, mlp_head_units):
        super(SCSTransformer, self).__init__()
        self.patch_encoder = PatchEncoder(input_shape, input_position_shape, projection_dim)

        self.transformer_layer = nn.TransformerEncoderLayer(d_model=projection_dim, nhead=num_heads,
                                                            dim_feedforward = transformer_units, dropout=0.1, layer_norm_eps=1e-6)

        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, transformer_layers)

        # Create a [batch_size, projection_dim] tensor.
        self.norm_cls = nn.LayerNorm(projection_dim, eps=1e-6)
        self.feature_cls = nn.Sequential(
            nn.Linear(projection_dim, mlp_head_units),
            nn.Dropout(0.5)
        )
        # Classifier
        self.pos = nn.Linear(mlp_head_units, class_num)
        self.binary = nn.Linear(mlp_head_units, 1)

    def forward(self, inputs, inputs_positions):
        x = self.patch_encoder(inputs, inputs_positions)
        x = self.transformer_encoder(x)
        x = self.norm_cls(x)
        feature = self.feature_cls(self.norm_cls(x)[:,0,:])
        return {'pos': self.pos(feature), 'binary':self.binary(feature)}

In [8]:
import torch
input = torch.randn((10,50,2000))
position = torch.randn((10,50,2))

In [9]:
model = SCSTransformer(class_num = 16, input_shape = 2000, input_position_shape=2, projection_dim=512, num_heads=8, transformer_units=512, transformer_layers=6, mlp_head_units=512)

In [13]:
model(input, position)['pos'].shape

torch.Size([10, 16])

In [15]:
model(input, position)['binary'].shape

torch.Size([10, 1])

In [14]:
loss = nn.CrossEntropyLoss()

In [None]:
input = torch.randn(3, 5, requires_grad=True)

In [19]:
x,y = [1,2]