In [None]:
%matplotlib inline


Datasets & DataLoaders
===================




PyTorch provides two data primitives: ``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``
that allow you to use pre-loaded datasets as well as your own data.
``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around
the ``Dataset`` to enable easy access to the samples.


Loading a Dataset
-------------------
PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset: `torchvision.datasets`, `torchtext.datasets` and `torchaudio.datasets`.

Here is an example of how to load the Fashion-MNIST dataset from TorchVision.
Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples.
Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

We load the `FashionMNIST Dataset <https://pytorch.org/vision/stable/datasets.html#fashion-mnist>`_ with the following parameters:
 - ``root`` is the path where the train/test data is stored,
 - ``train`` specifies training or test dataset,
 - ``download=True`` downloads the data from the internet if it's not available at ``root``.
 - ``transform`` and ``target_transform`` specify the feature and label transformations



In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=None,
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=None,
)

In [None]:
training_data

In [None]:
test_data

Iterating and Visualizing the Dataset
-----------------

We can index ``Datasets`` manually like a list: ``training_data[index]``.
We use ``matplotlib`` to visualize some samples in our training data.



In [None]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img, cmap="gray")
plt.show()

..
 .. figure:: /_static/img/basics/fashion_mnist.png
   :alt: fashion_mnist



--------------




Creating a Custom Dataset for your files
---------------------------------------------------

A custom Dataset class must implement three functions: `__init__`, `__len__`, and `__getitem__`.


In [None]:
import os
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):  
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        # load the files 
        filenames = []
        labels = []
        for filename in os.listdir(img_dir):
            filenames.append(filename)
            if filename.startswith("bird"):
                labels.append(0)
            elif filename.startswith("cat"):
                labels.append(1)
            elif filename.startswith("dog"):
                labels.append(2)
        self.filenames = filenames
        self.labels = labels

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.filenames[idx])
        image = Image.open(img_path)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
from torchvision.transforms import Resize, ToTensor, Compose

# pet_dataset = CustomImageDataset("./data/pets", transform = Compose([Resize((32,32)), ToTensor()]))
pet_dataset = CustomImageDataset("./data/pets", transform = ToTensor()) 

In [None]:
pet_dataset[1]

In [None]:
from torchvision.transforms import ToPILImage

raw_image, label = pet_dataset[1]
figure = plt.figure(figsize=(8, 8))
print("type of the image: ", type(raw_image))
if type(raw_image) == torch.Tensor:
    plt.imshow(ToPILImage()(raw_image))
else:
    plt.imshow(raw_image)
    
print(f"label: {label}")

In [None]:
raw_image.shape

Preparing your data for training with DataLoaders
-------------------------------------------------
The ``Dataset`` retrieves our dataset's features and labels one sample at a time. While training a model, we typically want to
pass samples in "minibatches", reshuffle the data at every epoch to reduce model overfitting, and use Python's ``multiprocessing`` to
speed up data retrieval.

``DataLoader`` is an iterable that abstracts this complexity for us in an easy API.



In [None]:
from torch.utils.data import DataLoader

pet_dataloader = DataLoader(pet_dataset, batch_size=12, shuffle=True)

Iterate through the DataLoader
--------------------------

In [None]:
for feature, label in pet_dataloader:
    print(f"Feature batch shape: {feature.size()}")
    print(label)

In [None]:
# Display image and label.
train_features, train_labels = next(iter(pet_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = ToPILImage()(train_features[0])
label = train_labels[0]
plt.imshow(img)
plt.show()
print(f"Label: {label}")

In [None]:
train_features