In [189]:
# Third Party
import torch
import torch.nn as nn


############
# COMPONENTS
############



class LSTM(nn.Module):
    def __init__(self, input_dim, out_dim=1, h_dims=[], h_activ=nn.Tanh(),
                 out_activ=nn.Sigmoid()):
        super(LSTM, self).__init__()
        
        layer_dims = [input_dim] + h_dims + [out_dim]
        self.num_layers = len(layer_dims) - 1
        self.layers = nn.ModuleList()
        
        for index in range(self.num_layers-1):
            layer = nn.LSTM(
                input_size=layer_dims[index],
                hidden_size=layer_dims[index + 1],
                num_layers=1,
                batch_first=True
            )
            self.layers.append(layer)

        self.h_activ, self.out_activ = h_activ, out_activ
        
        self.linear=nn.Linear(layer_dims[-2],out_dim)
        
    def forward(self, x):
        
 
        for index, layer in enumerate(self.layers):
          
            x, (h_n, c_n) = layer(x)
            x = self.h_activ(x)
            
        x=self.linear(x[:,-1])
        print(x.shape)
        x=self.out_activ(x)
        print(x.shape)
                

        return x
    


In [194]:
m=LSTM(58, 1, h_dims=[16,13,5,3], h_activ=nn.Tanh(),
                 out_activ=nn.Sigmoid())

In [195]:
m

LSTM(
  (layers): ModuleList(
    (0): LSTM(58, 16, batch_first=True)
    (1): LSTM(16, 13, batch_first=True)
    (2): LSTM(13, 5, batch_first=True)
    (3): LSTM(5, 3, batch_first=True)
  )
  (h_activ): Tanh()
  (out_activ): Sigmoid()
  (linear): Linear(in_features=3, out_features=1, bias=True)
)

In [196]:
x=torch.randn(32,16,58)

In [197]:
m(x).shape

torch.Size([32, 1])
torch.Size([32, 1])


torch.Size([32, 1])