In [13]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view as sliding_window_view
import pickle
import math

In [113]:
 class MiLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        self.p_size = input_sz 
        self.n_size = input_sz 
        self.index_size = input_sz

        #f_t
        self.Wfh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wfy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bf = nn.Parameter(torch.Tensor(hidden_sz))
        
        #o_t
        self.Woh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Woy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bo = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_t
        self.Wch = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wcy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bc = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_pt
        self.Wcph = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wcpp = nn.Parameter(torch.Tensor(p_sz, hidden_sz))
        self.bcp = nn.Parameter(torch.Tensor(hidden_sz))

        #c_nt
        self.Wcnh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wcnn = nn.Parameter(torch.Tensor(n_sz, hidden_sz))
        self.bcn = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_it
        self.Wcih = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wcii = nn.Parameter(torch.Tensor(index_sz, hidden_sz))
        self.bci = nn.Parameter(torch.Tensor(hidden_sz))

        #i_t
        self.Wih = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wiy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bi = nn.Parameter(torch.Tensor(hidden_sz))
        
        #i_pt
        self.Wiph = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wipy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bip = nn.Parameter(torch.Tensor(hidden_sz))

        #c_nt
        self.Winh = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Winy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bin = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_it
        self.Wiih = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.Wiiy = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.bii = nn.Parameter(torch.Tensor(hidden_sz))

        #attn
        self.alpha_t = nn.Parameter(torch.Tensor(1))
        self.alpha_pt = nn.Parameter(torch.Tensor(1))
        self.alpha_nt = nn.Parameter(torch.Tensor(1))
        self.alpha_it = nn.Parameter(torch.Tensor(1))
        self.Wattn = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.ba = nn.Parameter(torch.Tensor(1))
        self.bap = nn.Parameter(torch.Tensor(1))
        self.ban = nn.Parameter(torch.Tensor(1))
        self.bai = nn.Parameter(torch.Tensor(1))

        self.init_weights()

    
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

        #c_pt
        nn.init.zeros_(self.Wcph)
        nn.init.zeros_(self.Wcpp)
        nn.init.zeros_(self.bcp)
        
        #c_nt
        nn.init.zeros_(self.Wcnh)
        nn.init.zeros_(self.Wcnn)
        nn.init.zeros_(self.bcn)
        
        #c_it
        nn.init.zeros_(self.Wcih)
        nn.init.zeros_(self.Wcii)
        nn.init.zeros_(self.bci)
        

    def forward(self, y_tilde, p_tilde, n_tilde, index_tilde, init_stats=None):
        batch_size, win_len, _ = y_tilde.shape
        hidden_seqs = []
        cell_states = []

        if init_stats is None:
            h_t, cell_t = (torch.zeros(batch_size, self.hidden_size).to(y_tilde.device), 
                        torch.zeros(batch_size, self.hidden_size).to(y_tilde.device))
        else:
            h_t, cell_t = init_states 

        
        for t in range(win_len):
            y_t = y_tilde[:, t, :]
            p_t = p_tilde[:, t, :]
            n_t = n_tilde[:, t, :]
            index_t = index_tilde[:, t, :]

            f_t = torch.sigmoid(y_t @ self.Wfy + h_t @ self.Wfh + self.bf)
            o_t = torch.sigmoid(y_t @ self.Woy + h_t @ self.Woh + self.bo)
            c_t = torch.tanh(y_t @ self.Wcy + h_t @ self.Wch + self.bc)
            c_pt = torch.tanh(p_t @ self.Wcpp + h_t @ self.Wcph + self.bcp)
            c_nt = torch.tanh(n_t @ self.Wcnn + h_t @ self.Wcnh + self.bcn)
            c_it = torch.tanh(index_t @ self.Wcii + h_t @ self.Wcph + self.bci)

            i_t = torch.sigmoid(y_t @ self.Wiy + h_t @ self.Wih + self.bi)
            i_pt = torch.sigmoid(y_t @ self.Wipy + h_t @ self.Wiph + self.bip)
            i_nt = torch.sigmoid(y_t @ self.Winy + h_t @ self.Winh + self.bin)
            i_it  = torch.sigmoid(y_t @ self.Wiiy + h_t @ self.Wiih + self.bii)

            l_t = torch.mul(c_t, i_t)
            l_pt = torch.mul(c_pt, i_pt)
            l_nt = torch.mul(c_nt, i_nt)
            l_it = torch.mul(c_it, i_it)
           
            u_t = torch.mul(l_t @ self.Wattn, cell_t).sum(dim=1)
            u_pt = torch.mul(l_pt @ self.Wattn, cell_t).sum(dim=1)
            u_nt = torch.mul(l_nt @ self.Wattn, cell_t).sum(dim=1)
            u_it = torch.mul(l_it @ self.Wattn, cell_t).sum(dim=1)

            alphas = torch.stack((u_t, u_pt, u_nt, u_it), dim=1)
            softmax = nn.Softmax(dim=1)
            probs = softmax(alphas)
            alpha_t, alpha_pt, alpha_nt, alpha_it = probs[:, 0], probs[:, 1], probs[:, 2], probs[:, 3]
            
            L_t = self.alpha_t*l_t + self.alpha_pt*l_pt + self.alpha_nt*l_nt + self.alpha_it*l_it

            cell_t = torch.mul(cell_t, f_t) + L_t
            h_t = torch.mul(torch.tanh(cell_t), o_t)
            
            hidden_seqs.append(h_t)
            cell_states.append(cell_t)

        hidden_seqs = torch.stack(hidden_seqs)
        hidden_seqs = hidden_seqs.transpose(0, 1).contiguous()
        return hidden_seqs, (h_t, cell_t)


In [115]:
y_tilde, p_tilde, n_tilde, index_tilde = torch.rand(512, 10, 64), torch.rand(512, 10, 64), torch.rand(512, 10, 64), torch.rand(512, 10, 64)

input_sz = 64
hidden_sz = 64
p_sz, n_sz, index_sz = 64, 64, 64
model = MiLSTM(input_sz, hidden_sz)
hidden_seqs, (h_t, cell_t) = model(y_tilde, p_tilde, n_tilde, index_tilde)
print(hidden_seqs.shape, h_t.shape, cell_t.shape)

torch.Size([512, 10, 64]) torch.Size([512, 64]) torch.Size([512, 64])
