In [None]:
#default_exp data.load

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
def twoepochs(d): return ' '.join(''.join(o) for _ in range(2) for o in d)
bs = 4

## DataLoader

In [None]:
#export
class Dataset():
    _methods = 'collate_fn indexes batches reset wif sampler'.split()
    @kwargs(_methods)
    def __init__(self, items=None, bs=None, drop_last=False, shuffle=False, indexed=None, **kwargs):
        if indexed is None: indexed = items is not None and hasattr(items,'__getitem__')
        self.items,self.bs,self.drop_last,self.shuffle,self.indexed = items,bs,drop_last,shuffle,indexed
        self.seed,self.rng,self.nw,self.offs = None,random.Random(),1,0
        replace_methods(self, kwargs)
        try: self.n = len(self.items)
        except TypeError: self.n = None
        assert not kwargs or not (bs is None and drop_last)

    def __iter__(self):
        if self.seed is not None: set_seed(self.seed)
        self.it = iter(self.items) if self.items else None
        idxs = (b for i,b in enumerate(self.sampler()) if i%self.nw==self.offs)
        self.reset()
        return map(self.collate_fn, self.batches(iter(idxs)))
    
    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
    
    def batches(self, idxs):
        res = map(self.item, idxs)
        return res if self.bs is None else chunked(res, self.bs, self.drop_last)

    def sampler(self):
        res = Inf.count if self.indexed else Inf.nones
        if self.n is None: return res
        res = list(itertools.islice(res, self.n))
        return self.rng.sample(res,len(res)) if self.shuffle else res

    reset = wif = noop   
    def collate_fn(self, b): return (default_collate,default_convert)[self.bs is None](b)
    def item(self, s): return next(self.it) if s is None else self.items[s]

Override `batches` to return some specific finite iterable.

In [None]:
class LettersDS(Dataset):
    def batches(self, idxs): return (string.ascii_lowercase[i:i+4] for i in range(0,26,4))

test_eq(L(LettersDS()), 'abcd,efgh,ijkl,mnop,qrst,uvwx,yz'.split(','))

Use `idxs` to get indexes of samples of this batch, if needed. 

In [None]:
class RandDS(Dataset):
    def batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))

L(RandDS())

(#36) [0.06818208925726543,0.8161129101047011,0.3541381449304538,0.23452748508589716,0.04810473341773647,0.7190683538386087,0.6988808013110566,0.48821459727967476,0.9236314892435064,0.5019581553066688...]

In [None]:
def _batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))
L(Dataset(batches=_batches))

(#11) [0.2633041256099563,0.21067821178279578,0.8022175455289592,0.8412889993343979,0.19269867281749564,0.8470297360848011,0.2461200121766085,0.5689809894419389,0.5840661089416174,0.3201851958561802...]

Override `batch` and use the default infinite sampler to get a stream of unknown length (`raise StopIteration` when you want to stop the stream).

In [None]:
class RandDS(Dataset):
    def item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDS())

(#16) [0.517183380076361,0.6216792367853115,0.795131539496894,0.7222952309094344,0.13578518881114499,0.26933490932579907,0.9113684524149119,0.6457585545124523,0.28624094577564274,0.54287981489844...]

`items` is assumed to have a `__next__` that returns a batch.

In [None]:
letters = list(string.ascii_lowercase)

In [None]:
ds1 = Dataset(letters)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

test_shuffled(L(Dataset(letters, shuffle=True)), letters)

ds1 = Dataset(letters, indexed=False)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = Dataset(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2]),array([3,4,5]))
ds3 = Dataset(t3)
test_eq_type(L(ds3), t2)

ds4 = Dataset(t3, collate_fn=noops)
test_eq_type(L(ds4), t3)

