In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# 假设已经有一个embedding层，embedding_dim是嵌入维度
embedding_dim = 5
embedding = nn.Embedding(10, embedding_dim,padding_idx=0)  # 假设词汇表大小为10

# 假设原始的token是3个句子，它们已经被编码为 token ids
# 示例token序列为[1, 2, 3], [4, 5], [6, 7, 8, 9]
tokens = torch.tensor([[1,2,3,0],
                       [4,5,0,0],
                       [6,7,8,9]],dtype=torch.int32)

# 使用Embedding层将token转换为嵌入
embedded_tokens = embedding(tokens)

In [2]:
embedded_tokens

tensor([[[-2.3816,  0.8971,  0.9209, -0.1366, -0.0108],
         [-0.7181, -1.2318,  0.2207, -1.4396,  0.4422],
         [-2.0836, -0.1481,  1.2221, -0.7765,  0.9815],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.3544, -0.3383,  0.1941,  0.3798,  0.6775],
         [ 0.2916,  0.6503,  0.8309, -1.5693,  1.0405],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.0277,  0.2214, -0.1070,  1.3576, -2.7300],
         [ 0.3351,  0.3744,  1.3038, -0.7976,  0.5450],
         [ 1.6915, -0.6120,  1.9171, -1.1943, -0.7973],
         [-2.2708, -0.7220, -0.8340,  1.7299,  1.3875]]],
       grad_fn=<EmbeddingBackward0>)

In [3]:
padded_inputs  = embedded_tokens 
padded_inputs

tensor([[[-2.3816,  0.8971,  0.9209, -0.1366, -0.0108],
         [-0.7181, -1.2318,  0.2207, -1.4396,  0.4422],
         [-2.0836, -0.1481,  1.2221, -0.7765,  0.9815],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.3544, -0.3383,  0.1941,  0.3798,  0.6775],
         [ 0.2916,  0.6503,  0.8309, -1.5693,  1.0405],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.0277,  0.2214, -0.1070,  1.3576, -2.7300],
         [ 0.3351,  0.3744,  1.3038, -0.7976,  0.5450],
         [ 1.6915, -0.6120,  1.9171, -1.1943, -0.7973],
         [-2.2708, -0.7220, -0.8340,  1.7299,  1.3875]]],
       grad_fn=<EmbeddingBackward0>)

In [4]:
# 假设每个序列的有效长度
lengths = torch.tensor([3, 2, 4])

# 将填充后的输入和有效长度传入pack_padded_sequence
packed_input = pack_padded_sequence(padded_inputs, lengths, batch_first=True, enforce_sorted=False)



In [5]:
# 定义一个LSTM层
lstm = nn.LSTM(input_size=embedding_dim, hidden_size=10, batch_first=True)

# 输入LSTM
packed_output, (h_n, c_n) = lstm(packed_input)

# 解包输出
output, _ = pad_packed_sequence(packed_output, batch_first=True)

print("Output shape:", output.shape)


Output shape: torch.Size([3, 4, 10])


In [8]:
packed_output[0].shape

torch.Size([9, 10])