# Summary

- `Dataset` used to provide finite number of samples. By default each sample will be fetched no more than once.
- `IterableDataset` is used when the number of samples are to be decided. Samples can repeat.

## Dataset

#### `for-loop`

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(f"Index - {idx}")
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset, shuffle=True)

for d in loader:
    print(d)


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
Index - 4
tensor([-1.2502])
Index - 7
tensor([0.0544])
Index - 6
tensor([-0.2071])
Index - 9
tensor([-0.3889])
Index - 2
tensor([-1.2054])
Index - 0
tensor([-2.0260])
Index - 3
tensor([-0.9122])
Index - 5
tensor([0.8032])
Index - 1
tensor([-2.0655])
Index - 8
tensor([0.1378])


#### `next(iter())`

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(f"Index - {idx}")
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset, shuffle=True)

it = iter(loader)
print(next(it))
print(next(it))
print(next(it))


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
Index - 4
tensor([-1.2502])
Index - 7
tensor([0.0544])
Index - 6
tensor([-0.2071])


## Iterable Dataset

#### Work! infinite loop in `__iter__`, use `next(iter(loader))` to fetch data

In [5]:
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        while True:
            index = torch.randint(0, len(self.data), size=(1,))
            yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

it = iter(loader)
for _ in range(10):
    print(next(it))


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])
tensor([[-2.0260]])
tensor([[-1.2054]])
tensor([[-2.0655]])
tensor([[-0.2071]])
tensor([[0.8032]])
tensor([[-0.3889]])
tensor([[-1.2502]])
tensor([[0.8032]])
tensor([[-0.3889]])


#### Can also be reproduced by `Dataset`

In [6]:
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = torch.randint(0, len(self.data), size=(1,))  # override the index
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset)

for d in loader:
    print(d)


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])
tensor([[-2.0260]])
tensor([[-1.2054]])
tensor([[-2.0655]])
tensor([[-0.2071]])
tensor([[0.8032]])
tensor([[-0.3889]])
tensor([[-1.2502]])
tensor([[0.8032]])
tensor([[-0.3889]])


#### NOT WORK! single iterator in `__iter__`, use `next(iter(loader))` to fetch data

In [7]:
# Only the 1st data can be fetched!!!
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        index = torch.randint(0, len(self.data), size=(1,))
        yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

it = iter(loader)
for _ in range(10):
    print(next(it))


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])


StopIteration: 

#### NOT WORK! infinite loop in `__iter__`, use `for-loop` to fetch data

In [12]:
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        while True:
            index = torch.randint(0, len(self.data), size=(1,))
            yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(f"Original data: {data}")
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

print(loader.sampler)

for d in loader:
    # This will create an infinite loop!!!
    print(d)


Original data: tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
<torch.utils.data.dataloader._InfiniteConstantSampler object at 0x7fe80453f3a0>