In [None]:
ds1 = Dataset(letters,4,drop_last=True)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = Dataset(range(12), bs=4)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = Dataset([str(i) for i in range(11)], bs=4)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(Dataset(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

In [None]:
class RandBatchDS(Dataset):
    def item(self, s):
        r = random.random()
        if r>0.9: raise StopIteration
        return r

ds = RandBatchDS(bs=4)
L(ds)

(#10) [tensor([0.0831, 0.5347, 0.3338, 0.3114], dtype=torch.float64),tensor([0.7578, 0.6927, 0.1735, 0.5516], dtype=torch.float64),tensor([0.2478, 0.1672, 0.5012, 0.2118], dtype=torch.float64),tensor([0.8800, 0.5981, 0.8254, 0.7693], dtype=torch.float64),tensor([0.2011, 0.0635, 0.7847, 0.5918], dtype=torch.float64),tensor([0.3841, 0.6587, 0.7413, 0.8817], dtype=torch.float64),tensor([0.6461, 0.0513, 0.6019, 0.3876], dtype=torch.float64),tensor([0.4892], dtype=torch.float64),tensor([0.7487, 0.7090, 0.0338, 0.2595], dtype=torch.float64),tensor([0.8138, 0.0474, 0.2306, 0.7621], dtype=torch.float64)]

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset
    ds.nw,ds.offs,ds.seed = info.num_workers,info.id,info.seed
    ds.wif()

In [None]:
#export
class DataLoader(GetAttr):
    _auto_collation,collate_fn,drop_last,dataset_kind = False,noops,False,_DatasetKind.Iterable
    @delegates(Dataset.__init__)
    def __init__(self, items, num_workers=0, pin_memory=False, timeout=0, tfm=None, **kwargs):
        self.default = self.dataset = items if isinstance(items, Dataset) else Dataset(items, **kwargs) 
        self.pin_memory,self.tfm,self.worker_init_fn,self._index_sampler = pin_memory,tfm or noop,_wif,Inf.count
        self.num_workers = 0 if num_workers < 0 else num_workers
        self.timeout = 0 if timeout < 0 else timeout
        self.dataset.lock = threading.Lock()

    def __iter__(self):  return map(self.tfm, _loaders[self.num_workers==0](self))
    def __len__(self): return len(self.dataset)

In [None]:
[len(L(DataLoader(ds))) for _ in range(4)]

[12, 23, 1, 10]

In [None]:
[len(L(DataLoader(ds, num_workers=4))) for _ in range(4)]

[77, 40, 56, 23]

In [None]:
class SleepyDS(list):
    def __getitem__(self,i):
        time.sleep(random.random()/20)
        return super().__getitem__(i)

In [None]:
t = SleepyDS(letters)
%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=1), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_shuffled(L(DataLoader(t, shuffle=True, num_workers=4)), letters)

CPU times: user 1.82 ms, sys: 995 µs, total: 2.81 ms
Wall time: 633 ms
CPU times: user 11.8 ms, sys: 3.9 ms, total: 15.7 ms
Wall time: 774 ms
CPU times: user 12.9 ms, sys: 9.14 ms, total: 22 ms
Wall time: 344 ms
CPU times: user 10.6 ms, sys: 23.9 ms, total: 34.5 ms
Wall time: 215 ms


In [None]:
dl = DataLoader(map(noop,t), num_workers=4)
# %time test_eq(dl, sum(zip(letters,letters), ()))

In [None]:
%time list(dl)

CPU times: user 21.6 ms, sys: 20.9 ms, total: 42.4 ms
Wall time: 47.7 ms


['a',
 'a',
 'a',
 'a',
 'b',
 'b',
 'b',
 'b',
 'c',
 'c',
 'c',
 'c',
 'd',
 'd',
 'd',
 'd',
 'e',
 'e',
 'e',
 'e',
 'f',
 'f',
 'f',
 'f',
 'g',
 'g',
 'g',
 'g',
 'h',
 'h',
 'h',
 'h',
 'i',
 'i',
 'i',
 'i',
 'j',
 'j',
 'j',
 'j',
 'k',
 'k',
 'k',
 'k',
 'l',
 'l',
 'l',
 'l',
 'm',
 'm',
 'm',
 'm',
 'n',
 'n',
 'n',
 'n',
 'o',
 'o',
 'o',
 'o',
 'p',
 'p',
 'p',
 'p',
 'q',
 'q',
 'q',
 'q',
 'r',
 'r',
 'r',
 'r',
 's',
 's',
 's',
 's',
 't',
 't',
 't',
 't',
 'u',
 'u',
 'u',
 'u',
 'v',
 'v',
 'v',
 'v',
 'w',
 'w',
 'w',
 'w',
 'x',
 'x',
 'x',
 'x',
 'y',
 'y',
 'y',
 'y',
 'z',
 'z',
 'z',
 'z']

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_dataloader.ipynb.
Converted 01a_script.ipynb.
Converted 02_transforms.ipynb.
Converted 03_pipeline.ipynb.
Converted 04_data_external.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_source.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_rect_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_test_models_core.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 35_tutorial_wikitex