# Implementing and using iterable datasets: What Could Go Wrong?


## Context

```python
for batch in DataLoader(dataset, batch_size=..., num_workers=...):
    # Training here
```


``dataset`` <-- This is what we'll talk about.
There are 2 types of datasets:

- Indexable, "MapStyle"
- Iterable (glorified Python generator). Potentially more powerful, but also much more dangerous


There are good reasons to use Iterable datasets (streaming, packing data into chunks to save io/bandwith latency, etc.).
And there are good reasons to avoid them.

## Goal

Understand different issues *users* (you 🫵) have to deal with when using **iterable** datasets.

**Disclaimer**: this talk might be confusing. It's actually the point (kinda).


## Let's start with the basics

#### Map-style datasets

In [1]:
import torch
import torch.utils.data as data


class MyMapStyleDS:
    
    def __init__(self, size=100):
        self.size = size
        
    def __getitem__(self, idx):  # Returns the i'th sample
        # Here: read from disk [+ decoding] [+ transforms]
        sample = idx
        return sample
    
    def __len__(self):
        return self.size
    
    
mapstyle_ds = MyMapStyleDS()
mapstyle_dl = data.DataLoader(mapstyle_ds, batch_size=10)

for batch in mapstyle_dl:
    print(batch)
    # Here: forward and backward passes

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])
tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])
tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])


#### Iterable datasets

In [2]:
class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        # Here: read from disk [+ decoding] [+ transforms]
        for sample in range(self.size):
            yield sample
    
    def __len__(self):
        return self.size
    
    
iter_ds = MyIterableDS()
iter_dl = data.DataLoader(iter_ds, batch_size=10)

for batch in iter_dl:
    print(batch)
    # Here: forward and backward passes

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])
tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])
tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])


### So far so good
## Let's add some parallelism  🕺💃

We'll cover:

- DataLoader parallelism
- DDP parallelism (if we have time, which we won't)

Fun fact: they're not mutually exclusive



### DataLoader parallelism

#### Map-style - EZPZ lemon squeezy

In [3]:
mapstyle_dl = data.DataLoader(mapstyle_ds, batch_size=10, num_workers=4)

for batch in mapstyle_dl:
    print(batch)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])
tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])
tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])


#### Iterable - ~EZPZ lemon squeezy~

In [4]:
iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4)

