In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

**LSTM**

In [104]:
N = 2
L = 32
D = 1
H_in = 64
H_cell = 128
H_out = H_cell
num_layers = 1

input = torch.rand(N, L, H_in) # N * L * D
h0 = torch.rand(D*num_layers, N, H_out)
c0 = torch.rand(D*num_layers, N, H_cell)

In [105]:
layer = nn.LSTM(H_in, H_cell, 1, batch_first=True, bidirectional=False)
torch_output, (torch_h_n, torch_c_n) = layer(input, (h0, c0))
print(torch_output.shape)
print(torch_h_n.shape)
print(torch_c_n.shape)
for k, v in layer.named_parameters():
    print(k)

torch.Size([2, 32, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
weight_ih_l0
weight_hh_l0
bias_ih_l0
bias_hh_l0


In [109]:
# we only take the initial_states of one layer
def LSTM_forward(input, initial_states, W_ih, W_hh, B_ih, B_hh):
    h_0, c_0 = initial_states
    
    L = input.shape[1]
    hidden_dim = W_ih.shape[0] // 4
    output = torch.zeros(input.shape[0], input.shape[1], W_ih.shape[0]//4)
    h_prev = h_0
    c_prev = c_0
    
    for t in range(0, L):
        x_t = input[:, t, :]
        
        sum_ = x_t @ W_ih.T + B_ih + h_prev @ W_hh.T + B_hh
        
        it = torch.sigmoid(sum_[:, 0 : hidden_dim])
        ft = torch.sigmoid(sum_[:, hidden_dim : 2*hidden_dim])
        gt = torch.tanh(sum_[:, 2*hidden_dim : 3*hidden_dim])
        ot = torch.sigmoid(sum_[:, 3*hidden_dim : ])

        c_prev = ft * c_prev + it * gt
        h_prev = ot * torch.tanh(c_prev)
        output[:, t, :] = h_prev
        
    return output, (torch.stack([h_prev]), torch.stack([c_prev]))


my_output, (my_h_n, my_c_n) = LSTM_forward(input, (h0[0], c0[0]), layer.weight_ih_l0, layer.weight_hh_l0, layer.bias_ih_l0, layer.bias_hh_l0)
# print(my_output)
# print(torch_output)
# print(my_h_n)
# print(torch_h_n)
# print(my_c_n)
# print(torch_c_n)
print(my_output.shape)
print(torch_output.shape)
print(my_h_n.shape)
print(torch_h_n.shape)
print(my_c_n.shape)
print(torch_c_n.shape)

torch.Size([2, 32, 128])
torch.Size([2, 32, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])
torch.Size([1, 2, 128])


**Bidirectional**

In [116]:
N = 2
L = 32
D = 2 # because it is bidirectional
H_in = 64
H_cell = 128
H_out = H_cell
num_layers = 1

input = torch.rand(N, L, H_in) # N * L * D
h0 = torch.rand(D*num_layers, N, H_out)
c0 = torch.rand(D*num_layers, N, H_cell)

In [117]:
layer = nn.LSTM(H_in, H_cell, num_layers, batch_first=True, bidirectional=True)
torch_output, (torch_h_n, torch_c_n) = layer(input, (h0, c0))
print(torch_output.shape)
print(torch_h_n.shape)
print(torch_c_n.shape)
for k, v in layer.named_parameters():
    print(k)

torch.Size([2, 32, 256])
torch.Size([2, 2, 128])
torch.Size([2, 2, 128])
weight_ih_l0
weight_hh_l0
bias_ih_l0
bias_hh_l0
weight_ih_l0_reverse
weight_hh_l0_reverse
bias_ih_l0_reverse
bias_hh_l0_reverse


In [124]:
# we only take the initial_states of one layer
def bidirectional_LSTM_forward(input, initial_states, initial_states_reverse, W_ih, W_hh, B_ih, B_hh, W_ih_reverse, W_hh_reverse, B_ih_reverse, B_hh_reverse):
    output, (h_prev, c_prev) = LSTM_forward(input, initial_states, W_ih, W_hh, B_ih, B_hh)
    output_reverse, (h_prev_reverse, c_prev_reverse) = LSTM_forward(torch.flip(input, [1]), initial_states_reverse, W_ih_reverse, W_hh_reverse, B_ih_reverse, B_hh_reverse)
        
    output_reverse = torch.flip(output_reverse, [1])
    return torch.concatenate((output, output_reverse), dim=2), (torch.cat([h_prev, h_prev_reverse], dim=0), torch.cat([c_prev, c_prev_reverse], dim=0))
        
my_output, (my_h_n, my_c_n) = bidirectional_LSTM_forward(input, \
                                                         (h0[0], c0[0]), (h0[1], c0[1]), \
                                                         layer.weight_ih_l0, layer.weight_hh_l0, layer.bias_ih_l0, layer.bias_hh_l0,\
                                                         layer.weight_ih_l0_reverse, layer.weight_hh_l0_reverse, layer.bias_ih_l0_reverse, layer.bias_hh_l0_reverse)
print(my_output)
print(torch_output)
print(my_h_n)
print(torch_h_n)
print(my_c_n)
print(torch_c_n)

print((my_c_n - torch_c_n)<0.000001)

print(my_output.shape)
print(torch_output.shape)
print(my_h_n.shape)
print(torch_h_n.shape)
print(my_c_n.shape)
print(torch_c_n.shape)

tensor([[[ 0.1741, -0.0336,  0.1421,  ..., -0.1452,  0.2346, -0.0763],
         [ 0.0464, -0.0366,  0.0995,  ..., -0.1819,  0.2252, -0.0486],
         [ 0.0178, -0.0497,  0.0449,  ..., -0.1323,  0.2441, -0.0690],
         ...,
         [ 0.0156, -0.1277,  0.0246,  ..., -0.0953,  0.2692, -0.0467],
         [ 0.0059, -0.1244,  0.0263,  ..., -0.0844,  0.2678, -0.0483],
         [ 0.0160, -0.0961, -0.0162,  ..., -0.0085,  0.2007, -0.0068]],

        [[ 0.2781, -0.0514,  0.0826,  ..., -0.1563,  0.2397, -0.0396],
         [ 0.0668, -0.0968,  0.0734,  ..., -0.1255,  0.2682, -0.0106],
         [ 0.0387, -0.1219,  0.0125,  ..., -0.1425,  0.2543, -0.0961],
         ...,
         [-0.0287, -0.0612,  0.0715,  ..., -0.1488,  0.2387,  0.0297],
         [ 0.0024, -0.0624,  0.0432,  ..., -0.1322,  0.2051,  0.0322],
         [ 0.0312, -0.1018,  0.0526,  ..., -0.0866,  0.2448,  0.1347]]],
       grad_fn=<CatBackward0>)
tensor([[[ 0.1741, -0.0336,  0.1421,  ..., -0.1452,  0.2346, -0.0763],
         [ 0.0

**Multilayers**

In [132]:
N = 2
L = 32
D = 1
H_in = 64
H_cell = 128
H_out = H_cell
num_layers = 3

input = torch.rand(N, L, H_in) # N * L * H_in
h0 = torch.rand(D*num_layers, N, H_out)
c0 = torch.rand(D*num_layers, N, H_cell)

In [133]:
layer = nn.LSTM(H_in, H_cell, num_layers, batch_first=True, bidirectional=False)
torch_output, (torch_h_n, torch_c_n) = layer(input, (h0, c0))
print(torch_output.shape)
print(torch_h_n.shape)
print(torch_c_n.shape)
for k, v in layer.named_parameters():
    print(k)

torch.Size([2, 32, 128])
torch.Size([3, 2, 128])
torch.Size([3, 2, 128])
weight_ih_l0
weight_hh_l0
bias_ih_l0
bias_hh_l0
weight_ih_l1
weight_hh_l1
bias_ih_l1
bias_hh_l1
weight_ih_l2
weight_hh_l2
bias_ih_l2
bias_hh_l2


In [140]:
# we only take the initial_states of one layer
def LSTM_forward_multilayers(input, initial_states, W_ih, W_hh, B_ih, B_hh):
    layers = len(W_ih)
    
    layer_outputs = []
    layer_h_n = []
    layer_c_n = []
    prev_output = input
    
    for layer in range(0, layers):
        prev_output, (h_n, c_n) = LSTM_forward(prev_output, (initial_states[0][layer], initial_states[1][layer]), W_ih[layer], W_hh[layer], B_ih[layer], B_hh[layer])
        layer_outputs.append(prev_output)
        layer_h_n.append(h_n)
        layer_c_n.append(c_n)
        
    return prev_output, (torch.concatenate(layer_h_n, dim=0), torch.concatenate(layer_c_n, dim=0))



W_ih = [layer.weight_ih_l0, layer.weight_ih_l1, layer.weight_ih_l2]
W_hh = [layer.weight_hh_l0, layer.weight_hh_l1, layer.weight_hh_l2]
B_ih = [layer.bias_ih_l0, layer.bias_ih_l1, layer.bias_ih_l2]
B_hh = [layer.bias_hh_l0, layer.bias_hh_l1, layer.bias_hh_l2]

my_output, (my_h_n, my_c_n) = LSTM_forward_multilayers(input, (h0, c0), W_ih, W_hh, B_ih, B_hh)
# print(my_output)
# print(torch_output)
# print(my_h_n)
# print(torch_h_n)
# print(my_c_n)
# print(torch_c_n)

print((my_output - torch_output)<0.000001)
print((my_h_n - torch_h_n)<0.000001)
print((my_c_n - torch_c_n)<0.000001)

# print(my_output.shape)
# print(torch_output.shape)
# print(my_h_n.shape)
# print(torch_h_n.shape)
# print(my_c_n.shape)
# print(torch_c_n.shape)

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])
tensor([[[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, 

**Bidirectional Multilayer**

In [161]:
N = 2
L = 32
D = 2
H_in = 64
H_cell = 128
H_out = H_cell
num_layers = 3

input = torch.rand(N, L, H_in) # N * L * H_in
h0 = torch.rand(D*num_layers, N, H_out)
c0 = torch.rand(D*num_layers, N, H_cell)

In [162]:
layer = nn.LSTM(H_in, H_cell, num_layers, batch_first=True, bidirectional=True)
torch_output, (torch_h_n, torch_c_n) = layer(input, (h0, c0))
print(torch_output.shape)
print(torch_h_n.shape)
print(torch_c_n.shape)
for k, v in layer.named_parameters():
    print(k)
    
print(layer.weight_ih_l0.shape)

torch.Size([2, 32, 256])
torch.Size([6, 2, 128])
torch.Size([6, 2, 128])
weight_ih_l0
weight_hh_l0
bias_ih_l0
bias_hh_l0
weight_ih_l0_reverse
weight_hh_l0_reverse
bias_ih_l0_reverse
bias_hh_l0_reverse
weight_ih_l1
weight_hh_l1
bias_ih_l1
bias_hh_l1
weight_ih_l1_reverse
weight_hh_l1_reverse
bias_ih_l1_reverse
bias_hh_l1_reverse
weight_ih_l2
weight_hh_l2
bias_ih_l2
bias_hh_l2
weight_ih_l2_reverse
weight_hh_l2_reverse
bias_ih_l2_reverse
bias_hh_l2_reverse
torch.Size([512, 64])


In [151]:
# we only take the initial_states of one layer
def bidirectional_LSTM_forward_multilayers(input, initial_states, initial_states_reverse, W_ih, W_hh, B_ih, B_hh, W_ih_reverse, W_hh_reverse, B_ih_reverse, B_hh_reverse):
    layers = len(W_ih)
    
    layer_outputs = []
    layer_h_n = []
    layer_c_n = []
    prev_output = input
    
    for layer in range(0, layers):
        prev_output, (h_n, c_n) = bidirectional_LSTM_forward(prev_output, (initial_states[0][layer], initial_states[1][layer]), (initial_states_reverse[0][layer], initial_states_reverse[1][layer]), W_ih[layer], W_hh[layer], B_ih[layer], B_hh[layer], W_ih_reverse[layer], W_hh_reverse[layer], B_ih_reverse[layer], B_hh_reverse[layer])
        layer_outputs.append(prev_output)
        layer_h_n.append(h_n)
        layer_c_n.append(c_n)
        
    return prev_output, (torch.concatenate(layer_h_n, dim=0), torch.concatenate(layer_c_n, dim=0))



W_ih = [layer.weight_ih_l0, layer.weight_ih_l1, layer.weight_ih_l2]
W_hh = [layer.weight_hh_l0, layer.weight_hh_l1, layer.weight_hh_l2]
B_ih = [layer.bias_ih_l0, layer.bias_ih_l1, layer.bias_ih_l2]
B_hh = [layer.bias_hh_l0, layer.bias_hh_l1, layer.bias_hh_l2]
W_ih_reverse = [layer.weight_ih_l0_reverse, layer.weight_ih_l1_reverse, layer.weight_ih_l2_reverse]
W_hh_reverse = [layer.weight_hh_l0_reverse, layer.weight_hh_l1_reverse, layer.weight_hh_l2_reverse]
B_ih_reverse = [layer.bias_ih_l0_reverse, layer.bias_ih_l1_reverse, layer.bias_ih_l2_reverse]
B_hh_reverse = [layer.bias_hh_l0_reverse, layer.bias_hh_l1_reverse, layer.bias_hh_l2_reverse]

my_output, (my_h_n, my_c_n) = bidirectional_LSTM_forward_multilayers(input, (h0[0::2], c0[0::2]), (h0[1::2], c0[1::2]), W_ih, W_hh, B_ih, B_hh, W_ih_reverse, W_hh_reverse, B_ih_reverse, B_hh_reverse)
# print(my_output)
# print(torch_output)
# print(my_h_n)
# print(torch_h_n)
# print(my_c_n)
# print(torch_c_n)

print((my_output - torch_output)<0.000001)
print((my_h_n - torch_h_n)<0.000001)
print((my_c_n - torch_c_n)<0.000001)

# print(my_output.shape)
# print(torch_output.shape)
# print(my_h_n.shape)
# print(torch_h_n.shape)
# print(my_c_n.shape)
# print(torch_c_n.shape)

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        

In [164]:
for a in W_ih:
    print(a.shape)
    
for a in B_ih:
    print(a.shape)

torch.Size([512, 64])
torch.Size([512, 256])
torch.Size([512, 256])
torch.Size([512])
torch.Size([512])
torch.Size([512])


**Module**

In [165]:
class LSTM(nn.Module):
    def __init__(self, dims, hidden_dims, num_layers, bidirectional=False):
        super(LSTM, self).__init__()
        
        self.bidirectional = bidirectional
        
        self.W_ih = []
        self.W_hh = []
        self.B_ih = []
        self.B_hh = []
        
        hidden_dims = 4 * hidden_dims
        
        if bidirectional:
            self.W_ih_reverse = []
            self.W_hh_reverse = []
            self.B_ih_reverse = []
            self.B_hh_reverse = []
            for layer in range(0, num_layers):
                if layer == 0:
                    self.W_ih.append(torch.rand(4 * hidden_dims, dims))
                    self.W_ih_reverse.append(torch.rand(4 * hidden_dims, dims))
                else:
                    self.W_ih.append(torch.rand(4 * hidden_dims, 2 * hidden_dims))
                    self.W_ih_reverse.append(torch.rand(4 * hidden_dims, 2 * hidden_dims))
                self.W_hh.append(torch.rand(4 * hidden_dims, hidden_dims))
                self.W_hh_reverse.append(torch.rand(4 * hidden_dims, hidden_dims))
                self.B_ih.append(torch.rand(4 * hidden_dims, ))
                self.B_ih_reverse.append(torch.rand(4 * hidden_dims, ))
                self.B_hh.append(torch.rand(4 * hidden_dims, ))
                self.B_hh_reverse.append(torch.rand(4 * hidden_dims, ))
            self.W_ih_reverse = nn.ParameterList(self.W_ih_reverse)
            self.W_hh_reverse = nn.ParameterList(self.W_hh_reverse)
            self.B_ih_reverse = nn.ParameterList(self.B_ih_reverse)
            self.B_hh_reverse = nn.ParameterList(self.B_hh_reverse)
        else:
            for layer in range(0, num_layers):
                if layer == 0:
                    self.W_ih.append(torch.rand(4 * hidden_dims, dims))
                else:
                    self.W_ih.append(torch.rand(4 * hidden_dims, hidden_dims))
                self.W_hh.append(torch.rand(4 * hidden_dims, hidden_dims))
                self.B_ih.append(torch.rand(4 * hidden_dims, ))
                self.B_hh.append(torch.rand(4 * hidden_dims, ))
                
        self.W_ih = nn.ParameterList(self.W_ih)
        self.W_hh = nn.ParameterList(self.W_hh)
        self.B_ih = nn.ParameterList(self.B_ih)
        self.B_hh = nn.ParameterList(self.B_hh)
                        
    def forward(self, x):
        if self.bidirectional:
            h0 = torch.zeros(2*len(self.W_ih), x.shape[0], self.W_ih[0].shape[0]//4)
            c0 = torch.zeros(2*len(self.W_ih), x.shape[0], self.W_ih[0].shape[0]//4)
            return bidirectional_LSTM_forward_multilayers(x, (h0[0::2], c0[0::2]), (h0[1::2], c0[1::2]), self.W_ih, self.W_hh, self.B_ih, self.B_hh, self.W_ih_reverse, self.W_hh_reverse, self.B_ih_reverse, self.B_hh_reverse)
        else:
            h0 = torch.zeros(len(self.W_ih), x.shape[0], self.W_ih[0].shape[0]//4)
            c0 = torch.zeros(len(self.W_ih), x.shape[0], self.W_ih[0].shape[0]//4)

            return LSTM_forward_multilayers(x, (h0, c0), self.W_ih, self.W_hh, self.B_ih, self.B_hh)

         

In [167]:
input = torch.rand(2, 32, 128) # N * L * D
layer = LSTM(128, 256, 3, bidirectional=True)
my_output, (my_h_n, my_c_n) = layer(input)
print(my_output.shape)
print(my_h_n.shape)
print(my_c_n.shape)

for k, v in layer.named_parameters():
    print(k)

torch.Size([2, 32, 2048])
torch.Size([6, 2, 1024])
torch.Size([6, 2, 1024])
W_ih_reverse.0
W_ih_reverse.1
W_ih_reverse.2
W_hh_reverse.0
W_hh_reverse.1
W_hh_reverse.2
B_ih_reverse.0
B_ih_reverse.1
B_ih_reverse.2
B_hh_reverse.0
B_hh_reverse.1
B_hh_reverse.2
W_ih.0
W_ih.1
W_ih.2
W_hh.0
W_hh.1
W_hh.2
B_ih.0
B_ih.1
B_ih.2
B_hh.0
B_hh.1
B_hh.2
