In [1]:
import torch
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import math

In [None]:
class LSTM(torch.nn.Module):
    def __init__(self,input_dim,hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.Ui = torch.nn.Parameter(torch.Tensor(input_dim,hidden_dim))
        self.Vi = torch.nn.Parameter(torch.Tensor(hidden_dim,hidden_dim))
        self.bi = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uf = torch.nn.Parameter(torch.Tensor(input_dim,hidden_dim))
        self.Vf = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bf = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uc = torch.nn.Parameter(torch.Tensor(input_dim,hidden_dim))
        self.Vc = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bc = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.Uo = torch.nn.Parameter(torch.Tensor(input_dim,hidden_dim))
        self.Vo = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.bo = torch.nn.Parameter(torch.Tensor(hidden_dim))

        self.linear = torch.nn.Linear(hidden_dim, 2)


        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, init_states=None):
        bs, seq_sz = x.size()
        hidden_seq = []

        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size).to(x.device),
                torch.zeros(bs, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = init_states

        for t in range(seq_sz):
            x_t = x[:,t,:]

            i_t = torch.sigmoid(x_t @ self.Ui + h_t @ self.Vi + self.bi)
            f_t = torch.sigmoid(x_t @ self.Uf + h_t @ self.Vf + self.bf)
            o_t = torch.sigmoid(x_t @ self.Uo + h_t @ self.Vo + self.bo)
            g_t = torch.tanh(x_t @ self.Uc + h_t @ self.Vc + self.bc)

            c_t = f_t * c_t + i_t * g_t 
            h_t = o_t * torch.tanh(c_t )

            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim = 0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        logits = self.linear(h_t)  # last timestep hidden state

        return logits, hidden_seq, (h_t, c_t)








