In [1]:
import torch
import torch.nn as nn

In [28]:
bs, T, i_size, h_size = 2, 3, 4, 5
proj_size = 3
inputs = torch.randn(bs, T, i_size)
h_0 = torch.randn(bs, h_size)
c_0 = torch.randn(bs, h_size)
#h_0 = torch.randn(bs, proj_size) #proj是对h压缩输出

lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
output, (hn, cn) = lstm_layer(inputs, (h_0.unsqueeze(0), c_0.unsqueeze(0)))

print(output)
print(hn)
print(cn)

tensor([[[-0.2335,  0.0344, -0.2780,  0.5216,  0.0529],
         [-0.2386,  0.3331, -0.3291,  0.3712, -0.3144],
         [-0.1060,  0.0798, -0.2508,  0.0960, -0.2757]],

        [[ 0.0434,  0.0873, -0.0879,  0.1100,  0.5907],
         [-0.0907,  0.0738,  0.1541,  0.0971,  0.2843],
         [-0.2087,  0.2055,  0.2546,  0.2179, -0.1388]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.1060,  0.0798, -0.2508,  0.0960, -0.2757],
         [-0.2087,  0.2055,  0.2546,  0.2179, -0.1388]]],
       grad_fn=<StackBackward0>)
tensor([[[-0.2106,  0.1742, -0.4524,  0.1758, -0.5244],
         [-0.5820,  0.2689,  0.3088,  0.3347, -0.1883]]],
       grad_fn=<StackBackward0>)


In [29]:
print(output.shape)
print(hn.shape)
print(cn.shape)
for k,v in lstm_layer.named_parameters():
    print(k, v.shape) #4*5,4  4*5,5 四个矩阵拼接在一起的

torch.Size([2, 3, 5])
torch.Size([1, 2, 5])
torch.Size([1, 2, 5])
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])


