# LSTM implementation in Pytorch

In [2]:
import numpy as np
import torch 
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import random
import copy
import matplotlib.pyplot as plt
import pandas as pd


In [3]:
class LSTM(nn.Module):
    def __init__(self, cell_state_width, hidden_width, input_width):
        
        self.cell_state_width = cell_state_width
        self.hidden_width = hidden_width
        self.input_width = input_width
        
        self.W_forget = nn.Linear(in_features=input_width+hidden_width,
                                 out_features=cell_state_width)
        
        self.W_i_update = nn.Linear(in_features=input_width+hidden_width,
                                    out_features=cell_state_width)
        
        self.W_i_update_values = nn.Linear(in_features=input_width+hidden_width,
                                           out_features=cell_state_width)
        
        self.W_CS_to_output = nn.Linear(in_features=cell_state_width,
                                        out_features=hidden_width)
        
        self.W_combined_to_output = nn.Linear(in_features=input_width+hidden_width,
                                              out_features=hidden_width)
        
    def forward(self, X, cell_state=None, hidden_state=None):
        if cell_state is None:
            cell_state = self.get_blank_cell_state()
        if hidden_state is None:
            hidden_state = self.get_blank_hidden_state()
        
        combined_state = torch.concat(X, hidden_state)
        
        forget_factor = torch.sigmoid(self.W_forget(combined_state))
        new_cell_state = cell_state * forget_factor
        
        creation_gate_gating = torch.sigmoid(self.W_i_update(combined_state))
        creation_gate_values = torch.tanh(self.W_i_update_values(combined_state))
        
        new_cell_state += creation_gate_gating * creation_gate_values
        
        cell_state_output_contrib = torch.tanh(self.W_CS_to_output(new_cell_state))
        combined_state_output_contrib = torch.sigmoid(self.W_combined_to_output(combined_state))
        
        new_hidden_state = cell_state_output_contrib * combined_state_output_contrib
        
        return new_hidden_state, new_cell_state
        
    def get_blank_hidden_state(self, batchsize=1):
        return torch.zeros(batchsize, self.hidden_width)
    
    def get_blank_cell_state(self, batchsize=1):
        return torch.zeros(batchsize, self.cell_state_width)

Source(s):
- https://colah.github.io/posts/2015-08-Understanding-LSTMs/