In [None]:
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_size = 70
        self.hidden_size = 8
        self.num_layers = 2
        self.output_dim = 2
        self.lstm_feature_1 = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True,
                                      bidirectional=True)
        self.lstm_feature_2 = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True,
                                      bidirectional=True)
        self.fc_feature_1 = nn.Linear((self.hidden_size * 2) + 1, 1)
        self.fc_feature_2 = nn.Linear((self.hidden_size * 2) + 1, 1)

        self.fc = nn.Linear(4, self.output_dim)

    def forward(self, f, device=None):
        if not device:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        x_f1, f1, x_f2, f2, f3, f4 = f

        # x_f1 is feature_1_seq
        # f1 is feature_1_baseline
        # x_f2 is feature_2_seq
        # f2 is feature_2_baseline

        # f3 and f4 are tabular features

        x_f1 = x_f1.view(x_f1.shape[0], 1, -1)
        h0_f1, c0_f1 = self.init_hidden(x_f1, device)
        h_t_f1, c_t_f1 = self.lstm_feature_1(x_f1, (h0_f1, c0_f1))
        x_f1 = h_t_f1
        x_f1 = x_f1.view(x_f1.shape[0], -1)

        x_f2 = x_f2.view(x_f2.shape[0], 1, -1)
        h0_f2, c0_f2 = self.init_hidden(x_f2, device)
        h_t_f2, c_t_f2 = self.lstm_feature_2(x_f2, (h0_f2, c0_f2))
        x_f2 = h_t_f2
        x_f2 = x_f2.view(x_f2.shape[0], -1)

        x_f1 = torch.cat((x_f1, f1), 1)
        x_f1 = self.fc_feature_1(x_f1)

        x_f2 = torch.cat((x_f2, f2), 1)
        x_f2 = self.fc_feature_2(x_f2)

        x = torch.cat((x_f1, x_f2, f3, f4), 1)
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

    def init_hidden(self, x, device):
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size).to(device)
        return h0, c0