for batch in iter_dl:
    print(batch)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([50,

### Oops. What went wrong?


Let's dive into the [DataLoader internals](./imgs/from_mapstyle_to_iterable.html).


Mapstyle

<img src="imgs/mapstyle.png" width="500"/>

Iterable

<img src="imgs/iterable.png" width="500"/>



<img src="imgs/sharding.png" width="500"/>


Recap

- **Map-style dataset**: main DataLoader process is able to request specific indices from each worker
- **Iterable dataset**: there's no notion of "indices". All the DataLoader can do is to request the "next" sampler from each worker, via `None`.
  - So we need to tell each worker which samples belong to them.
  - We have to do that **manually**.
  - There's no standard or cannonical way.
  
 
TL;DR: it's hard 🥲

<img src="imgs/iterable_with_sharding.png" width="500"/>



In [5]:
class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        worker_info = data.get_worker_info()
        num_workers = worker_info.num_workers
        worker_id = worker_info.id
        
        for i, s in enumerate(range(self.size)):
            if i % num_workers == worker_id:
                yield s
    
    def __len__(self):
        return self.size
    
    
iter_ds = MyIterableDS()
iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4)

for batch in iter_dl:
    print(batch)

tensor([ 0,  4,  8, 12, 16, 20, 24, 28, 32, 36])
tensor([ 1,  5,  9, 13, 17, 21, 25, 29, 33, 37])
tensor([ 2,  6, 10, 14, 18, 22, 26, 30, 34, 38])
tensor([ 3,  7, 11, 15, 19, 23, 27, 31, 35, 39])
tensor([40, 44, 48, 52, 56, 60, 64, 68, 72, 76])
tensor([41, 45, 49, 53, 57, 61, 65, 69, 73, 77])
tensor([42, 46, 50, 54, 58, 62, 66, 70, 74, 78])
tensor([43, 47, 51, 55, 59, 63, 67, 71, 75, 79])
tensor([80, 84, 88, 92, 96])
tensor([81, 85, 89, 93, 97])
tensor([82, 86, 90, 94, 98])
tensor([83, 87, 91, 95, 99])


Works OK, but:

- Manual and boilerplate code, and there's no standard
- Notice the difference with Map-style Dataset (not a big deal tho)
- **Notice the batch size at the end!!**. This can have [bad consequences](https://github.com/pytorch/data/issues/302) when batch-norm is involved. Solution: use `drop_last=True`; but ideally we shouldn't need to.

------


<br>
<br>
<br>
<br>
<br>
<br>
<br>






#### Now let's take a look at DDP parallelism

What is DDP (Distributed Data Parallel)?
- N copies of the model, typically on N GPUs (== N DDP processes).
- The N models see different parts of the data  <-- **That's the important part**
- The N models' weights are kept equal via gradient synchronization


Let's look at this outside of this notebook (if we have time): see [this file](https://github.com/NicolasHug/iterable_ds_pres/blob/main/issues_with_ddp.py)

\<insert pain here\>

TL;DR:

- The exact same sharding issue happens (but for other reasons)
- So we need to shard across **DDP workers** (just like we sharded across DataLoader workers above)
- DataLoader multi-processing can be embedded within a DDP multi-process:
  - **So we need 2 levels of sharding**: DDP *and* DataLoader:

In [6]:
class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        
        worker_info = data.get_worker_info()
        num_dl_workers = worker_info.num_workers
        dl_worker_id = worker_info.id

        num_ddp_workers = dist.get_world_size()
        ddp_worker_id = dist.get_rank()
        
        for i, s in enumerate(range(self.size)):  # We need 2 levels of sharding!!
            if i % num_ddp_workers == ddp_worker_id:
                if i % num_dl_workers == dl_worker_id:
                    yield s
    
    def __len__(self):
        return self.size

But IRL you'll need a lot of glue code (e.g. only shard when DDP is on).

## More fun: shuffling

#### Map-style: EZPZ lemon squeezy

In [7]:
sampler = data.RandomSampler(mapstyle_ds)
mapstyle_dl = data.DataLoader(mapstyle_ds, batch_size=10, num_workers=4, sampler=sampler)

for batch in mapstyle_dl:
    print(batch)

tensor([67, 35, 27, 93, 50, 68, 33,  7, 19, 51])
tensor([37, 26, 87, 82, 60,  0, 62,  1, 69, 42])
tensor([ 5, 57, 32, 58, 73, 75, 17, 64, 91, 29])
tensor([44, 36, 81, 22, 28, 55, 24, 18, 53, 77])
tensor([48, 61, 56, 74, 52, 83, 21, 89,  6, 20])
tensor([54, 80, 92, 40, 72, 16, 65, 79, 31,  3])
tensor([14, 63, 13, 95, 46, 94,  2, 88, 41, 45])
tensor([70, 11, 25, 85, 47,  9, 84, 15, 90, 66])
tensor([43, 78, 71, 34, 30, 86, 96, 12,  4, 59])
tensor([97, 98, 10, 99, 76,  8, 49, 39, 23, 38])


#### Iterable: ~EZPZ lemon squeezy~

In [8]:
sampler = data.RandomSampler(iter_ds)
iter_ds = MyIterableDS()
# iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4, sampler=sampler)

#### OK, let's shuffle manually then

In [9]:
import random

class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        worker_info = data.get_worker_info()
        num_workers = worker_info.num_workers
        worker_id = worker_info.id
        
        buffer = []
        
        for i, s in enumerate(range(self.size)):
            if i % num_workers == worker_id:
                buffer.append(s)
        
        random.shuffle(buffer)
        
        yield from buffer
    
    def __len__(self):
        return self.size

iter_ds = MyIterableDS()
iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4)

for batch in iter_dl:
    print(batch)

tensor([ 4, 20,  0, 76, 56, 40, 92, 96, 60, 44])
tensor([73, 53, 33, 69, 41,  9, 77, 97, 37, 93])
tensor([62,  2, 46, 34, 50, 66, 82, 86, 74, 26])
tensor([67, 51, 59, 27, 71, 47, 35, 91, 55, 95])
tensor([52, 68,  8, 80, 12, 48, 32, 24, 88, 36])
tensor([85, 13, 45, 57, 21, 89,  5, 61, 49, 29])
tensor([22, 58, 94, 14, 54, 18, 42, 10, 70, 78])
tensor([ 7, 75, 11, 63, 87, 43, 23, 83, 31, 15])
tensor([84, 64, 72, 16, 28])
tensor([81, 65,  1, 25, 17])
tensor([90, 30, 98,  6, 38])
tensor([99, 79,  3, 39, 19])


#### Looks random 🤩

Narrator: *It's not*

And it's **not obvious** to diagnose. Each individual worker is only shuffling **within its own shard**!

<img src="imgs/shard_before_shuffle.png" width="500"/>



Blue with blue, yellow with yellow... That's not uniform shuffling: the same samples are always being batched together.

It can have dramatic effects on the accuracy, especially when the underlying files are stored in a per-class folder structure.

## We need to shuffle before we shard


So let's do that:

In [22]:
class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        worker_info = data.get_worker_info()
        num_workers = worker_info.num_workers
        worker_id = worker_info.id
        
        buffer = []
        for s in range(self.size):
            buffer.append(s)
                
        random.shuffle(buffer)  # Shuffle ...
        
        for i, s in enumerate(buffer):  # ... then shard
            if i % num_workers == worker_id:
                yield s
    
    def __len__(self):
        return self.size

iter_ds = MyIterableDS()
iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4)

batches = []
for batch in iter_dl:
    print(batch)
    batches.append(batch)

