In [1]:
import torch
import torch.nn as nn

In [8]:
bs, T, i_size, h_size = 2, 3, 4, 5
input = torch.randn(bs, T, i_size)
c0 = torch.randn(bs, h_size)
h0 = torch.randn(bs,h_size)

lstm_layer = nn.LSTM(i_size, h_size, batch_first = True)
output, (hn, cn) = lstm_layer(input, (h0.unsqueeze(0) ,c0.unsqueeze(0)))
print(output)
print(hn)
print(cn)

tensor([[[-0.0467, -0.2262,  0.1009,  0.0181, -0.4716],
         [ 0.0770, -0.1335, -0.0828,  0.1359, -0.2228],
         [ 0.1128, -0.1604, -0.1379,  0.4182,  0.0183]],

        [[ 0.4830,  0.3796,  0.0830, -0.0610, -0.0134],
         [ 0.0589,  0.2820, -0.1063, -0.1542, -0.1215],
         [ 0.0594,  0.1878, -0.1361, -0.1986, -0.1460]]],
       grad_fn=<TransposeBackward0>)
tensor([[[ 0.1128, -0.1604, -0.1379,  0.4182,  0.0183],
         [ 0.0594,  0.1878, -0.1361, -0.1986, -0.1460]]],
       grad_fn=<StackBackward0>)
tensor([[[ 0.3187, -0.3272, -0.3735,  0.6051,  0.0393],
         [ 0.1298,  0.4271, -0.2945, -0.3860, -0.2693]]],
       grad_fn=<StackBackward0>)


In [10]:
for k, v in lstm_layer.named_parameters():
    print(k, v)
    print(k, v.shape)

weight_ih_l0 Parameter containing:
tensor([[ 0.4169, -0.0408,  0.2963, -0.3897],
        [ 0.2836, -0.2703, -0.3928,  0.3190],
        [ 0.0222,  0.1646,  0.3935, -0.3035],
        [ 0.3735, -0.4393, -0.3617, -0.3809],
        [ 0.3914,  0.0273,  0.3561, -0.1903],
        [-0.2129,  0.3855,  0.3294, -0.4384],
        [ 0.1983,  0.1146, -0.2064,  0.3244],
        [-0.0149,  0.0017,  0.1221,  0.0919],
        [ 0.2423,  0.4331, -0.0146,  0.0933],
        [ 0.0081, -0.1136,  0.1429,  0.2141],
        [-0.1832,  0.0664, -0.1759,  0.2322],
        [ 0.0890, -0.1788,  0.2994, -0.2285],
        [-0.1244, -0.1792,  0.2920,  0.1311],
        [-0.4446,  0.2604, -0.1670, -0.1984],
        [-0.1053,  0.2581, -0.2646, -0.1538],
        [-0.1556, -0.4160,  0.0323,  0.3732],
        [-0.2581, -0.0085,  0.3566, -0.2454],
        [ 0.3432, -0.0930, -0.1764, -0.0991],
        [-0.3629,  0.3339, -0.3523,  0.4205],
        [ 0.0654,  0.1552,  0.1192,  0.0624]], requires_grad=True)
weight_ih_l0 torch.Size(

In [17]:
##i_s元组（h0,c0）
def lstm_forward(input, i_s,w_ih, w_hh, b_ih, b_hh):
    h0, c0 = i_s
    bs, T, input_size = input.shape
    h_size = w_ih.shape[0]//4
    
    prev_h = h0
    prev_c = c0
    output_size = h_size
    output = torch.zeros(bs, T, output_size)
    
    #w_ih [4*h_size, i_size] w_hh [4*h_size, h_size]
    w_ih_batch = w_ih.unsqueeze(0).tile(bs, 1, 1) 
    w_hh_batch = w_hh.unsqueeze(0).tile(bs, 1, 1) 
    for t in range(T):
        x = input[:, t ,:] ##batch_size,t,input_size
        w_times_x = torch.bmm(w_ih_batch, x.unsqueeze(-1)) #扩[bs, 4*h_size, 1]
        w_times_x = w_times_x.squeeze(-1) #[bs, 4*h_size]
        w_times_h_prev = torch.bmm(w_hh_batch, prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)
        
        ##输入门（i）,遗忘门(f), cell门(g), 输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size]+ w_times_h_prev[:, :h_size] + b_ih[:h_size] +b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size]+ w_times_h_prev[:, h_size:2*h_size] + b_ih[h_size:2*h_size] +b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size]+ w_times_h_prev[:, 2*h_size:3*h_size] + b_ih[2*h_size:3*h_size] +b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:4*h_size]+ w_times_h_prev[:, 3*h_size:4*h_size] + b_ih[3*h_size:4*h_size] +b_hh[3*h_size:4*h_size])
        prev_c = f_t*prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)
        output[:, t, :] = prev_h
    return output, (prev_h, prev_c)

In [18]:
lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)

(tensor([[[-0.0467, -0.2262,  0.1009,  0.0181, -0.4716],
          [ 0.0770, -0.1335, -0.0828,  0.1359, -0.2228],
          [ 0.1128, -0.1604, -0.1379,  0.4182,  0.0183]],
 
         [[ 0.4830,  0.3796,  0.0830, -0.0610, -0.0134],
          [ 0.0589,  0.2820, -0.1063, -0.1542, -0.1215],
          [ 0.0594,  0.1878, -0.1361, -0.1986, -0.1460]]], grad_fn=<CopySlices>),
 (tensor([[ 0.1128, -0.1604, -0.1379,  0.4182,  0.0183],
          [ 0.0594,  0.1878, -0.1361, -0.1986, -0.1460]], grad_fn=<MulBackward0>),
  tensor([[ 0.3187, -0.3272, -0.3735,  0.6051,  0.0393],
          [ 0.1298,  0.4271, -0.2945, -0.3860, -0.2693]], grad_fn=<AddBackward0>)))

In [1]:
quit()