In [39]:
import torch
import torch.utils.data as Data

BATCH_SIZE = 4

x = torch.linspace(0.01, 0.1, 100)
y = torch.linspace(1, 10, 100)

torch_dataset = Data.TensorDataset(x, y) # 包装数据集

n_train = len(x)
split = n_train // 3
indices = torch.randperm(n_train) #  Returns a random permutation of integers from 0 to n - 1.

In [40]:
# Samples elements randomly from a given list of indices, without replacement(不重复)
'''
SubsetRandomSampler类实现过程:
class SubsetRandomSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)
'''
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) # 训练数据集(2/3)的索引(可根据需求灵活设置indices)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) # 测试数据集(1/3)的索引

In [41]:
list(test_sampler)

[tensor(82),
 tensor(55),
 tensor(4),
 tensor(39),
 tensor(20),
 tensor(94),
 tensor(45),
 tensor(69),
 tensor(65),
 tensor(23),
 tensor(5),
 tensor(8),
 tensor(56),
 tensor(97),
 tensor(35),
 tensor(86),
 tensor(95),
 tensor(34),
 tensor(66),
 tensor(14),
 tensor(2),
 tensor(64),
 tensor(44),
 tensor(51),
 tensor(92),
 tensor(3),
 tensor(0),
 tensor(7),
 tensor(96),
 tensor(13),
 tensor(76),
 tensor(12),
 tensor(49)]

In [42]:
train_loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE,
                               sampler=train_sampler) # sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.


test_loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE,
                               sampler=test_sampler)
def show_batch(dataset):
    for epoch in range(3):
        for step, batch_data in enumerate(dataset):
            print("batch:{}, batch_data:{}".format(step, batch_data)) # tensor长度为batch_size
        print()

In [43]:
show_batch(train_loader)

batch:0, batch_data:[tensor([0.0800, 0.0318, 0.0618, 0.0709]), tensor([8.0000, 3.1818, 6.1818, 7.0909])]
batch:1, batch_data:[tensor([0.0473, 0.0555, 0.0182, 0.0191]), tensor([4.7273, 5.5455, 1.8182, 1.9091])]
batch:2, batch_data:[tensor([0.0464, 0.0436, 0.0836, 0.0782]), tensor([4.6364, 4.3636, 8.3636, 7.8182])]
batch:3, batch_data:[tensor([0.0809, 0.0355, 0.0864, 0.0873]), tensor([8.0909, 3.5455, 8.6364, 8.7273])]
batch:4, batch_data:[tensor([0.0327, 0.0245, 0.0382, 0.0855]), tensor([3.2727, 2.4545, 3.8182, 8.5455])]
batch:5, batch_data:[tensor([0.0482, 0.0573, 0.0745, 0.0818]), tensor([4.8182, 5.7273, 7.4545, 8.1818])]
batch:6, batch_data:[tensor([0.1000, 0.0291, 0.0891, 0.0945]), tensor([10.0000,  2.9091,  8.9091,  9.4545])]
batch:7, batch_data:[tensor([0.0400, 0.0536, 0.0627, 0.0255]), tensor([4.0000, 5.3636, 6.2727, 2.5455])]
batch:8, batch_data:[tensor([0.0900, 0.0773, 0.0345, 0.0264]), tensor([9.0000, 7.7273, 3.4545, 2.6364])]
batch:9, batch_data:[tensor([0.0764, 0.0645, 0.0236

In [44]:
show_batch(test_loader)


batch:0, batch_data:[tensor([0.0964, 0.0136, 0.0700, 0.0509]), tensor([9.6364, 1.3636, 7.0000, 5.0909])]
batch:1, batch_data:[tensor([0.0309, 0.0218, 0.0600, 0.0100]), tensor([3.0909, 2.1818, 6.0000, 1.0000])]
batch:2, batch_data:[tensor([0.0564, 0.0145, 0.0936, 0.0409]), tensor([5.6364, 1.4545, 9.3636, 4.0909])]
batch:3, batch_data:[tensor([0.0455, 0.0545, 0.0173, 0.0500]), tensor([4.5455, 5.4545, 1.7273, 5.0000])]
batch:4, batch_data:[tensor([0.0791, 0.0882, 0.0727, 0.0282]), tensor([7.9091, 8.8182, 7.2727, 2.8182])]
batch:5, batch_data:[tensor([0.0955, 0.0691, 0.0118, 0.0164]), tensor([9.5455, 6.9091, 1.1818, 1.6364])]
batch:6, batch_data:[tensor([0.0682, 0.0227, 0.0845, 0.0127]), tensor([6.8182, 2.2727, 8.4545, 1.2727])]
batch:7, batch_data:[tensor([0.0973, 0.0609, 0.0982, 0.0418]), tensor([9.7273, 6.0909, 9.8182, 4.1818])]
batch:8, batch_data:[tensor([0.0209]), tensor([2.0909])]

batch:0, batch_data:[tensor([0.0127, 0.0791, 0.0418, 0.0409]), tensor([1.2727, 7.9091, 4.1818, 4.0909]