In [48]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

创建初始数据, 并使用`pad_sequence`将数据补充成等长的数据

In [49]:
train_x = [torch.Tensor([1, 2, 3, 4, 5, 6, 7]),
           torch.Tensor([2, 3, 4, 5, 6, 7]),
           torch.Tensor([3, 4, 5, 6, 7]),
           torch.Tensor([4, 5, 6, 7]),
           torch.Tensor([5, 6, 7]),
           torch.Tensor([6, 7]),
           torch.Tensor([7])]

In [50]:
class MyData(Dataset):

    def __init__(self, train_x):
        self.train_x = train_x

    def __len__(self):
        return len(self.train_x)

    def __getitem__(self, item):
        return self.train_x[item]


train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2)

for i in train_dataloader:
    print(i)

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 7 and 6 in dimension 1 at /tmp/pip-req-build-58y_cjjl/aten/src/TH/generic/THTensor.cpp:689

这里定义了一个`collate_fn`函数, 然后对数据进行排序, 提取数据的实际长度, 最后对数据进行填充, 这些准备是用于后面的`pack_padded_sequence()`

In [51]:
def collate_fn(train_data):
    train_data.sort(key=lambda data: len(data), reverse=True)
    data_length = [len(data) for data in train_data]
    train_data = rnn_utils.pad_sequence(train_data, batch_first=True, padding_value=0)
    return train_data, data_length

将`collate_fn`函数传入 `DataLoader` 函数

In [52]:
train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)

for data, length in train_dataloader:
    print(data)
    print(length)

tensor([[1., 2., 3., 4., 5., 6., 7.],
        [2., 3., 4., 5., 6., 7., 0.]])
[7, 6]
tensor([[3., 4., 5., 6., 7.],
        [4., 5., 6., 7., 0.]])
[5, 4]
tensor([[5., 6., 7.],
        [6., 7., 0.]])
[3, 2]
tensor([[7.]])
[1]


这里使用`pack_padded_sequence`进行数据的填充

In [53]:
for data, length in train_dataloader:
    data = rnn_utils.pack_padded_sequence(data, length, batch_first=True)
    print(data)

PackedSequence(data=tensor([1., 2., 2., 3., 3., 4., 4., 5., 5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 2, 2, 2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([3., 4., 4., 5., 5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([7.]), batch_sizes=tensor([1]), sorted_indices=None, unsorted_indices=None)


定义好`LSTM`, 然后将数据输入`LSTM`中, 看输出的第一个batch的数据格式

In [57]:
net = nn.LSTM(1, 5, batch_first=True)

In [55]:
def collate_fn(train_data):
    train_data.sort(key=lambda data: len(data), reverse=True)
    data_length = [len(data) for data in train_data]
    train_data = rnn_utils.pad_sequence(train_data, batch_first=True, padding_value=0)
    return train_data.unsqueeze(-1), data_length

In [59]:
train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)

flag = 0
for data, length in train_dataloader:
    data = rnn_utils.pack_padded_sequence(data, length, batch_first=True)
    output, hidden = net(data)
    if flag == 0:
        print(output)
        flag = 1

PackedSequence(data=tensor([[-0.0359, -0.0036,  0.0825,  0.1019, -0.1004],
        [ 0.0155,  0.0222,  0.0926,  0.1369, -0.0548],
        [-0.0054,  0.0196,  0.1241,  0.1759, -0.1449],
        [ 0.0495,  0.0504,  0.1263,  0.2017, -0.0534],
        [ 0.0374,  0.0475,  0.1405,  0.2131, -0.1426],
        [ 0.0729,  0.0720,  0.1338,  0.2225,  0.0114],
        [ 0.0656,  0.0693,  0.1410,  0.2237, -0.0812],
        [ 0.0792,  0.0866,  0.1280,  0.2228,  0.1560],
        [ 0.0743,  0.0844,  0.1319,  0.2203,  0.0601],
        [ 0.0737,  0.0962,  0.1156,  0.2154,  0.3757],
        [ 0.0701,  0.0946,  0.1179,  0.2117,  0.2878],
        [ 0.0630,  0.1021,  0.1004,  0.2058,  0.6103],
        [ 0.0604,  0.1011,  0.1020,  0.2022,  0.5502]], grad_fn=<CatBackward>), batch_sizes=tensor([2, 2, 2, 2, 2, 2, 1]), sorted_indices=None, unsorted_indices=None)


将输出的数据通过`pad_packed_sequence`, 然后我们可以看到数据符合我们的要求了

In [64]:
train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)

flag = 0
for data, length in train_dataloader:
    data = rnn_utils.pack_padded_sequence(data, length, batch_first=True)
    output, hidden = net(data)
    if flag == 0:
        output, out_len = rnn_utils.pad_packed_sequence(output, batch_first=True)
        print(output.shape)
        print(output)
        flag = 1

torch.Size([2, 7, 5])
tensor([[[-0.0359, -0.0036,  0.0825,  0.1019, -0.1004],
         [-0.0054,  0.0196,  0.1241,  0.1759, -0.1449],
         [ 0.0374,  0.0475,  0.1405,  0.2131, -0.1426],
         [ 0.0656,  0.0693,  0.1410,  0.2237, -0.0812],
         [ 0.0743,  0.0844,  0.1319,  0.2203,  0.0601],
         [ 0.0701,  0.0946,  0.1179,  0.2117,  0.2878],
         [ 0.0604,  0.1011,  0.1020,  0.2022,  0.5502]],

        [[ 0.0155,  0.0222,  0.0926,  0.1369, -0.0548],
         [ 0.0495,  0.0504,  0.1263,  0.2017, -0.0534],
         [ 0.0729,  0.0720,  0.1338,  0.2225,  0.0114],
         [ 0.0792,  0.0866,  0.1280,  0.2228,  0.1560],
         [ 0.0737,  0.0962,  0.1156,  0.2154,  0.3757],
         [ 0.0630,  0.1021,  0.1004,  0.2058,  0.6103],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<TransposeBackward0>)
