In [4]:
from torch.utils.data import Dataset, DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence
import random

In [50]:
class TestDataset(Dataset):
    def __init__(self):
        self.datas = list()
        self.data_len = 2000
        for i in range(self.data_len):
            sentence_len = random.randint(20, 50) # 句子长度是随机的
            sentence = list()
            for j in range(sentence_len):
                sentence.append(random.randint(3,20000)) # 句子长度中每个位置的token值是随机的
            self.datas.append(sentence) # 将每个句子放入datas中

    def __getitem__(self, idx):
        return self.datas[idx]

    def __len__(self):
        return len(self.datas)

In [51]:
test_dataset = TestDataset()

上述过程在模拟一个NLP的文本dataset 每个文本的长度是不确定的
使用Dataloader的collate_fn函数为每个文本padding(用0值padding)之后 加上<sos>和<eos>分别用1和2代替 并且按照原有句子长度降序排列

In [53]:
def padding_sentence(batch):
    # batch的维度是[batch, sentence_len] sentence_len是可变的
    batch = sorted(batch, key = lambda x:len(x), reverse=True)
    # 先按照sequence的长度降序排序
    batch_tensor = list()
    for s in batch:
        s.insert(0,1) # 开头padding一个1代表<sos>
        s.append(2) # 末尾padding一个2代表<eos>
        batch_tensor.append(torch.LongTensor(s)) # 转化为LongTensor
    batch = pad_sequence(batch_tensor, batch_first=True) # pad_sequence用0进行pad
    return batch

In [52]:
dataloader = DataLoader(test_dataset, batch_size=10, collate_fn=padding_sentence)

In [54]:
for each in enumerate(dataloader):
    print(each)

(0, tensor([[    1,  7631, 11125,  5756, 17441,  6587,  9309,  6101, 12151, 12079,
         11690, 10083,   725, 18528, 11018, 17837, 19473,  9109,  5279,  3600,
          6116,  5888, 17212, 14074,  9968, 16358, 19634,  2321, 15289,  2588,
          8407, 12106,  5031, 19391,  8743, 13097, 14550, 17068, 12239,  6782,
           760,  7901,  9795, 15099, 14190,  2655, 16442,  1363, 11481, 11130,
         19680,     2],
        [    1, 14224,  9276,  7112,  8264, 11739, 13426, 11227,  3619,  7461,
         10415,  3326, 13398, 11835, 12194,  9363, 18439,  9437,  8053,  1389,
         18726, 14316,  5145,  5965, 10040,  8281,  8601,  8175, 12106,  8766,
          1186,  1952, 14019,   828,  7619,   478,  9579, 12760, 14206,  5759,
         17770, 17658, 17631,  4952, 19246,  9434, 12826,  7972,   228,  8871,
         13790,     2],
        [    1,  9159, 15982,  6511,  8521,  8310,  7485, 13531, 13215, 18459,
         14800,  2906,  4649,  4809,  6590,  7403,  9410, 13422, 19773,  6794,
