In [1]:
# Szymon Manduk
# August 5, 2021
# Import data from your own dataset and iterate over data
# The notebook is a part of "Data preparation with Dataset and DataLoader in Pytorch" blog post
# https://aigeekprogrammer.com/data-preparation-with-dataset-and-dataloader-in-pytorch/

In [11]:
import torch
from torch.utils.data import Dataset, DataLoader

In [16]:
# The first step is to define our own class that inherits from the "abstract" Dataset class
# The implementation requires overwriting 2 methods: _getitem__ and __len__
# Also, we move the data generation to the init method

class RandomIntDataset(Dataset):
    def __init__(self, start, stop, x, y):
        # randomly generate an array of ints that will act as data
        self.data   = torch.randint(start, stop, (x,y))
        # randomly generate a vector of ints that act as labels
        self.labels = torch.randint(0, 10, (x,))
    
    def __len__(self):
        # the size of the dataset is equal to the length of the vector
        return len(self.labels)
    
    def __str__(self):
        #  combine both data structures to present them in the form of a single table
        return str(torch.cat((self.data, self.labels.unsqueeze(1)), 1))

    def __getitem__(self, i):
        # the method returns a pair of data - label for a given index i
        return self.data[i], self.labels[i]
    


In [17]:
dataset = RandomIntDataset(100, 1000, 500, 10)

In [18]:
len(dataset)

500

In [19]:
print(dataset)

tensor([[627, 160, 881,  ..., 485, 457,   9],
        [705, 511, 947,  ..., 744, 465,   5],
        [692, 427, 701,  ..., 639, 378,   9],
        ...,
        [601, 228, 749,  ..., 155, 823,   4],
        [599, 627, 802,  ..., 179, 693,   4],
        [740, 861, 697,  ..., 286, 160,   4]])


In [20]:
dataset_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [21]:
data, labels = next(iter(dataset_loader))

In [24]:
data

tensor([[724, 232, 501, 555, 369, 142, 504, 226, 849, 924],
        [170, 510, 711, 502, 641, 458, 378, 927, 324, 701],
        [838, 482, 299, 379, 181, 394, 473, 739, 888, 265],
        [945, 421, 983, 531, 237, 106, 261, 399, 161, 459]])

In [25]:
labels

tensor([3, 6, 9, 7])