In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# 假设已经有一个embedding层，embedding_dim是嵌入维度
embedding_dim = 5
embedding = nn.Embedding(10, embedding_dim)  # 假设词汇表大小为10

# 假设原始的token是3个句子，它们已经被编码为 token ids
# 示例token序列为[1, 2, 3], [4, 5], [6, 7, 8, 9]
tokens = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6, 7, 8, 9])]

# 使用Embedding层将token转换为嵌入
embedded_tokens = [embedding(token) for token in tokens]

In [6]:
embedded_tokens

[tensor([[-0.3956,  0.4439,  0.5085, -0.5840,  0.8709],
         [ 0.2143,  1.6778,  0.0112, -0.0320, -0.3915],
         [ 0.9667, -0.9229,  0.0081,  1.8913, -0.6835]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[ 1.3310, -0.3161,  0.5835,  0.8231, -0.1400],
         [ 1.6454, -0.3439, -1.3837,  0.8876,  1.4397]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[-0.9114, -1.2176,  0.2858, -0.5130,  0.0268],
         [-0.0803,  0.6734, -1.3797, -0.8677,  0.9274],
         [-1.1376, -0.3452,  0.3709,  0.3842,  0.8517],
         [-2.8832,  1.1551,  0.2756,  0.0649, -0.1000]],
        grad_fn=<EmbeddingBackward0>)]

In [3]:
# 使用pad_sequence将它们填充为相同长度的序列
padded_inputs = nn.utils.rnn.pad_sequence(embedded_tokens, batch_first=True, padding_value=0)



In [7]:
padded_inputs   

tensor([[[-0.3956,  0.4439,  0.5085, -0.5840,  0.8709],
         [ 0.2143,  1.6778,  0.0112, -0.0320, -0.3915],
         [ 0.9667, -0.9229,  0.0081,  1.8913, -0.6835],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 1.3310, -0.3161,  0.5835,  0.8231, -0.1400],
         [ 1.6454, -0.3439, -1.3837,  0.8876,  1.4397],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.9114, -1.2176,  0.2858, -0.5130,  0.0268],
         [-0.0803,  0.6734, -1.3797, -0.8677,  0.9274],
         [-1.1376, -0.3452,  0.3709,  0.3842,  0.8517],
         [-2.8832,  1.1551,  0.2756,  0.0649, -0.1000]]], grad_fn=<CopySlices>)

In [None]:
# 假设每个序列的有效长度
lengths = torch.tensor([3, 2, 4])

# 将填充后的输入和有效长度传入pack_padded_sequence
packed_input = pack_padded_sequence(padded_inputs, lengths, batch_first=True, enforce_sorted=False)

# 定义一个LSTM层
lstm = nn.LSTM(input_size=embedding_dim, hidden_size=10, batch_first=True)

# 输入LSTM
packed_output, (h_n, c_n) = lstm(packed_input)

# 解包输出
output, _ = pad_packed_sequence(packed_output, batch_first=True)

print("Output shape:", output.shape)
