In [1]:
import torch
from torch import nn
from torchinfo import summary
from positional_encodings.torch_encodings import PositionalEncoding1D
from einops import rearrange, repeat

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
class Classification_block(nn.Module):
    def __init__(self):
        super(Classification_block, self).__init__()
        self.linear_1 = nn.Linear(768, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(256, 2)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        return self.softmax(self.linear_2(self.dropout(self.relu(self.linear_1(x)))))
    
class Transformer(nn.Module):
    def __init__(self,d_in=1365,d_model=768,nhead=6,num_layers=3):
        super(Transformer, self).__init__()
        self.linear_projection = nn.Linear(d_in, 768)
        self.positionEncoding=PositionalEncoding1D(d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.classification_block = Classification_block()
        self.cls_token= nn.Parameter(torch.randn(1,d_model))
    def forward(self, x):
        x= rearrange(x, 'a b c d -> a d (b c)')
        x = self.linear_projection(x)
        cls_tokens = repeat(self.cls_token, 'n d -> b n d', b=x.size(0))
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.positionEncoding(x)
        x = self.transformer_encoder(x)
        x = self.classification_block(x[:,0,:])
        return x

In [4]:
summary(Transformer(d_in=65*22,nhead=6,num_layers=3), input_size=(10, 22, 65,9))

  return torch._transformer_encoder_layer_fwd(


Layer (type:depth-idx)                        Output Shape              Param #
Transformer                                   [10, 2]                   5,514,752
├─Linear: 1-1                                 [10, 9, 768]              1,099,008
├─PositionalEncoding1D: 1-2                   [10, 10, 768]             --
├─TransformerEncoder: 1-3                     [10, 10, 768]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─TransformerEncoderLayer: 3-1      [10, 10, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-2      [10, 10, 768]             5,513,984
│    │    └─TransformerEncoderLayer: 3-3      [10, 10, 768]             5,513,984
├─Classification_block: 1-4                   [10, 2]                   --
│    └─Linear: 2-2                            [10, 256]                 196,864
│    └─ReLU: 2-3                              [10, 256]                 --
│    └─Dropout: 2-4                           [10, 256]

In [15]:
summary(Transformer(d_in=65*22,nhead=8,num_layers=8), input_size=(10, 22, 65,21))