tensor([56,  0, 87,  8, 61, 38,  6,  1, 33, 19])
tensor([30, 68, 96, 60, 24, 27, 14, 41, 17, 19])
tensor([62, 21, 54, 50, 65,  6,  9,  0, 90, 80])
tensor([16, 41, 81, 90, 99, 82, 49, 94,  9, 25])
tensor([43, 20, 55, 83, 71, 39, 79, 31, 30, 17])
tensor([46, 16, 89, 43, 80, 63, 31, 81, 64, 87])
tensor([30, 98, 87, 47, 40, 56, 41, 44, 16,  4])
tensor([87, 42, 69,  6, 57, 15, 67, 14, 18, 30])
tensor([57, 22, 54, 13, 15])
tensor([ 4, 44, 22, 52, 72])
tensor([64, 36, 51, 69, 37])
tensor([73, 71, 12, 10, 83])


#### Did it work now? 🤩


Narrator: *it didn't*

In [23]:
all_samples = torch.cat(batches)
_, counts = torch.unique(all_samples, return_counts=True)
counts, counts.shape

(tensor([2, 1, 2, 3, 1, 2, 1, 1, 1, 2, 2, 3, 2, 1, 2, 1, 1, 2, 1, 1, 1, 4, 2, 1,
         1, 1, 1, 1, 1, 3, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2,
         1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 2, 4, 1, 2, 1, 1, 1, 1]),
 torch.Size([67]))

**Reason**: all workers use a [different RNG seed](https://github.com/pytorch/pytorch/blob/3ac27e78ca5429b47a63826a1bb678031d20bffd/torch/utils/data/_utils/worker.py#L217-L223) for shuffling, so some samples can be missing or duplicated.


<img src="imgs/iterable_with_seed.png" width="500"/>


#### So we need the same RNG seed across workers... Right?

In [10]:
def worker_init_fn(worker_id):
    # This is wrong for *at least* 2 reasons
    random.seed(0)  

iter_ds = MyIterableDS()
iter_dl = data.DataLoader(iter_ds, batch_size=10, num_workers=4, worker_init_fn=worker_init_fn)

batches = []
for batch in iter_dl:
    print(batch)
    batches.append(batch)


all_samples = torch.cat(batches)
_, counts = torch.unique(all_samples, return_counts=True)
counts, counts.shape

tensor([56,  0, 92, 72, 24, 20, 28, 40, 88, 80])
tensor([57,  1, 93, 73, 25, 21, 29, 41, 89, 81])
tensor([58,  2, 94, 74, 26, 22, 30, 42, 90, 82])
tensor([59,  3, 95, 75, 27, 23, 31, 43, 91, 83])
tensor([16,  8, 84, 12, 68, 44, 76, 36, 96, 60])
tensor([17,  9, 85, 13, 69, 45, 77, 37, 97, 61])
tensor([18, 10, 86, 14, 70, 46, 78, 38, 98, 62])
tensor([19, 11, 87, 15, 71, 47, 79, 39, 99, 63])
tensor([64, 32,  4, 52, 48])
tensor([65, 33,  5, 53, 49])
tensor([66, 34,  6, 54, 50])
tensor([67, 35,  7, 55, 51])


(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]),
 torch.Size([100]))

#### Shuffling works now 🤓 🤓 🤓 🤓
### But we broke the random augmentations 💀💀💀💀
### Oh, and no, shuffling still doesn't work: the seed should be epoch-dependent

We still want the RNG of the sample transformations (done within the workers) to be different.

We **only** want the seed for **shuffling** to be the same across workers (meaning: across DataLoader workers **and** DDP workers).


# We need to separate RNG streams
## - for shuffling: a unique one across all workers
## - for random augmentation: one or more

🎶 *Hello Statefulness my old friend...* 🎶

Idea: pass a shuffling seed in `__init__()` and create a new RNG stream there. But **be careful**: we still need to make sure the seed will be different across epochs (remember [`DistributedSampler.set_epoch()`](https://github.com/pytorch/pytorch/blob/401179f263d5ba22731de107874f41fdd256737f/torch/utils/data/distributed.py#L100)?)


And remember: everything is at least twice as hard with multiple GPUs.

TL;DR: it's hard 🥲

*The implementation of an iterable dataset that properly handles distributed training and multiprocessing, shuffling and sharding, all within a simple user-friendly API that does not expose low-level implementation details, is left as an exercise to the reader.*

## Take away:

We've seen a bunch of issues with iterable datasets:
  - Related to sharding / multiprocessing
    - inside the DataLoader
    - outside the DataLoader (DDP)
    - when both are involved
  - Related to shuffling
  - Related to the *order* of these operation: shuffling before sharding

Hopefully, you're confused. It is confusing! **You shouldn't be expected to understand all this**.



# Don't implement Iterable datasets yourself
# Use existing solutions. (But don't assume they're all correct!)
# If you're given the power to shuffle and partition manually: be EXTREMELY careful