In [1]:
import torch
import torch.nn as nn


class LSTMNet(nn.Module):
    def __init__(self, linear_input_dim, linear_hidden_dim, linear_output_dim, hidden_dim, output_dim, num_layers, fc_layers):
        super(LSTMNet, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Fully connected layers for input
        self.fc_layers = nn.ModuleList()
        for i in range(fc_layers):
            if i == 0:
                self.fc_layers.append(
                    nn.Linear(linear_input_dim, linear_hidden_dim))
            elif i != fc_layers-1:
                self.fc_layers.append(
                    nn.Linear(linear_hidden_dim, linear_hidden_dim))
            else:
                self.fc_layers.append(
                    nn.Linear(linear_hidden_dim, linear_output_dim))

        # LSTM layer 정의
        self.lstm = nn.LSTM(linear_output_dim, hidden_dim,
                            num_layers, batch_first=True)

        # 출력을 위한 fully connected layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Pass input through Fully Connected Layers
        for layer in self.fc_layers:
            x = layer(x)

        # LSTM의 초기 hidden, cell state
        h0 = torch.zeros(self.num_layers, x.size(0),
                         self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0),
                         self.hidden_dim).to(x.device)

        out, _ = self.lstm(x, (h0, c0))

        out = self.fc(out)

        return out


# 사용 예제
linear_input_dim = 8
linear_hidden_dim = 15
linear_output_dim = 10
hidden_dim = 20
output_dim = 3
num_layers = 2
fc_layers = 2  # number of fully connected layers before LSTM
model = LSTMNet(linear_input_dim, linear_hidden_dim, linear_output_dim, hidden_dim,
                output_dim, num_layers, fc_layers)

# 무작위로 생성한 32개의 시퀀스 데이터. 각 시퀀스의 길이는 5, 각 시퀀스의 요소는 10차원입니다.
x = torch.randn(32, 5, linear_input_dim)

output = model(x)
