## BASIC LSTM USING PYTORCH

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim 

In [2]:
class TimeSeriesRNN(nn.Module):
    def __init__(self, n_lstm_layers, n_input, n_hidden, n_output):
        super(TimeSeriesRNN, self).__init()
        
        self.n_lstm_layers = n_lstm_layers
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_output = n_output
        
        self.LSTM_layer = nn.LSTM(self.n_input, self.n_hidden, self.n_lstm_layers, batch_first = True)
        
        self.fc_layers = nn.Sequential(
            nn.Linear(self.n_hidden, int(self.n_hidden*2)),
            nn.ReLU(),
            nn.Linear(int(self.n_hidden*2), self.n_output)
        )
        
    def forward(self, x):
        # initialize hidden and cell states to zeros(0)
        h0 = torch.zeros(self.n_lstm_layers, x.shape[0], self.n_hidden).requires_grad_()
        c0 = torch.zeros(self.n_lstm_layers, x.shape[0], self.n_hidden).requires_grad_()
        
        output, (hn, cn) = self.LSTM_layer(x, (h0.detatch(), c0.detatch()))
        
        r"""
            index hidden state of last time step
            out.size() ---> batch_size, last_seq_val, n_hidden
            out[:, -1, :] ---> get the last sequence / time step
        """
        output = output[:, -1, :]
        output = self.fc_layers(output)
        
        return output