# Imports
---

In [23]:
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter

import numpy as np
import time
import math

Defining an "Epoch Dataloader with Shuffle"
---
This is a dataloader as defined generally in most deep learning framework *with*
all the benefits of `batchgenerators`. The dataloader, using the `SlimDataLoaderBase`
from `batchgenerators`, has the following properties:

1. Each epoch covers the dataset once
2. Each new epoch 'deterministically' shuffles the data before creating batches
3. Using `MultiThreadedAugmentor` to use multiple processes to load and transform
data _while_ using batch size > 1

Comments have been added throughout the code to explain important parts of writing
a dataloader with `batchgenerators`. 

Also highlighted is the interplay between
`SlimDataLoaderBase` and `MultithreadedAugmentor`, the later of which is used for rapid mini-batched
loading and augmenting.

In [24]:
class EpochDLWithShuffle(SlimDataLoaderBase):
    def __init__(self, data, num_threads_in_mt=12, batch_size=4):
        # This initializes self._data, self.batch_size and self.number_of_threads_in_multithreaded
        super(EpochDLWithShuffle, self).__init__(data, batch_size, num_threads_in_mt)

        self.num_restarted = 0
        self.current_position = 0
        self.was_initialized = False

    def reset(self):
        # Prevents the random order for each epoch being the same
        rs = np.random.RandomState(self.num_restarted)
        
        # Here the data is shuffled but one can easily replace this with a 
        # shuffle of indices for when one wants to load the data while generating
        # a batch in real-time, for example.
        #
        # Eg. rs.shuffle(self._data_indices) instead of line below
        rs.shuffle(self._data)
        self.was_initialized = True
        self.num_restarted = self.num_restarted + 1

        # Select a starting point for this subprocess. The self.thread_id is set by
        # MultithreadedAugmentor and is in the range [0, num_of_threads_in_mt)
        # Multiplying it with batch_size gives every subprocess a unique starting 
        # point WHILE taking into consideration the size of the batch
        self.current_position = self.thread_id*self.batch_size

    def generate_train_batch(self):
        # This method HAS to be defined and is used to return batches.

        # For doing the initialization in each subprocess generated by MultiThreadedAugmentor
        if not self.was_initialized:
            self.reset()
        
        # This will be used for the batch starting point in this loop
        idx = self.current_position

        if idx < len(self._data):
            
            # Next starting point. This skips the length of one batch for
            # this process AS WELL AS all the other processes (i.e, self.number_of_threads_in_multithreaded)
            # Since the processes already have unique (but contiguous) starting 
            # points due to the initialization of self.current_position in 
            # reset(), they continue to not overlap.
            self.current_position = idx + self.batch_size*self.number_of_threads_in_multithreaded

            # Having assured that the next starting point is safe, we simply
            # return the next batch. Additionally, we take into consideration
            # that the idx+batch_size might exceed the dataset size so we take 
            # min(len(self._data),idx+self.batch_size) as the end index
            return self._data[idx: min(len(self._data),idx+self.batch_size)]
        else:
            self.was_initialized=False
            raise StopIteration

# MultiThreadedAugmentor and running the Dataloader
---

The `MultiThreadedAugmentor` allows for rapid data loading and augmentation
(`batchgenerators` has these as well) using multiple sub-processes. This is designed
to be seamlessly used with the dataloader defined above.

In [25]:
# Some random data. Deliberately choosing odd length
data = list(range(107))
batch_size = 4

# Number of subprocesses to be used in MultiThreadedAugmentor. It is important to
# coordinate this variable between the dataloader you are defining 
# (which subclasses SlimDataLoaderBase, eg, we call it EpochDLWithShuffle here)
# for calculations. Basically, pass the same value to both.
num_threads_in_mt=12        

# Just checking the samples in our original dataset
print(f"Total samples in Dataset: {len(data)}")

# Let's create an instance of our dataloader which we defined in the previous code cell
dl = EpochDLWithShuffle(
        data=data, 
        num_threads_in_mt=num_threads_in_mt,    # This should be the same as *
        batch_size=batch_size
        )


# Multithreaded augmemter with no transforms. For actual datasets, feel free
# to use multiple transforms from batchgenerators. The bare minimum you need to
# be concerned about while using MultiThreadedAugmenter is the data_loader,
# transforms and num of processes. The speedup is essentially due to individual
# sub-processes performing the transforms on the data and loading them onto Queues
# for your training loop to retrieve from when iterating over the 
# MultiThreadedAugmenter object
mt = MultiThreadedAugmenter(
            data_loader=dl, 
            transform=None,                      # batchgenerators provides augmentations for 2D and 3D data
            num_processes=num_threads_in_mt      # * as this
            )
# For optimizing performance further, please read about the other parameters of
# MultiThreadedAugmenter

# Iterating over one epoch to demonstrate batches being retrieved
for epoch in range(1):
    l = []
    for batch_id, batch in enumerate(mt):
        print(f"Batch {batch_id+1}:{batch}")

        # Let's store the batches to analyze them later
        l.extend(batch)      

    # How many samples were returned by the augmentor?
    print(f"Total samples returned by augmentor: {len(l)}")

    # How much of the original data was covered in the batches return?
    print(f"Coverage of samples returned by augmentor of original dataset: {len(set(l).intersection(set(range(len(data)))))}/{len(data)}")

Total samples in Dataset: 107
Batch 1:[84, 10, 75, 2]
Batch 2:[24, 99, 106, 7]
Batch 3:[16, 86, 68, 22]
Batch 4:[45, 60, 76, 52]
Batch 5:[13, 73, 85, 54]
Batch 6:[102, 8, 26, 92]
Batch 7:[33, 3, 66, 48]
Batch 8:[30, 6, 78, 94]
Batch 9:[89, 93, 100, 59]
Batch 10:[27, 18, 61, 51]
Batch 11:[63, 71, 43, 1]
Batch 12:[79, 42, 41, 4]
Batch 13:[15, 17, 40, 38]
Batch 14:[5, 53, 98, 56]
Batch 15:[0, 34, 28, 55]
Batch 16:[50, 11, 62, 35]
Batch 17:[23, 31, 96, 57]
Batch 18:[82, 91, 32, 90]
Batch 19:[14, 74, 19, 29]
Batch 20:[49, 104, 105, 69]
Batch 21:[80, 20, 101, 72]
Batch 22:[77, 25, 37, 81]
Batch 23:[46, 97, 39, 65]
Batch 24:[58, 12, 95, 88]
Batch 25:[70, 87, 36, 21]
Batch 26:[83, 9, 103, 67]
Batch 27:[64, 47, 44]
Total samples returned by augmentor: 107
Coverage of samples returned by augmentor of original dataset: 107/107


# Timing an epoch
---
The major speed benefits are seen while working with actual data and an augmentation
pipeline (check `batchgenerators` README for an example of this). 

However, an example of timing code is provided and can be used to check for speedups when using real data

In [26]:
# Timing a full epoch. However, this is without transformations and speedups are
# usually observed while using tranformations
batch_times = []
for _ in range(math.ceil(len(data)/batch_size)):
    start = time.time()
    _ = next(mt)
    batch_times.append(time.time() - start)
    print("This batch took %02.3f s" % batch_times[-1])

avg_batch_time_mt = np.mean(batch_times)
print("Multi threaded batch generation using 12 workers took %02.3f s on average per batch" % avg_batch_time_mt)

This batch took 0.002 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.021 s
This batch took 0.001 s
This batch took 0.000 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.002 s
This batch took 0.001 s
This batch took 0.001 s
This batch took 0.001 s
Multi threaded batch generation using 12 workers took 0.001 s on average per batch