In [31]:
#自己写一个
def lstm_forward(inputs, initial_states, w_ih, w_hh, b_ih, b_hh, w_rh=None):
    h_0, c_0 = initial_states #初始状态
    bs, T, input_size = inputs.shape
    h_size = w_ih.shape[0] // 4
    prev_h, prev_c = h_0, c_0 # 每次都要传入之前的细胞和隐藏状态
    
    
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)  #扩充维度后复制,bs,4*hidden_size, input_size
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)  #bs,4*hidden_size, hidden_size
    batch_b_ih = b_ih.unsqueeze(0).tile(bs, 1)  #bs,h_size
    batch_b_hh = b_hh.unsqueeze(0).tile(bs, 1)  #bs,h_size
    
    if w_rh is not None:
        p_size = w_rh.shape[0]  
        output_size = p_size
        batch_w_rh = w_rh.unsqueeze(0).tile(bs, 1, 1)  # bs,p_size,h_size
    else:
        output_size = h_size
    output = torch.zeros(bs, T, output_size)
        
    
    
    for t in range(T):
        x = inputs[:, t, :]  #当前时刻的输入向量 bs*input_size
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(2))  #bs,4*h_size, 1
        w_times_x = w_times_x.squeeze(2)  #bs,4*h_size
        
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  #bs,4*h_size,h_size  bs,h_size,1 = bs,4*h_size,1
        #如果是projection,bs,4*h_size,p_size  bs,p_size,1 = bs,4*h_size,1
        w_times_h_prev = w_times_h_prev.squeeze(-1)  #bs,4*h_size
        
        #分别计算书入门i，遗忘门f，cell门g，输出门o
        i_t = torch.sigmoid(w_times_x[:, :1*h_size] + batch_b_ih[:, :1*h_size] + \
                          w_times_h_prev[:, :1*h_size] + batch_b_hh[:, :1*h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + batch_b_ih[:, h_size:2*h_size] + \
                          w_times_h_prev[:, h_size:2*h_size] + batch_b_hh[:, h_size:2*h_size]) 
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + batch_b_ih[:, 2*h_size:3*h_size] + \
                          w_times_h_prev[:, 2*h_size:3*h_size] + batch_b_hh[:, 2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + batch_b_ih[:, 3*h_size:] + \
                          w_times_h_prev[:, 3*h_size:] + batch_b_hh[:, 3*h_size:])
        prev_c = f_t*prev_c + i_t*g_t #bs,h_size
        prev_h = o_t*torch.tanh(prev_c) #bs,h_size
        
        if w_rh is not None:
            prev_h = torch.bmm(batch_w_rh, prev_h.unsqueeze(-1)).squeeze(2) #bs,p_size,h_size * bs,h_size,1 = bs,p_size,1
        output[:, t, :] = prev_h
        
    return output, (prev_h, prev_c)

forward_output, (forward_hn, forward_cn) = lstm_forward(inputs, (h_0,c_0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0, \
             lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)
print(forward_output)
print(forward_hn)
print(forward_cn)

tensor([[[-0.2335,  0.0344, -0.2780,  0.5216,  0.0529],
         [-0.2386,  0.3331, -0.3291,  0.3712, -0.3144],
         [-0.1060,  0.0798, -0.2508,  0.0960, -0.2757]],

        [[ 0.0434,  0.0873, -0.0879,  0.1100,  0.5907],
         [-0.0907,  0.0738,  0.1541,  0.0971,  0.2843],
         [-0.2087,  0.2055,  0.2546,  0.2179, -0.1388]]], grad_fn=<CopySlices>)
tensor([[-0.1060,  0.0798, -0.2508,  0.0960, -0.2757],
        [-0.2087,  0.2055,  0.2546,  0.2179, -0.1388]], grad_fn=<MulBackward0>)
tensor([[-0.2106,  0.1742, -0.4524,  0.1758, -0.5244],
        [-0.5820,  0.2689,  0.3088,  0.3347, -0.1883]], grad_fn=<AddBackward0>)


In [32]:
h_0 = torch.randn(bs, proj_size) #proj是对h压缩输出
pro_lstm_layer = nn.LSTM(i_size, h_size, batch_first=True,proj_size=proj_size)
output, (hn, cn) = pro_lstm_layer(inputs, (h_0.unsqueeze(0), c_0.unsqueeze(0)))

print(output)
print(hn)
print(cn)

tensor([[[ 0.0423, -0.0554, -0.0014],
         [ 0.0168,  0.0114, -0.0809],
         [ 0.0308,  0.0179, -0.0685]],

        [[ 0.2101,  0.0282, -0.1115],
         [ 0.2058, -0.0541, -0.1189],
         [ 0.2325, -0.0897, -0.1474]]], grad_fn=<TransposeBackward0>)
tensor([[[ 0.0308,  0.0179, -0.0685],
         [ 0.2325, -0.0897, -0.1474]]], grad_fn=<StackBackward0>)
tensor([[[ 0.0141, -0.2495, -0.5108, -0.1488, -0.0325],
         [-0.3937,  0.1852, -0.2235, -0.5974,  0.3615]]],
       grad_fn=<StackBackward0>)


In [33]:
forward_output, (forward_hn, forward_cn) = lstm_forward(inputs, (h_0,c_0), pro_lstm_layer.weight_ih_l0, pro_lstm_layer.weight_hh_l0, \
             pro_lstm_layer.bias_ih_l0, pro_lstm_layer.bias_hh_l0, pro_lstm_layer.weight_hr_l0)
print(forward_output)
print(forward_hn)
print(forward_cn)

tensor([[[ 0.0423, -0.0554, -0.0014],
         [ 0.0168,  0.0114, -0.0809],
         [ 0.0308,  0.0179, -0.0685]],

        [[ 0.2101,  0.0282, -0.1115],
         [ 0.2058, -0.0541, -0.1189],
         [ 0.2325, -0.0897, -0.1474]]], grad_fn=<CopySlices>)
tensor([[ 0.0308,  0.0179, -0.0685],
        [ 0.2325, -0.0897, -0.1474]], grad_fn=<SqueezeBackward1>)
tensor([[ 0.0141, -0.2495, -0.5108, -0.1488, -0.0325],
        [-0.3937,  0.1852, -0.2235, -0.5974,  0.3615]], grad_fn=<AddBackward0>)


In [40]:
print(lstm_layer.weight_ih_l0.shape, lstm_layer.weight_hh_l0.shape)
print(pro_lstm_layer.weight_ih_l0.shape, pro_lstm_layer.weight_hh_l0.shape,pro_lstm_layer.weight_hr_l0.shape)

torch.Size([20, 4]) torch.Size([20, 5])
torch.Size([20, 4]) torch.Size([20, 3]) torch.Size([3, 5])


In [None]:
project是对隐状态作了投影，维度改变，可以使得W_hh矩阵参数减少
原本的4*h_size,h_size变成4*h_size,p_size，但也多了一个p_size,h_size的投影矩阵