In [39]:
import torch
import torch.nn as nn

# 定义一些常量
batch_size, seq_len, input_size, h_size = 100, 1, 6, 32
input_6666 = torch.randn(batch_size, seq_len, input_size)  # 随机初始化一个输入序列
c_0 = torch.randn(batch_size, h_size)  # 初始值，不会参与训练
h_0 = torch.randn(batch_size, h_size)


In [40]:
# 调用官方 LSTM API
lstm_layer = nn.LSTM(input_size, h_size, batch_first=True)  # num_layers默认为1
output, (h_n, c_n) = lstm_layer(input_6666, (h_0.unsqueeze(0), c_0.unsqueeze(0)))  # (D*num_layers=1, b, hidden_size)


In [41]:
print(output.shape)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n.shape)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n.shape)  # [1,2,5] [num_layers, b, hidden_size]


torch.Size([100, 1, 32])
torch.Size([1, 100, 32])
torch.Size([1, 100, 32])


In [42]:
for name, para in lstm_layer.named_parameters():
    print(name, para.shape)


weight_ih_l0 torch.Size([128, 6])
weight_hh_l0 torch.Size([128, 32])
bias_ih_l0 torch.Size([128])
bias_hh_l0 torch.Size([128])


In [43]:
def lstm_forward(input_6666, initial_states, w_ih, w_hh, b_ih, b_hh):
    h_0, c_0 = initial_states  # 初始状态  [b_size, hidden_size]
    b_size, seq_len, input_size = input_6666.shape
    h_size = h_0.shape[-1]

    h_prev, c_prev = h_0, c_0
    # 需要将权重w在batch_size维进行扩维并复制，才能和x与h进行相乘
    w_ih_batch = w_ih.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, in_size]->[b_size, ,]
    w_hh_batch = w_hh.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, hidden_size]->[b_size, ,]

    output_size = h_size
    output = torch.zeros(b_size, seq_len, output_size)  # 初始化一个输出序列
    for t in range(seq_len):
        x = input_6666[:, t, :]  # 当前时刻的输入向量 [b,in_size]->[b,in_size,1]
        w_times_x = torch.bmm(w_ih_batch, x.unsqueeze(-1)).squeeze(-1)   # bmm:含有批量大小的矩阵相乘
        # [b, 4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Wii*xt|Wif*xt|Wig*xt|Wio*xt
        w_times_h_prev = torch.bmm(w_hh_batch, h_prev.unsqueeze(-1)).squeeze(-1)
        # [b, 4*hidden_size, hidden_size]*[b, hidden_size, 1]->[b,4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Whi*ht-1|Whf*ht-1|Whg*ht-1|Who*ht-1

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)  维度均为 [b, h_size]
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])  # 取前四分之一
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size]
                            + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size]
                            + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:]
                            + b_ih[3*h_size:] + b_hh[3*h_size:])
        c_prev = f_t * c_prev + i_t * g_t
        h_prev = o_t * torch.tanh(c_prev)

        output[:, t, :] = h_prev

    return output, (h_prev.unsqueeze(0), c_prev.unsqueeze(0))  # 官方是三维，在第0维扩一维


In [44]:
# 这里使用 lstm_layer 中的参数
# 加了me表示自己手写的
output_me, (h_n_me, c_n_me) = lstm_forward(input_6666, (h_0, c_0), lstm_layer.weight_ih_l0,
                                            lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)


In [45]:
print("PyTorch API output:")
print(output)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n)  # [1,2,5] [num_layers, b, hidden_size]
print("\nlstm_forward function output:")
print(output_me)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n_me)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n_me)
print(output_me.shape)


PyTorch API output:
tensor([[[ 0.2675, -0.1014,  0.0701,  ..., -0.4516, -0.4680,  0.3969]],

        [[-0.2050, -0.1308,  0.3732,  ...,  0.1511, -0.3707,  0.3849]],

        [[-0.0065, -0.0293,  0.0346,  ...,  0.4093,  0.1275,  0.0865]],

        ...,

        [[-0.1264,  0.1540, -0.0821,  ...,  0.3826,  0.5380, -0.2471]],

        [[-0.0446, -0.1559,  0.0485,  ...,  0.1540, -0.3700, -0.1222]],

        [[-0.3258, -0.3508, -0.1058,  ..., -0.1239, -0.3522, -0.3892]]],
       grad_fn=<TransposeBackward0>)
tensor([[[ 0.2675, -0.1014,  0.0701,  ..., -0.4516, -0.4680,  0.3969],
         [-0.2050, -0.1308,  0.3732,  ...,  0.1511, -0.3707,  0.3849],
         [-0.0065, -0.0293,  0.0346,  ...,  0.4093,  0.1275,  0.0865],
         ...,
         [-0.1264,  0.1540, -0.0821,  ...,  0.3826,  0.5380, -0.2471],
         [-0.0446, -0.1559,  0.0485,  ...,  0.1540, -0.3700, -0.1222],
         [-0.3258, -0.3508, -0.1058,  ..., -0.1239, -0.3522, -0.3892]]],
       grad_fn=<StackBackward0>)
tensor([[[ 1.053

In [None]:
#自己写一个LSTM模型
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    h0, c0 = initial_states #初始状态
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) #bs,4*hidden_size,i_size
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)#bs,4*size, h_size

    output_size = h_size
    output = torch.zeros(bs, T, output_size) #输出序列

    for t in range(T):
        x = input[:,t, :] #当前时刻的输入向量,[bs, i_size],矩阵相乘后面加一个维度
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  #bs, 4*h_size,1
        w_times_x = w_times_x.squeeze(-1) #bs, 4*h_size

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  #bs, 4*h_size,1
        w_times_h_prev = w_times_h_prev.squeeze(-1) #bs, 4*h_size

        #分别计算输入们(i)\遗忘门(f)\cell门(g)\输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size] + \
            w_times_h_prev[:, :h_size] +b_ih[:h_size] + b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + \
            w_times_h_prev[:, h_size:2*h_size] +b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:,2*h_size:3*h_size]\
            +b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + \
            w_times_h_prev[:, 3*h_size:] +b_ih[3*h_size:] + b_hh[3*h_size:])
        
        #然后算记忆元ct,迭代实现
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)
        output[:, t, :] = prev_h
    return output, (prev_h, prev_c)




output_custom, (h_final_custom, c_final_custom) = lstm_forward(input, (h0,c0), \
    lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0,lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)

print(output)
print(output_custom)
