In [1]:
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import *
import numpy as np
import os
import sys

In [2]:
sampler

<module 'torch.utils.data.sampler' from 'C:\\Users\\Administrator\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\sampler.py'>

In [28]:
# 先构建一个dataset
class my_tracking_dataset(Dataset):
    def __init__(self, inputs, targets):
        super(my_tracking_dataset, self).__init__()
        self.inp = inputs
        self.tgt = targets
        self.len = len(inputs)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        # 这里是以字典的方式返回数据项
        disc = {}
        disc.update({'inputs': self.inp[idx], 'targets': self.tgt[idx]})
        return disc

In [29]:
# 构建dummy_input以及dummy_tgt
dummy_input = torch.arange(20*3*128*128, dtype=torch.float32).view((20,3,128,128))
dummy_target = torch.arange(20)

In [30]:
dummy_input[0].shape

torch.Size([3, 128, 128])

In [31]:
dummy_target

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

In [32]:
data = my_tracking_dataset(dummy_input, dummy_target)

In [33]:
data.len

20

In [34]:
len(data)

20

In [35]:
# 构建一个batchSampler
the_batchSampler = BatchSampler(SequentialSampler(range(20)), batch_size=4, drop_last=True)

In [36]:
list(the_batchSampler)

[[0, 1, 2, 3],
 [4, 5, 6, 7],
 [8, 9, 10, 11],
 [12, 13, 14, 15],
 [16, 17, 18, 19]]

In [37]:
# 构建一个dataloader
dataloader = DataLoader(data, batch_sampler=the_batchSampler)

In [38]:
for idx, batch in enumerate(dataloader):
    print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))

idx:0; input:tensor([[[[0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 1.2500e+02,
           1.2600e+02, 1.2700e+02],
          [1.2800e+02, 1.2900e+02, 1.3000e+02,  ..., 2.5300e+02,
           2.5400e+02, 2.5500e+02],
          [2.5600e+02, 2.5700e+02, 2.5800e+02,  ..., 3.8100e+02,
           3.8200e+02, 3.8300e+02],
          ...,
          [1.6000e+04, 1.6001e+04, 1.6002e+04,  ..., 1.6125e+04,
           1.6126e+04, 1.6127e+04],
          [1.6128e+04, 1.6129e+04, 1.6130e+04,  ..., 1.6253e+04,
           1.6254e+04, 1.6255e+04],
          [1.6256e+04, 1.6257e+04, 1.6258e+04,  ..., 1.6381e+04,
           1.6382e+04, 1.6383e+04]],

         [[1.6384e+04, 1.6385e+04, 1.6386e+04,  ..., 1.6509e+04,
           1.6510e+04, 1.6511e+04],
          [1.6512e+04, 1.6513e+04, 1.6514e+04,  ..., 1.6637e+04,
           1.6638e+04, 1.6639e+04],
          [1.6640e+04, 1.6641e+04, 1.6642e+04,  ..., 1.6765e+04,
           1.6766e+04, 1.6767e+04],
          ...,
          [3.2384e+04, 3.2385e+04, 3.2386e+04,

In [39]:
# 构建一个随机的数据加载器
dataloader2 = DataLoader(data, shuffle=True, batch_size=4)

In [40]:
for idx, batch in enumerate(dataloader2):
    print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))

idx:0; input:tensor([[[[3.9322e+05, 3.9322e+05, 3.9322e+05,  ..., 3.9334e+05,
           3.9334e+05, 3.9334e+05],
          [3.9334e+05, 3.9334e+05, 3.9335e+05,  ..., 3.9347e+05,
           3.9347e+05, 3.9347e+05],
          [3.9347e+05, 3.9347e+05, 3.9347e+05,  ..., 3.9360e+05,
           3.9360e+05, 3.9360e+05],
          ...,
          [4.0922e+05, 4.0922e+05, 4.0922e+05,  ..., 4.0934e+05,
           4.0934e+05, 4.0934e+05],
          [4.0934e+05, 4.0934e+05, 4.0935e+05,  ..., 4.0947e+05,
           4.0947e+05, 4.0947e+05],
          [4.0947e+05, 4.0947e+05, 4.0947e+05,  ..., 4.0960e+05,
           4.0960e+05, 4.0960e+05]],

         [[4.0960e+05, 4.0960e+05, 4.0960e+05,  ..., 4.0972e+05,
           4.0973e+05, 4.0973e+05],
          [4.0973e+05, 4.0973e+05, 4.0973e+05,  ..., 4.0985e+05,
           4.0985e+05, 4.0986e+05],
          [4.0986e+05, 4.0986e+05, 4.0986e+05,  ..., 4.0998e+05,
           4.0998e+05, 4.0998e+05],
          ...,
          [4.2560e+05, 4.2560e+05, 4.2560e+05,