In [None]:
#default_exp data.load

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
#export
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
def twoepochs(d): return ' '.join(''.join(o) for _ in range(2) for o in d)
bs = 4
letters = list(string.ascii_lowercase)

## DataLoader

In [None]:
#export
def _wif(worker_id):
    info = get_worker_info()
    ds = info.dataset.d
    ds.nw,ds.offs = info.num_workers,info.id
    set_seed(info.seed)
    ds.wif()

class _FakeLoader(GetAttr):
    _auto_collation,collate_fn,drop_last,dataset_kind,_index_sampler = False,noops,False,_DatasetKind.Iterable,Inf.count
    def __init__(self, d, pin_memory, num_workers, timeout):
        self.dataset,self.default,self.worker_init_fn = self,d,_wif
        store_attr(self, 'd,pin_memory,num_workers,timeout')
    def __iter__(self): return iter(self.d._iter())

In [None]:
#export
@methods_kwargs
class DataLoader():
    reset=item_tfm=batch_tfm=wif = noops
    _methods = 'collate_fn batches reset wif sampler item batch_tfm item_tfm'.split()
    def __init__(self, items=None, bs=None, drop_last=False, shuffle=False, indexed=None,
                 num_workers=0, pin_memory=False, timeout=0, **kwargs):
        if indexed is None: indexed = items is not None and hasattr(items,'__getitem__')
        store_attr(self, 'items,bs,drop_last,shuffle,indexed')
        self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout)
        self.lock,self.rng,self.nw,self.offs = Lock(),random.Random(),1,0
        try: self.n = len(self.items)
        except TypeError: self.n = None
        assert not kwargs or not (bs is None and drop_last)

    def __iter__(self): return _loaders[self.fake_l.num_workers==0](self.fake_l)
    def _iter(self):
        self.it = iter(self.items) if self.items else None
        self.reset()
        batches = self.batches(self.sampler())
        return maps(self.collate_fn, self.batch_tfm, batches)
    
    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
    
    def batches(self, idxs):
        res = maps(self.item, self.item_tfm, idxs)
        return res if self.bs is None else chunked(res, self.bs, self.drop_last)

    def sampler(self):
        res = Inf.count if self.indexed else Inf.nones
        if self.n is not None:
            res = list(itertools.islice(res, self.n))
            res = self.rng.sample(res,len(res)) if self.shuffle else res
        return (b for i,b in enumerate(res) if i//(self.bs or 1)%self.nw==self.offs)

    def collate_fn(self, b): return (default_collate,default_convert)[self.bs is None](b)
    def item(self, s): return next(self.it) if s is None else self.items[s]

Override `batches` to return some specific finite iterable.

In [None]:
class LettersDL(DataLoader):
    def batches(self, idxs): return (string.ascii_lowercase[i:i+4] for i in range(0,26,4))

test_eq(L(LettersDL()), 'abcd,efgh,ijkl,mnop,qrst,uvwx,yz'.split(','))

Use `idxs` to get indexes of samples of this batch, if needed. If you didn't pass an `items`  then `idxs` is just a list of `None`s of appropriate length.

In [None]:
class RandDL(DataLoader):
    def batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))

L(RandDL())

(#24) [0.07752714554053142,0.20955938119915585,0.10367323351625546,0.15011170292525056,0.9157223411714032,0.432686603823224,0.3370742266525697,0.4438111811572829,0.21863541374677542,0.423844665703637...]

Pass a method to `__init__` instead of overriding if you prefer.

In [None]:
def _batches(self, idxs): return gen(lambda o:random.random(), idxs, lt(0.95))
L(DataLoader(batches=_batches))

(#6) [0.7472562585802442,0.7149032324417728,0.8157503365898758,0.10763048711119405,0.5474217530192991,0.6601405423134689]

Override `item` and use the default infinite sampler to get a stream of unknown length (`stop()` when you want to stop the stream).

In [None]:
class RandDL(DataLoader):
    def item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDL())

(#21) [0.8735032074727076,0.7700000752165664,0.7249022639733494,0.887712785383386,0.6677797749492357,0.031527541659260194,0.15276200815511343,0.3296331246526736,0.21424501268351115,0.07735320003778468...]

In [None]:
L(RandDL(bs=4, num_workers=4, drop_last=True)).mapped(len)

(#13) [4,4,4,4,4,4,4,4,4,4...]

If you don't set `bs`, then `items` is assumed to provide an iterator or a `__getitem__` that returns a batch.

In [None]:
ds1 = DataLoader(letters)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

test_shuffled(L(DataLoader(letters, shuffle=True)), letters)

ds1 = DataLoader(letters, indexed=False)
test_eq(ds1, letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = DataLoader(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2]),array([3,4,5]))
ds3 = DataLoader(t3)
test_eq_type(L(ds3), t2)

ds4 = DataLoader(t3, collate_fn=noops)
test_eq_type(L(ds4), t3)

If you do set `bs`, then `items` is assumed to provide an iterator or a `__getitem__` that returns a single item of a batch.

In [None]:
ds1 = DataLoader(letters, bs=4,drop_last=True,num_workers=1)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = DataLoader(letters,4,num_workers=2)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')

ds1 = DataLoader(range(12), bs=4, num_workers=3)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = DataLoader([str(i) for i in range(11)], bs=4, num_workers=5)
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))

it = iter(DataLoader(map(noop,range(20)), bs=4))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

In [None]:
class SleepyDL(list):
    def __getitem__(self,i):
        time.sleep(random.random()/50)
        return super().__getitem__(i)

t = SleepyDL(letters)

%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_eq(DataLoader(t, num_workers=4), letters)
test_shuffled(L(DataLoader(t, shuffle=True, num_workers=4)), letters)

CPU times: user 4.85 ms, sys: 272 µs, total: 5.12 ms
Wall time: 291 ms
CPU times: user 11.8 ms, sys: 16.1 ms, total: 27.9 ms
Wall time: 160 ms
CPU times: user 12.2 ms, sys: 22.1 ms, total: 34.3 ms
Wall time: 109 ms


In [None]:
class SleepyQueue():
    "Simulate a queue with varying latency"
    def __init__(self, q): self.q=q
    def __iter__(self):
        while True:
            time.sleep(random.random()/100)
            try: yield self.q.get_nowait()
            except queues.Empty: return

q = Queue()
for o in range(30): q.put(o)
it = SleepyQueue(q)

%time test_shuffled(L(DataLoader(it, num_workers=4)), range(30))

CPU times: user 22.7 ms, sys: 25.1 ms, total: 47.7 ms
Wall time: 106 ms


## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_dataloader.ipynb.
Converted 01a_script.ipynb.
Converted 02_transforms.ipynb.
Converted 03_pipeline.ipynb.
Converted 04_data_external.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_source.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_rect_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_test_models_core.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 35_tutorial_wikitex