In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

It is not possible to hold large datasets in memory. So, we feed batches of data to neural network models during training. 

In this tutorial we will learn how to create batches of data.

In [4]:
# Generate some data
x = torch.rand((10, 2), dtype=torch.float)
y = torch.rand((10, 1), dtype=torch.float)

# PyTorch's Dataset Class

torch.utils.data.Dataset is an abstract class representing a dataset. Your custom dataset should inherit Dataset and override the following methods:

- `__len__` so that len(dataset) returns the size of the dataset.
- `__getitem__` to support the indexing such that dataset[i] can be used to get ith sample.

Learn more: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class

In [6]:
# We will define a class that fethces one data point at a time
class DatasetIter(Dataset):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  # Determine the length of the dataset
  def __len__(self):
    return len(self.x)
  
  # Fetch a specific row
  def __getitem__(self, ix):
      return self.x[ix], self.y[ix]

In [8]:
# Create an instance of the dataset class
dataset_inst = DatasetIter(x, y)

In [10]:
# Fetch data points from our recently created dataset generator
for i in range(len(dataset_inst)):
  print(dataset_inst[i])

(tensor([0.2668, 0.2282]), tensor([0.6828]))
(tensor([0.0741, 0.0221]), tensor([0.9067]))
(tensor([0.7355, 0.6016]), tensor([0.0859]))
(tensor([0.3360, 0.3168]), tensor([0.7913]))
(tensor([0.2151, 0.8686]), tensor([0.4917]))
(tensor([0.1432, 0.7005]), tensor([0.8465]))
(tensor([0.8164, 0.6122]), tensor([0.2009]))
(tensor([0.0711, 0.3155]), tensor([0.6406]))
(tensor([0.9035, 0.3447]), tensor([0.0497]))
(tensor([0.9868, 0.9642]), tensor([0.8722]))


Via the `for` loop we were able to fetch a single data point from our dataset generator. But we haven't achieved our objective of fetching batches of data. By using a for loop we loose some useful features such as, shuffling and multiprocessing loading.

PyTorch's `DataLoader` helps fetch batches from a dataset generator.

In [12]:
# Call DataLoader
dataloader = DataLoader(dataset_inst, batch_size=4,
                        shuffle=True, num_workers=0)  # num_workers=0 disables multiprocessing

In [14]:
for i in dataloader:
  print("New batch:\n", i)

New batch:
 [tensor([[0.2668, 0.2282],
        [0.9868, 0.9642],
        [0.3360, 0.3168],
        [0.7355, 0.6016]]), tensor([[0.6828],
        [0.8722],
        [0.7913],
        [0.0859]])]
New batch:
 [tensor([[0.1432, 0.7005],
        [0.9035, 0.3447],
        [0.0711, 0.3155],
        [0.2151, 0.8686]]), tensor([[0.8465],
        [0.0497],
        [0.6406],
        [0.4917]])]
New batch:
 [tensor([[0.8164, 0.6122],
        [0.0741, 0.0221]]), tensor([[0.2009],
        [0.9067]])]
