LSTM实现

In [1]:
import torch
import torch.nn as nn

class lstm_m(nn.Module):
    def __init__(self, input_size, hidden_size) -> None:
        super(lstm_m, self).__init__()
        # input_sz 输入的数组长度
        self.input_size = input_size
        # hidden_sz cell_state和hidden_state长度
        self.hidden_size = hidden_size

        self.Wii = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.Whi = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.Wif = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.Whf = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.Wig = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.Whg = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.Wio = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.Who = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        
    def forward(self, x):
        batch_size = x.shape[0]
        ht = torch.rand([batch_size, self.hidden_size])
        ct = torch.rand([batch_size, self.hidden_size])
        output = []
        for i in range(x.shape[1]):
            xt = x[:, i,:]
            it = torch.sigmoid(self.Wii(xt)+self.Whi(ht))
            ft = torch.sigmoid(self.Wif(xt)+self.Whi(ht))
            gt = torch.tanh(self.Wig(xt)+self.Whg(ht))
            ot = torch.sigmoid(self.Wio(xt)+self.Who(ht))
            ct = ft*ct + it*gt
            ht = ot * torch.sigmoid(ct)
            output.append(ht)
        lstm_output = torch.cat(output, dim=1)
        lstm_output = lstm_output.view(-1, output[0].shape[0], output[0].shape[1])
        return lstm_output, (ht, ct)
x = torch.rand([4, 100, 128])
model = lstm_m(128, 256)
output, (hn, cn) = model(x)
print(output.shape)
print(hn.shape)
print(cn.shape)

torch.Size([100, 4, 256])
torch.Size([4, 256])
torch.Size([4, 256])


nn.Linear: $y = xA^T+b$

In [3]:
layer = nn.Linear(10,20)
print(layer.weight.shape)
print(layer.bias.shape)
w = torch.rand(20,10)
b = torch.rand(20)
layer.weight = torch.nn.Parameter(w)
layer.bias = torch.nn.Parameter(b)
x = torch.rand(4,10)
y1 = layer(x)
y2 = x @ w.T + b
print(torch.sum(torch.abs(y1-y2)))

torch.Size([20, 10])
torch.Size([20])
tensor(0., grad_fn=<SumBackward0>)


基于注意力的LSTM

In [5]:
x = torch.rand([16, 100, 128])  # batch_size=16, seq_len=100, feature_dim=128
lstm = nn.LSTM(128, 256, batch_first=True)  # input_size=128, hidden_size=256
w_omiga = torch.randn(16, 256, 1, requires_grad=True)
output, (hn, cn) = lstm(x)
print(output.shape) # [batch_size, seq_len, hidden_size]
print(hn.shape)     # [1, batch_size, hidden_size]
print(cn.shape)     # [1, batch_size, hidden_size]

# 1 用输出作为query，随机初始化矩阵作为key
# H = torch.tanh(output)  # [batch_size, seq_len, hidden_size]
# attn_weight = torch.softmax(torch.bmm(H, w_omiga), dim=1)   # [batch_size, seq_len, 1]
# output = torch.mul(output, attn_weight)     # 加权后的输出 [batch_size, seq_len, hidden_size]
# output = output.sum(dim=1)                  # [batch_size, hidden_size]

# 2 用输出作为query, hn作为key
attn_weight = torch.softmax(torch.bmm(output, hn.permute(1,2,0)), dim=1) # [batch_size, seq_len, 1]
output = torch.mul(output, attn_weight)
output = output.sum(dim=1)


torch.Size([16, 100, 256])
torch.Size([1, 16, 256])
torch.Size([1, 16, 256])
