# Work with PyTorch Datasets
---

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

---

## Creating of custom dataset
PyTorch provides easy mechanism to work with datasets. You just need to inherit from `torch.utils.data.Dataset` and override 2 methods:
 - `__len__` in a way that len(dataset) returns the size of the dataset.
 - `__getitem__` to support the indexing such that dataset[i] can be used to get ith sample

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset

In [None]:
class RandomVectorDataset(Dataset):
    """Random vector dataset."""
    
    def __init__(self, random_shape, transform=None):
        """
        Args:
            random_shape (list): Shape of random data in dataset
            transform (callable, optional): Optional transformation to be applied on a sample.
        """
        self.raw_data = np.random.randn(*random_shape)
        self.transform = transform
        
    def __len__(self):
        return self.raw_data.shape[0]
    
    def __getitem__(self, idx):
        sample = {'random_vector': self.raw_data[idx]}
        if self.transform:
            sample = self.transform(sample)
        return sample

In [None]:
random_vector_dataset = RandomVectorDataset(random_shape=[10, 5])

In [None]:
len(random_vector_dataset)

In [None]:
random_vector_dataset[5]

---

## Apply transformations to dataset
We can create objects with `__call__` method applying transforamtions to data from dataset. To put more transformations together, we can use `torchvision.transforms.Compose`. PyTorch provides multiple prepared  image transformations in ``torchvision.transforms`.

In [None]:
from torchvision.transforms import Compose

In [None]:
class Add2(object):
    def __call__(self, sample):
        return {'random_vector': sample['random_vector']+2}

class ToTorchTensor(object):
    def __call__(self, sample):
        return {'random_vector': torch.from_numpy(sample['random_vector'])}

In [None]:
transformations = Compose([Add2(), ToTorchTensor()])

In [None]:
random_vector_dataset = RandomVectorDataset(random_shape=[10, 5], transform=transformations)
random_vector_dataset[5]

---

## Sampling batches from dataset
PyTorch provides iterator `torch.utils.data.DataLoader` for work with datasets based on `torch.utils.data.Dataset` class.   
It enables
 - batching the data
 - shuffling the data  
 - load the data in parallel manner using multiprocessing workers


In [None]:
from torch.utils.data import DataLoader

In [None]:
data_loader = DataLoader(dataset=random_vector_dataset, batch_size=10, num_workers=1, shuffle=True)

In [None]:
next(iter(data_loader))

---

## Explore prepared dataset Fashion MNIST

In [None]:
import pandas as pd
from torchvision import datasets, transforms

In [None]:
transformations = transforms.Compose([transforms.ToTensor()])

### Training data

In [None]:
train_dataset = datasets.FashionMNIST('./dataset_fashion_mnist/', download=True, train=True, transform=transformations)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

In [None]:
train_dataset.classes

In [None]:
len(train_dataset)

In [None]:
train_dataset[0]

In [None]:
next(iter(train_loader))

In [None]:
import matplotlib.pyplot as plt

plt.subplots_adjust(wspace=1.5, hspace=2.5)
fig = plt.figure(figsize=(20,25))

img_batch, label_batch = next(iter(train_loader))
img_batch = img_batch.squeeze(dim=1).numpy()
label_batch = label_batch.numpy()
for img_id in range(100):
    ax = plt.subplot(10, 10, img_id+1)
    img = img_batch[img_id]
    
    class_id = label_batch[img_id]
    class_name = train_dataset.classes[class_id]
    ax.imshow(img , cmap='gray')
    ax.set_title(class_name)
    ax.axes.set_axis_off()

### Validation data

In [None]:
from image_processing_workshop.visual import plot_image

In [None]:
valid_dataset = datasets.FashionMNIST('./dataset_fashion_mnist/', download=True, train=False, transform=transformations)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)

In [None]:
len(valid_dataset)

In [None]:
plot_image(valid_dataset[21][0], figsize=(5, 5))

In [None]:
labels = valid_dataset.targets
class_names = list(map(lambda class_id: valid_dataset.classes[class_id], labels))
df = pd.DataFrame({'class_names': class_names, 'class_ids': labels})
df.head(10)

In [None]:
fig = plt.figure(figsize=(10, 10))
df.loc[:,'class_ids'].plot(kind='hist', width=0.5)
ax = plt.gca()
ax_ticks = ax.xaxis.set_ticks(np.arange(0.25, 9, 0.9))
ax_labels = ax.xaxis.set_ticklabels(list(valid_dataset.classes), rotation=70)