In [1]:
import pickle, gzip
from pathlib import Path
import torch
import fastcore.all as fc
from matplotlib import pylab as plt

In [2]:
data = Path('data')

In [3]:
gzip.open(data/'mnist.pkl.gz')

<gzip _io.BufferedReader name='data/mnist.pkl.gz' 0x7fa2142db5b0>

In [4]:
((x_train, y_train), (x_valid, y_valid), (x_test, y_test)) = pickle.load(gzip.open(data/'mnist.pkl.gz'), encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(torch.tensor, [x_train, y_train, x_valid, y_valid])

In [5]:
class Dataset:
    def __init__(self, *ds): self.ds = list(zip(*ds))
    def __len__(self): return len(self.ds)
    def __getitem__(self, i): return self.ds[i]

In [6]:
ds = Dataset(x_train,y_train)
ds[5][0].shape, ds[5][1], type(ds[2])

(torch.Size([784]), tensor(2), tuple)

In [7]:
import random

class Sampler:
    def __init__(self, ds, shuffle): self.n,self.shuffle = len(ds),shuffle
    def __iter__(self): 
        res = list(range(self.n))
        if self.shuffle: random.shuffle(res)
        return iter(res)

In [8]:
s = Sampler(ds, True)
for _ in range(5):
    print(next(iter(s)))

44240
14156
28325
19433
28567


In [9]:
from itertools import islice
ix = islice(s,5)
list(ix)

[46129, 9899, 39302, 7516, 44726]

In [10]:
import fastcore.all as fc
# fc.chunked??

In [11]:
class BatchSampler:
    def __init__(self, sampler, bs, drop_last=False): fc.store_attr()
    def __iter__(self): yield from fc.chunked(iter(self.sampler), self.bs, self.drop_last)
        

In [12]:
batch_sampler = BatchSampler(Sampler(ds, shuffle=True), bs=256, drop_last=True)
batch = next(iter(batch_sampler))

In [13]:
def collate_fn(b):
    xs, ys = list(zip(*b))
    return torch.stack(xs), torch.stack(ys)

In [14]:
class Dataloader:
    def __init__(self, ds, batches, collate_fn): fc.store_attr()
    def __iter__(self): yield from (self.collate_fn(ds[i] for i in b) for b in self.batches)

In [15]:
dl = Dataloader(ds, batch_sampler, collate_fn)

In [16]:
next(iter(dl))

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([3, 2, 7, 0, 7, 6, 9, 2, 4, 6, 3, 3, 5, 9, 2, 0, 8, 8, 6, 9, 4, 6, 3, 4,
         1, 6, 7, 9, 9, 8, 2, 8, 6, 1, 2, 4, 3, 5, 9, 7, 1, 0, 5, 6, 5, 3, 6, 3,
         4, 5, 5, 6, 9, 3, 5, 9, 8, 1, 2, 3, 0, 7, 4, 9, 8, 1, 9, 9, 3, 1, 8, 1,
         4, 0, 1, 7, 2, 6, 1, 9, 7, 5, 8, 3, 1, 3, 5, 4, 1, 9, 8, 2, 3, 0, 9, 6,
         0, 3, 0, 6, 9, 0, 6, 0, 9, 8, 3, 2, 3, 1, 2, 3, 7, 6, 4, 1, 8, 6, 1, 5,
         4, 2, 7, 5, 5, 0, 7, 2, 3, 4, 9, 7, 2, 7, 1, 6, 2, 5, 1, 8, 5, 3, 1, 8,
         3, 0, 7, 1, 1, 2, 8, 8, 2, 3, 3, 9, 0, 0, 9, 0, 3, 5, 7, 7, 4, 8, 4, 9,
         2, 7, 5, 6, 6, 9, 1, 4, 6, 5, 9, 8, 7, 0, 8, 0, 1, 2, 3, 4, 3, 0, 2, 8,
         3, 1, 8, 9, 4, 1, 1, 1, 7, 6, 3, 2, 7, 5, 4, 4, 6, 8, 3, 4, 1, 9, 6, 8,
         

In [17]:
# import torch.multiprocessing as mp

# class DataLoader():
#     def __init__(self, ds, batchs, n_workers=1, collate_fn=collate_fn): fc.store_attr()
#     def __iter__(self):
#         with mp.Pool(self.n_workers) as ex: yield from ex.map(self.ds.__getitem__, iter(self.batchs))

In [18]:
class Dataset:
    def __init__(self, xs, ys): fc.store_attr()
    def __len__(self): return len(self.xs)
    def __getitem__(self, i): return self.xs[i], self.ys[i]

In [19]:
ds = Dataset(x_train, y_train)
# train_dl = DataLoader(ds, batch_sampler, collate_fn=collate_fn, n_workers=2)
# it = iter(train_dl)

In [20]:
# next(it)

In [21]:
%timeit for _ in dl: pass

246 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
from torch.multiprocessing import Pool

In [23]:
class MDataloader:
    def __init__(self, ds, batches, collate_fn, n_workers=1): fc.store_attr()
    def __iter__(self):
        with Pool(self.n_workers) as p: yield from p.map(self.ds.__getitem__, iter(self.batches))

In [24]:
dl = MDataloader(ds, batch_sampler, collate_fn, 2)

In [25]:
it = iter(dl)

In [26]:
next(it)

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([0, 9, 4, 7, 0, 1, 7, 5, 7, 2, 6, 1, 9, 7, 9, 3, 0, 7, 6, 8, 7, 5, 8, 6,
         4, 1, 5, 7, 9, 2, 0, 3, 9, 9, 6, 2, 2, 9, 6, 8, 9, 7, 9, 5, 7, 4, 2, 6,
         8, 7, 4, 6, 7, 3, 1, 6, 8, 2, 3, 3, 4, 4, 4, 3, 7, 4, 7, 2, 2, 9, 1, 9,
         9, 8, 1, 2, 8, 2, 4, 2, 0, 7, 5, 9, 3, 3, 8, 9, 6, 0, 5, 3, 1, 6, 6, 7,
         1, 1, 2, 1, 2, 8, 8, 9, 4, 7, 4, 4, 5, 0, 6, 4, 4, 2, 0, 8, 6, 5, 0, 9,
         7, 0, 7, 4, 0, 7, 4, 0, 2, 3, 1, 3, 0, 0, 2, 5, 3, 3, 6, 1, 6, 1, 7, 4,
         2, 4, 8, 0, 4, 7, 1, 4, 7, 8, 8, 3, 7, 8, 0, 2, 9, 4, 0, 4, 9, 3, 6, 4,
         1, 4, 3, 0, 4, 8, 3, 8, 1, 7, 6, 7, 7, 1, 2, 6, 1, 1, 7, 2, 3, 9, 4, 1,
         8, 9, 2, 8, 7, 0, 0, 1, 0, 7, 6, 7, 1, 2, 2, 4, 6, 8, 0, 9, 5, 9, 5, 2,
         

In [27]:
%timeit for _ in dl:pass

777 ms ± 50.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
