# WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS
## 1. Goal: To develop a dataloader that provides 1) stacked images, 2) paired segmentation mask, 3) classification label
```
Author: Joohyung Lee
References: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
            https://jdhao.github.io/2017/10/23/pytorch-load-data-and-make-batch/
            https://pytorch.org/docs/stable/data.html
```

## 2. Architecture in-summary
```
Two classes are needed
```
* **Class 1:** dataset (torch.utils.data.Dataset or torch.utils.data.IterableDataset)
 * Enables access of each sample by its index
 * It can output a tuple, list, or dictionary of required data- e.g., {'image': ..., 'mask': ..., 'category': ...}  
 * Built-in class: torchvision.datasets.ImageFolder
 * Augmentation by cascading a series of transforms by providing a list of transforms to torchvision.transforms.Compose
 
&nbsp;
* **Class 2:** torch.utils.data.DataLoader
 * Creates a data batch
 * Iterator that receives torch.utils.data.Dataset object with various useful functionalities-e.g., batching, shuffling, multi-processing
&nbsp;

## 3. Dataset
### 3-1. torch.utils.data.Dataset
```
Inherit built-in Dataset and override these methods:
```
* **`__init__:`** Constructor
 * Read needed files (e.g., *.csv, *.txt, etc) but do NOT actually read the image
<br/><br/>

* **`__len__:`** len(custom_dataset) will return its output
<br/><br/>

* **`__getitem__:`** custom_dataset[i] will return its output (ith sample)
 * Read called images
<br/><br/>

* Example for Overriding

In [None]:
import os
import torch
import pandas as pd
from skimage import io
import numpy as np
from torch.utils.data import Dataset

# Inherit the built-in Dataset class
class CustomDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

* **torchvision.transforms.Compose:** write callable classes for various augmentations, image-size equalization within a mini-batch
 * Class instead of function so that parameters need not be fed everytime it's called
 * transforms (list of Transform objects) – list of transforms to compose (from left element to right element in list)
 * Example:

In [1]:
from torchvision import transforms, utils

class Rescale(object):
    def __init__(self, output_size):
        ...
    def __call__(self, sample):
        return ...

class RandomCrop(object):
    def __init__(self, output_size):
        ...
    def __call__(self, sample):
        return ...

# Rescale followed by randomcrop (left to right)
composed = transforms.Compose([Rescale(256), RandomCrop(224)])

#### 3-1-1. torchvision.datasets.ImageFolder
* Retreives images assuming the following heirarchy:
    * root/category/xxx.png
    * extension can be heterogeneous (png, jpg, jpeg, etc)
    * Example:

### 3-2. torch.utils.data.IterableDataset
```
Represents an iterable over data samples and override these methods:
```
* **`__init__:`** Constructor
<br/><br/>

* **`__iter__:`**
 * Returns an iterator of samples in dataset
<br/><br/>


## 4. torch.utils.data.DataLoader
* Iterator with the following functionalities:
    * Batching
    * Shuffling
    * Collate
    * `multiprocessing` to load the data in parallel
        *

In [None]:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                            num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0,
                            worker_init_fn=None, multiprocessing_context=None)

### torch.utils.data.sampler.WeightedRandomSampler
* SequentialSampler
* RandomSampler
* SubsetRandomSampler
* WeightedRandomSampler
* BatchSampler

### collate_fn
* To pack a series of images and labels as tensors (first dimension is batch-size)

In [None]:
def collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    _, labels, lengths = zip(*data)
    max_len = max(lengths)
    n_ftrs = data[0][0].size(1)
    features = torch.zeros((len(data), max_len, n_ftrs))
    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)

    for i in range(len(data)):
        j, k = data[i][0].size(0), data[i][0].size(1)
        features[i] = torch.cat([data[i][0], torch.zeros((max_len - j, k))])

    return features.float(), labels.long(), lengths.long()

In [None]:
import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
#     for 3 (RGB) channels
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

## Next
* https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
* https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html