# Summary

- `Dataset` used to provide finite number of samples. By default each sample will be fetched no more than once.
- `IterableDataset` is used when the number of samples are to be decided. Samples can repeat.

## Dataset

#### `for-loop`

In [45]:
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(idx)
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset, shuffle=True)

for i in loader:
    print(i)


tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
4
tensor([-1.2502])
7
tensor([0.0544])
6
tensor([-0.2071])
9
tensor([-0.3889])
2
tensor([-1.2054])
0
tensor([-2.0260])
3
tensor([-0.9122])
5
tensor([0.8032])
1
tensor([-2.0655])
8
tensor([0.1378])


#### `next(iter())`

In [46]:
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print(idx)
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset, shuffle=True)

it = iter(loader)
print(next(it))
print(next(it))
print(next(it))


tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
4
tensor([-1.2502])
7
tensor([0.0544])
6
tensor([-0.2071])


## Iterable Dataset

#### Work! infinite loop in `__iter__`, use `next(iter(loader))` to fetch data

In [47]:
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        while True:
            index = torch.randint(0, len(self.data), size=(1,))
            yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

it = iter(loader)
print(next(it))
print(next(it))
print(next(it))


tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])
tensor([[-2.0260]])
tensor([[-1.2054]])


#### Can also be reproduced by `Dataset`

In [51]:

from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = torch.randint(0, len(self.data), size=(1,)) # override the index
        print(idx.item())
        return self.data[idx]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomDataset(data)
loader = DataLoader(dataset=dataset)

for i in loader:
    print(i)


tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
2
tensor([[-1.2054]])
0
tensor([[-2.0260]])
2
tensor([[-1.2054]])
1
tensor([[-2.0655]])
6
tensor([[-0.2071]])
5
tensor([[0.8032]])
9
tensor([[-0.3889]])
4
tensor([[-1.2502]])
5
tensor([[0.8032]])
9
tensor([[-0.3889]])


#### NOT WORK! infinite loop in `__iter__`, use `for-loop` to fetch data

In [41]:
# This will create an infinite loop!!!
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        while True:
            index = torch.randint(0, len(self.data), size=(1,))
            yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

for i in loader:
    print(i)

tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])
tensor([[-2.0260]])
tensor([[-1.2054]])
tensor([[-2.0655]])
tensor([[-0.2071]])
tensor([[0.8032]])
tensor([[-0.3889]])
tensor([[-1.2502]])
tensor([[0.8032]])
tensor([[-0.3889]])
tensor([[-0.2071]])
tensor([[-0.3889]])
tensor([[-0.3889]])
tensor([[-0.3889]])
tensor([[-0.3889]])
tensor([[-0.9122]])
tensor([[-0.3889]])
tensor([[0.1378]])
tensor([[-1.2502]])
tensor([[0.1378]])
tensor([[-0.9122]])
tensor([[0.8032]])
tensor([[-1.2502]])
tensor([[-0.9122]])
tensor([[-0.9122]])
tensor([[0.1378]])
tensor([[-0.3889]])
tensor([[-2.0655]])
tensor([[-0.2071]])
tensor([[0.1378]])
tensor([[-0.2071]])
tensor([[-2.0260]])
tensor([[-2.0260]])
tensor([[-0.9122]])
tensor([[-1.2502]])
tensor([[-0.9122]])
tensor([[0.1378]])
tensor([[-0.2071]])
tensor([[-0.2071]])
tensor([[-2.0655]])
tensor([[0.0544]])
tensor([[-0.2071]])
tensor([[0.1378]])
tensor([[-0.2071]])
tensor([[0.8032]])
tens

KeyboardInterrupt: 

#### NOT WORK! single iterator in `__iter__`, use `next(iter(loader))` to fetch data

In [40]:
# Only the 1st data can be fetched!!!
import torch
from torch.utils.data import DataLoader, IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        index = torch.randint(0, len(self.data), size=(1,))
        yield self.data[index]


torch.manual_seed(1337)
data = torch.randn(size=(10,))
print(data)
dataset = CustomIterableDataset(data)
loader = DataLoader(dataset=dataset)

it = iter(loader)
print(next(it))
print(next(it))
print(next(it))


tensor([-2.0260, -2.0655, -1.2054, -0.9122, -1.2502,  0.8032, -0.2071,  0.0544,
         0.1378, -0.3889])
tensor([[-1.2054]])


StopIteration: 