In [2]:
import torch
import torch.nn as nn
from conformer import Conformer
from components import signalpreprocess, cdecoder
from tqdm import tqdm


sequence_length, dim = 50, 100
cuda = torch.cuda.is_available()  
device = torch.device('cuda:0' if cuda else 'cpu')

class ConformerModel(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.require_pooling = False
        self.signalpreprocess = signalpreprocess.SignalPreprocess(blocks=1)
        self.encoder = Conformer(
                    input_dim=dim, 
                    encoder_dim=32, 
                    num_encoder_layers=4,
                    need_fc=False)
        self.decoder = Decoder()
        self.componentsnorm = ComponentsNormalization()
        
    def forward(self, x):
        if self.require_pooling:
            x = self.signalpreprocess(x).to(device)
        else:
            x = x.view(-1, 1, 50, 100).to(device)
        x = self.encoder(x)
        x = self.decoder(x)
        x = self.componentsnorm(x)
        return x
    
    def encode(self, x):
        if self.require_pooling:
            x = self.signalpreprocess(x).to(device)
        else:
            x = x.view(-1, 1, 50, 100).to(device)     
        x = self.encoder(x)
        return x
    
    def last_state(self, x):
        x = self.signalpreprocess(x).to(device)
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        
class Decoder(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.decoder = cdecoder.cDecoder()
        self.linear1 = nn.Linear(in_features=320, out_features=8)

        
    def forward(self, x):
        x = self.decoder(x)
        x = x.view(-1, 320)
        x = self.linear1(x)
        x = torch.sigmoid(x)
        return x
    
class ComponentsNormalization(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
    def forward(self, x):
        res = []
        _sum = 1 / (torch.sum(x, -1) + 1e-5)
        for i in range(x.shape[0]):
            res.append(x[i] * _sum[i])
        return torch.vstack(res)
        

In [3]:
from torchinfo import summary
model = ConformerModel().to(device)
print(summary(model))
sum([param.nelement() for param in model.parameters()])

Layer (type:depth-idx)                                                      Param #
ConformerModel                                                              --
├─SignalPreprocess: 1-1                                                     --
│    └─ModuleList: 2-1                                                      --
│    │    └─SignalPooling: 3-1                                              --
├─Conformer: 1-2                                                            --
│    └─ConformerEncoder: 2-2                                                --
│    │    └─Conv2dSubampling: 3-2                                           9,568
│    │    └─Sequential: 3-3                                                 24,608
│    │    └─ModuleList: 3-4                                                 106,072
│    └─Linear: 2-3                                                          --
│    │    └─Linear: 3-5                                                     2,816
├─Decoder: 1-3                  

155544

In [4]:
x = torch.rand(1, 5000).to(device)
model(x).shape

torch.Size([1, 8])