# Notes of Using BatchGenerators Modules

Here are some simple notes and examples about how to use the BatchGenerators Modules, which may be more precise than the official help notebooks. 

1. The first part is about how to build the MultiThread Dataloader from our own data, which includes two different realizations: 

    **~ `batchgenerators.dataloading.data_loader.DataLoaderFromDataset`**

    **~ `batchgenerators.dataloading.data_loader.DataLoader`**

2. The second part is about some transformation (for augmentation) (to be finished).

## Build MultiThread DataLoader

### Via `batchgenerators.dataloading.data_loader.DataLoaderFromDataset`

In the following examples, we take the MNIST dataset as example. Instead of directly use
```python
    torchvision.datasets.MNIST
```
Here we use the original **xxx.gz** files to load the MNIST to get the adrray-type MNIST data.

First is the function of loading MNIST data:

In [1]:
import numpy as np 
import gzip
import os

def load_data(data_folder):
    files = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz']

    paths = []
    for fname in files:
        paths.append(os.path.join(data_folder, fname))
    with gzip.open(paths[0], 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16
        ).reshape(-1, 28, 28)
    with gzip.open(paths[1], 'rb') as lbpath:
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
    return x_train, y_train

x_train, y_train = load_data('./MNIST/raw')
print(x_train.shape)
print(y_train.shape)

data_dict = {'data': x_train, 'labels': y_train.astype(np.int64)} 
# here the int label must be np.int64 type, otherwise there will be some mistakes for buliding the dataloader.

(60000, 28, 28)
(60000,)


Referring to the official file **batchgenerators/examples/cifar.py** and the realization of `CifarDataset` module in batchgenerator, then we can write our own Dataset firstly, which should inherit the `Dataset` module in batchgenerator. 

In fact, this is almost the same as the general dataset, but replace the 
```python
    torch.utils.data.Dataset
``` 
by the 
```python
    batchgenerators.dataloading.dataset.Dataset
```


In [2]:
from batchgenerators.dataloading.dataset import Dataset
class OurOwnDataset(Dataset):
    def __init__(self, data, train=True, transform=None):
        super(OurOwnDataset, self).__init__()

        self.data = data['data']
        self.labels = data['labels']
        
        self.train = train
        self.transform = transform

        # here we omit the judgement of if_train, since we only use the training dataset in mnist
    def __getitem__(self, item):
        data_dict = {'data': self.data[item:item+1], 'label': self.labels[item]}
        return data_dict
        
    
    def __len__(self):
        return len(self.data)
    
ds_mnist = OurOwnDataset(data_dict)
print(ds_mnist.__len__())

60000


$\textbf{Notice}: $ In the data_dict, we use the slice `self.data[item:item+1]` but not directly `self.data[item]`. And the difference is that, for `self.data[item]` the batch data in `MultiThreadedAugmenter` will be in shape of **(batch_size * 28, 28)** while if we use `self.data[item:item+1]`, the size of the batch data is **(batch_size, 28, 28)** in this example.

Then we can use the command
```python
    batchgenerators.dataloading.data_loader.DataLoaderFromDataset
```
to build the DataLoader from the Dataset, then we can use
```python
    batchgenerators.dataloading.multi_threaded_augmenter.MultiThreadedAugmenter
```

$\color{red}{\textbf{ATTENTION}}$: 

1. The parameter `num_processes` in `MultiTreadedAugmenter` must be the same as the parameter `num_threads_in_multithreaded` in `DataLoaderFromDataset`.
2. The `int` type data must be `np.int64`, otherwise in the `MultiThreadedAugmenter` there will be some mistakes from the `default_collate` function in **data_loader.py** file.

In [3]:
from batchgenerators.dataloading.data_loader import DataLoaderFromDataset
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter

num_threads_in_mt = 10
dl_train = DataLoaderFromDataset(ds_mnist, 5, num_threads_in_mt, shuffle=False)    
mt_dl_train = MultiThreadedAugmenter(dl_train, None, num_threads_in_mt)

batch_dict = next(mt_dl_train)
print(batch_dict['data'].shape)
print(batch_dict['label'])
mt_dl_train._finish()

(5, 28, 28)
[5 0 4 1 9]


We can also use the normal iteration to get the batch as follows:

In [8]:
for i, batch_dict in enumerate(mt_dl_train):
    print(f'iter: {i}, batch shape: {batch_dict["data"].shape}, batch label: {batch_dict["label"]}')
    if i > 10:
        break
mt_dl_train._finish() # in this way, the _finish() is not necessary


