<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0824pytorch_dataset_data_loader_sampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch `Dataset`, `DataLoader`, `Sampler`, `Transforms` の使い方 <!-- # Working with Data: `Dataset`, `DataLoader`, `Sampler`, and `Transforms` -->

これらの基本的な概念により，大規模なデータを簡単に扱うことができる。
<!-- These basic concepts make it easy to work with large data. -->

## 必要となるライブラリ，補助関数，ユーティリティ等の輸入
<!-- ## Init, helpers, utils, ... -->

In [None]:
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
from IPython.core.debugger import set_trace

# `Dataset`

データセットを作るのは簡単な方法を示す。
PyTorch には
あらかじめ [データセット](https://pytorch.org/docs/stable/torchvision/datasets.html) が定義されている。
以下に例を示す：

<!-- It's easy to create your `Dataset`,
but PyTorch comes with some
[build-in datasets](https://pytorch.org/docs/stable/torchvision/datasets.html):
-->

- MNIST
- Fashion-MNIST
- KMNIST
- EMNIST
- FakeData
- COCO
  - Captions
  - Detection
- LSUN
- ImageFolder
- DatasetFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- PhotoTour
- SBU
- Flickr
- VOC
- Cityscapes

`Dataset` はサンプルの数に関する情報を与え (`__len__` を実装)，与えられたインデックスのサンプルを与える (`__getitem__`) を実装する必要がある。
これはデータを扱うためのシンプルで良い抽象化となっている。

<!--`Dataset` gives you information about the number of samples (implement `__len__`) and gives you the sample at a given index (implement `__getitem__`.
It's a nice and simple abstraction to work with data.-->

In [None]:
from torch.utils.data import Dataset

```python
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
```

`ImageFolder` データセットは非常に便利で，フォルダレイアウトの通常の規則に従っている：<!-- The `ImageFolder` dataset is quite useful and follows the usual conventions for folder layouts: -->

```
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
```

## 例 <!--Example-->

In [None]:
# %load my_datasets.py
import os
import tarfile
import zipfile

from torchvision.datasets.folder import ImageFolder, default_loader
from torchvision.datasets.utils import download_url, check_integrity

################################################################################
# PyTorch
class DogsCatsDataset(ImageFolder):
    """
    The 'Dogs and Cats' dataset from kaggle.

    https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/

    Args:
        root: the location where to store the dataset
        suffix: path to the train/valid/sample dataset. See folder structure.
        transform (callable, optional): A function/transform that takes in
            an PIL image and returns a transformed version.
            E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that
            takes in the target and transforms it.
        loader: A function to load an image given its path.
        download: if ``True``, download the data.


    The folder structure of the dataset is as follows::

        └── dogscats
            ├── sample
            │   ├── train
            │   │   ├── cats
            │   │   └── dogs
            │   └── valid
            │       ├── cats
            │       └── dogs
            ├── train
            │   ├── cats
            │   └── dogs
            └── valid
                ├── cats
                └── dogs

    """

    url = 'https://files.fast.ai/data/examples/dogscats.tgz'
    filename = "dogscats.tgz"
    checksum = 'ad2c4e646241a6dc06aedb4b59ef7687'

    def __init__(
        self,
        root: str,
        suffix: str,
        transform=None,
        target_transform=None,
        loader=default_loader,
        download=False,
    ):
        self.root = os.path.expanduser(root)

        if download:
            self._download()
            self._extract()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. "
                "You can use download=True to download it"
            )

        path = os.path.join(self.root, "dogscats", suffix)
        print(f"Loading data from {path}.")
        assert os.path.isdir(path), f"'{suffix}' is not valid."

        super().__init__(path, transform, target_transform, loader)

    def _download(self):
        if self._check_integrity():
            print("Dataset already downloaded and verified.")
            return

        root = self.root
        print("Downloading dataset... (this might take a while)")
        download_url(self.url, root, self.filename, self.checksum)

    def _extract(self):
        path_to_tgz = os.path.join(self.root, self.filename)

        # open file
        file = tarfile.open(path_to_tgz)

        # extracting file
        file.extractall(self.root)
        file.close()
        #path_to_zip = os.path.join(self.root, self.filename)
        #with zipfile.ZipFile(path_to_zip, "r") as zip_ref:
        #    zip_ref.extractall(self.root)

    def _check_integrity(self):
        path_to_zip = os.path.join(self.root, self.filename)
        return check_integrity(path_to_zip, self.checksum)


In [None]:
train_ds = DogsCatsDataset("../data/raw", "sample/train", download=True)

In [None]:
!apt install tree

In [None]:
!tree -d ../data/raw/dogscats/

In [None]:
train_ds

In [None]:
# the __len__ method
len(train_ds)

In [None]:
# the __getitem__ method
train_ds[0]
#print(len(train_ds[0]))  # 2
#print(train_ds[0])  # (PIL.Image.Image mode=RGB, 0)
print(train_ds[0][0].size)  # (499,375)

In [None]:
train_ds[15][0]

In [None]:
train_ds[14][1]

1

オプションとして，便利な関数や属性を提供するデータセットもある．
これはインターフェイスによって強制されるものではない．それに頼ってはいけない．
<!-- Optionally, some datasets offer convenience functions and attributes.
This is not enforced by the interface! Don't rely on it! -->

In [None]:
train_ds.classes

In [None]:
train_ds.class_to_idx

In [None]:
train_ds.imgs

In [None]:
import random

In [None]:
for img, label_id in random.sample(list(train_ds), 4):
    print(label_id, train_ds.classes[label_id])
    display(img)

# `torchvision.transforms`

合成，連鎖などの操作可能な一般的な画像変換について [torchvision の transform 参照](https://pytorch.org/vision/stable/transforms.html)

<!-- (https://pytorch.org/docs/stable/torchvision/transforms.html)。 -->
<!-- Common image transformation that can be composed/chained [[docs]](https://pytorch.org/docs/stable/torchvision/transforms.html). -->

In [None]:
from torchvision import transforms

In [None]:
_image_size = 224
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


trans = transforms.Compose([
    transforms.RandomCrop(_image_size),
    # transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

trans(train_ds[13][0])

## `torchvision.transforms.functional`

<blockquote>

`Funcitional transforms` では，transform パイプラインを細かく制御することができる。
上記の transform とは対照的に，functional transform はパラメータに乱数生成器を含まない。
つまり，すべてのパラメータを指定/生成する必要がある。
だが，functional transform を再利用することができる。
たとえば，以下のように複数の画像に関数変換を適用することができる：
<!-- Functional transforms give you fine-grained control of the transformation pipeline.
As opposed to the transformations above, functional transforms don’t contain a random number generator for their parameters.
That means you have to specify/generate all parameters, but you can reuse the functional transform.
For example, you can apply a functional transform to multiple images like this: -->

https://pytorch.org/vision/stable/transforms.html
</blockquote>

```python
import torchvision.transforms.functional as TF
import random

def my_segmentation_transforms(image, segmentation):
    if random.random() > 5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle)
        segmentation = TF.rotate(segmentation, angle)
    # more transforms ...
    return image, segmentation
```

Ref:
- https://pytorch.org/vision/stable/transforms.html
- https://pytorch.org/vision/stalbe/transforms.html#functional-transforms
- https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
- https://github.com/mdbloice/Augmentor
- https://github.com/aleju/imgaug

<!--
- https://pytorch.org/docs/stable/torchvision/transforms.htm
- https://pytorch.org/docs/stable/torchvision/transforms.html#functional-transforms
- https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
- https://github.com/mdbloice/Augmentor
- https://github.com/aleju/imgaug -->

Shout-out:
- Hig performance image augmentation with pillow-simd [[github]](https://github.com/uploadcare/pillow-simd) [[benchmark]](http://python-pillow.org/pillow-perf/)
- Improving Deep Learning Performance with AutoAugment [[blog]](https://ai.googleblog.com/2018/06/improving-deep-learning-performance.html) [[paper]](https://arxiv.org/abs/1805.09501) [[pytorch implementation]](https://github.com/DeepVoltaire/AutoAugment)

# `Dataloader`

`DataLoader` クラスは，データセットのバッチ化ローディングをマルチプロセシングと様々なサンプリング手法で提供している。
公式ドキュメントは
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
<!-- The `DataLoader` class offers batch loading of datasets with multi-processing and different sample strategies [[docs]](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader).-->

プロトタイプは以下のようになる： <!-- The signature looks something like this: -->

```python
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=default_collate,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None
)
```

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_ds = DogsCatsDataset("../data/raw", "sample/train", transform=trans)
train_dl = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
#train_dl = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

Loading data from ../data/raw/dogscats/sample/train.


In [None]:
train_iter = iter(train_dl)
X, y = next(train_iter)

In [None]:
print("X:", X.shape)
print("y:", y.shape)

X: torch.Size([2, 3, 224, 224])
y: torch.Size([2])


`trans` を渡したが，これは pillow 画像ではなく `torch.Tensor` を返す。
DataLoader はテンソル，数値，辞書，リストを想定している。
<!-- Note that I passed `trans`, which returns `torch.Tensor`, not pillow images.
DataLoader expects tensors, numbers, dicts or lists. -->

In [None]:
_train_ds = DogsCatsDataset("../data/raw", "sample/train", transform=None)
_train_dl = DataLoader(_train_ds, batch_size=2, shuffle=True)

try:
    for batch in _train_dl:
        pass
except TypeError as e:
    print("ERROR")
    print(e)

Loading data from ../data/raw/dogscats/sample/train.
ERROR
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>


## `collate_fn`

`DataLoader` の `collate_fn` 引数を使用すると，単一のデータポイントをバッチにまとめる方法をカスタマイズできる。
`collate_fn` はデータポイントのリスト (`dataset.__getitem__` が返すもの) を取得する単純な callable である。
<!-- The `collate_fn` argument of `DataLoader` allows you to customize how single datapoints are put together into a batch.
`collate_fn` is a simple callable that gets a list of datapoints (i.e. what `dataset.__getitem__` returns). -->

カスタム`collate_fn`の例
([こちら](https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278/3)から引用)：

<!-- Example of a custom `collate_fn`
(taken from [here](https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278/3)): -->

In [None]:
def my_collate_fn(list_of_x_y):
    data = [item[0] for item in list_of_x_y]
    target = [item[1] for item in list_of_x_y]
    target = torch.LongTensor(target)
    return [data, target]

# `Sampler`

`Sampler` はデータセット[[docs]](https://pytorch.org/docs/stable/data.html#torch.utils.data.sampler.Sampler)からサンプリングする方法を定義する。

<!-- `Sampler` define **how** to sample from the dataset [[docs]](https://pytorch.org/docs/stable/data.html#torch.utils.data.sampler.Sampler). -->

例:
- `SequentialSampler`
- `RandomSamples`
- `SubsetSampler`
- `WeightedRandomSampler`

`__iter__` を実装するだけで，データセットのインデックスを繰り返し処理することができる。
<!-- Write your own by simply implementing `__iter__` to iterate over the indices of the dataset. -->

```python
class Sampler(object):
    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError
```

# まとめ <!-- # Recap-->

- `Dataset`：データポイントを1つ取得する
- `transforms`: 組み合わせ可能な変換
- `DataLoader`: 1 つのデータポイントをバッチにまとめる
- `Sampler`：データセットからサンプリングする方法を提供

簡潔で拡張可能なインターフェースである

<!-- - `Dataset`: get one datapoint
- `transforms`: composable transformations
- `DataLoader`: combine single datapoints into batches (plus multi processing and more)
- `Sampler`: **how** to sample from a dataset

**Simple but extensible interfaces** -->

# 演習 <!--Exercise-->

- `DogsCatsDataset` を拡張して，データセットのサイズ，つまりサンプルの数を指定できるようにせよ。
- より小さなデータセットを作成するために  `Subset` [[docs]](https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) を試せ。
- データセットのサイズ (0 から1 の間) を指定することができる `SubsetFraction` を作成せよ。
- `DogsCatsDataset` 用のカスタム collate 関数を書いて，自己符号化器の設定で使用するのに適切なデータセットにせよ。


<!-- Go out and play:

- Maybe extend the `DogsCatsDataset` such that you can specify the size of dataset, i.e. the number of samples.
- Maybe try the `Subset` [[docs]](https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) to create smaller datasets.
- Maybe create `SubsetFraction` where you can specify the size of the dataset (between 0. and 1.).
- Maybe write a custom collate function for the `DogsCatsDataset` that turns it into a dataset appropriate to use in an autoencoder settings. -->

In [None]:
def autoencoder_collate_fn(list_of_x_y):
    # TODO implement me
    pass

In [None]:
class MyDataSet(Dataset):
    def __init__(self):
        super().__init__()
        # TODO implement me

    def __len__(self):
        # TODO implement me
        pass

    def __getitem__(self, idx):
        # TODO implement me
        pass