In [2]:
import numpy as np
import torch as t
from torch.utils.data import Dataset

In [3]:
rng = np.random.default_rng()

In [4]:
class MyMappedDataset(Dataset):
    def __init__(self, n=5, m=10):
        self._x = np.arange(n * m).reshape(m, n)
        self._y = rng.choice([0, 1], size=m, p=[0.7, 0.3])

    def __getitem__(self, idx):
        return self._x[idx], self._y[idx]

    def __len__(self):
        return self._x.shape[0]

In [6]:
ds = MyMappedDataset()
for i in range(len(ds)):
    print(ds[i])

(array([0, 1, 2, 3, 4]), 0)
(array([5, 6, 7, 8, 9]), 0)
(array([10, 11, 12, 13, 14]), 1)
(array([15, 16, 17, 18, 19]), 0)
(array([20, 21, 22, 23, 24]), 1)
(array([25, 26, 27, 28, 29]), 0)
(array([30, 31, 32, 33, 34]), 0)
(array([35, 36, 37, 38, 39]), 0)
(array([40, 41, 42, 43, 44]), 0)
(array([45, 46, 47, 48, 49]), 1)


In [7]:
from torch.utils.data import IterableDataset

In [8]:
class MyStreamingDataset(IterableDataset):
    def __init__(self, n):
        super().__init__()
        self._n = n

    def __iter__(self):
        start = 0
        while True:
            x = np.arange(start, start+self._n)
            y = rng.choice([0, 1], size=1, p=[0.7, 0.3])
            yield x, y
            start += self._n

In [10]:
ctr = 0
for x, y in MyStreamingDataset(5):
    print(x, y)
    if ctr >= 5: break
    ctr += 1

[0 1 2 3 4] [0]
[5 6 7 8 9] [0]
[10 11 12 13 14] [0]
[15 16 17 18 19] [0]
[20 21 22 23 24] [1]
[25 26 27 28 29] [0]
