In [1]:
import torch
import torch.nn as nn

In [11]:
class FE(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 dilation,
                 padding,
                 mp_kernel_size,
                 dropout,
                 num_conv_layers
                ):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels = in_channels,
                               out_channels = out_channels,
                               kernel_size = kernel_size,
                               stride = 1,
                               dilation = dilation,
                               padding = padding
                               )
        self.mp = nn.MaxPool1d(mp_kernel_size)
        self.dropout1 = nn.Dropout(dropout)
        self.convlist = nn.ModuleList([nn.Conv1d(in_channels = out_channels,
                                               out_channels = out_channels,
                                               kernel_size = kernel_size,
                                               stride = 1,
                                               dilation = dilation,
                                                padding = padding
                                               ) for _ in range(num_conv_layers)])
    def forward(self, x):
        x = self.conv1(x)
        x = self.mp(x)
        x = self.dropout1(x)
        for i in range(len(self.convlist)):
            x = self.convlist[i](x)
    
        return x

class IW_FE(nn.Module):
    def __init__(self,):
        super().__init__()
        self.fe1 = fe1 = FE(in_channels = 2, 
                            out_channels = 128, 
                            kernel_size = 5,
                            dilation = 1,
                            padding = 2,
                            mp_kernel_size = 2,
                            dropout = 0.5,
                            num_conv_layers = 3)
        self.fe2 = fe1 = FE(in_channels = 2, 
                        out_channels = 128, 
                        kernel_size = 5,
                        dilation = 4,
                        padding = 8, # dilation *(kernel_size - 1) // 2
                        mp_kernel_size = 2,
                        dropout = 0.5,
                        num_conv_layers = 3)
    def forward(self, x):
        left = self.fe1(x)
        right = self.fe2(x)
        
fe1 = FE(in_channels = 2, 
        out_channels = 128, 
        kernel_size = 5,
        dilation = 1,
        padding = 2,
        mp_kernel_size = 2,
        dropout = 0.5,
        num_conv_layers = 3)
fe2 = FE(in_channels = 2, 
        out_channels = 128, 
        kernel_size = 5,
        dilation = 4,
        padding = 8, # dilation *(kernel_size - 1) // 2
        mp_kernel_size = 2,
        dropout = 0.5,
        num_conv_layers = 3)

In [16]:
a = nn.LSTM(input_size=128 * 120,hidden_size=100,num_layers=1, batch_first=True,bidirectional=True)
total_params = sum(p.numel() for p in a.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 12369600


In [4]:
total_params = sum(p.numel() for p in a.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 5099200


In [5]:
total_params = sum(p.numel() for p in fe1.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 252032


In [23]:
x = torch.randn(1,2,120)
u1 = torch.cat([fe2(x), fe1(x)], dim = 1)
# f1 = u1.flatten(1)
# d1, _ = a(f1)

In [24]:
u1.shape

torch.Size([1, 256, 60])