iter: 0, batch shape: (5, 28, 28), batch label: [5 0 4 1 9]
iter: 1, batch shape: (5, 28, 28), batch label: [2 1 3 1 4]
iter: 2, batch shape: (5, 28, 28), batch label: [3 5 3 6 1]
iter: 3, batch shape: (5, 28, 28), batch label: [7 2 8 6 9]
iter: 4, batch shape: (5, 28, 28), batch label: [4 0 9 1 1]
iter: 5, batch shape: (5, 28, 28), batch label: [2 4 3 2 7]
iter: 6, batch shape: (5, 28, 28), batch label: [3 8 6 9 0]
iter: 7, batch shape: (5, 28, 28), batch label: [5 6 0 7 6]
iter: 8, batch shape: (5, 28, 28), batch label: [1 8 7 9 3]
iter: 9, batch shape: (5, 28, 28), batch label: [9 8 5 9 3]
iter: 10, batch shape: (5, 28, 28), batch label: [3 0 7 4 9]
iter: 11, batch shape: (5, 28, 28), batch label: [8 0 9 4 1]


$\textbf{Notice}: $ Each time we finish our iteration in the `mt_dl` by `next(mt_dl)`, we must close it by the command `mt_dl._finish()` or reload the `mt_dl`. Otherwise, the next time we use the `mt_dl`, it will continue to iterate the batch from the last time and after the iteration times is larger than the length of it, there will be mistakes. See the following simple example:

In [9]:
class NumDataset(Dataset):
    def __init__(self, data):
        super(NumDataset).__init__()
        self.data = data
    def __getitem__(self, item):
        return self.data[item]
    def __len__(self):
        return len(self.data)
data_num = np.arange(1,11)
ds_num =NumDataset(data_num)
dl_num = DataLoaderFromDataset(ds_num, 7, num_threads_in_mt, shuffle=False) # the batch size is 7
mt_dl_num = MultiThreadedAugmenter(dl_num, None, num_threads_in_mt)

In [10]:
for epoch in range(2):
    for i in range(ds_num.__len__() // 7):
        data = next(mt_dl_num)
        print(f'epoch: {epoch}, iter: {i}, data: {data}')
    mt_dl_num._finish()

epoch: 0, iter: 0, data: [1 2 3 4 5 6 7]
epoch: 1, iter: 0, data: [1 2 3 4 5 6 7]


In [11]:
for epoch in range(2):
    for i in range(ds_num.__len__() // 7):
        data = next(mt_dl_num)
        print(f'epoch: {epoch}, iter: {i}, data: {data}')
    # mt_dl_num._finish()

epoch: 0, iter: 0, data: [1 2 3 4 5 6 7]


StopIteration: 

### Via `batchgenerators.dataloading.data_loader.DataLoader`

In fact, we can also build our own dataloader and then use the `MultiThreadedAugmenter` to get the `mt_dl`, which can be simply realized by overwrite the method `generate_train_batch` in the class of `batchgenerators.dataloading.data_loader.DataLoader`. Here is `OurOwnMNISTDataloader` on the mnist dataset:

In [12]:
from batchgenerators.dataloading.data_loader import DataLoader
class OurOwnMNISTDataloader(DataLoader):
    def __init__(self, data, batch_size, num_threads_in_multithreaded=1, seed_for_shuffle=None, return_incomplete=False, shuffle=True, infinite=False, sampling_probabilities=None):
        super().__init__(data, batch_size, num_threads_in_multithreaded, seed_for_shuffle, return_incomplete, shuffle, infinite, sampling_probabilities)
        self.indices = list(range(len(data['data']))) # necessary
        
    def generate_train_batch(self):
        idx = self.get_indices()
        img_for_batch = [self._data['data'][i] for i in idx]
        label_for_batch = [self._data['labels'][i] for i in idx]
        
        data = np.zeros((self.batch_size, 1, 28, 28))
        data_label = np.zeros(self.batch_size, )
        for i, (img, label) in enumerate(zip(img_for_batch, label_for_batch)):
            data[i] = img.reshape(-1, 28, 28)
            data_label[i] = label
        return {'data': data, 'label': data_label}    

In [13]:
dl_mnist = OurOwnMNISTDataloader(data_dict, 5, num_threads_in_mt)
mt_dl_mnist = MultiThreadedAugmenter(dl_mnist, None, num_threads_in_mt)
data = next(mt_dl_mnist)

print(data['data'].shape)
print(data['label'])
mt_dl_mnist._finish()

(5, 1, 28, 28)
[4. 3. 2. 7. 1.]


$\textbf{Notice}: $

1. Our own dataloader inherits the class `DataLoader`, the attribute `data` and `batch_size` **do not need to define**, since the class `DataLoader` has the attribute `self._data = data` and `self.batch_size = batch_size`;

2. In our own dataloader, we need to define the attribute `self.indices = len(data)`, which defines the length of our our data. This attribute can be called in the method `self.get_indices()` to generate random batch data each time.

## Transformation for Augmentation

To be finished..