# Using a Sequential DataLoader to Create a Training Loop



## Import Libraries

In [None]:
import os
import glob
import numpy as np
import cv2
from urllib.request import urlretrieve
from functools import partial

from metavision_ml.data import SequentialDataLoader
from metavision_ml.data import box_processing as box_api

## Download the Data Sample

In [None]:
from zipfile import ZipFile

dataset_path = "dataset_precomputed"
# a dictionary containing the correspondence map for class indices
label_map_path = os.path.join(dataset_path, 'label_map_dictionary.json')

# getting the data for this tutorial
if not os.path.isdir(dataset_path):
    if not os.path.exists("sample_dataset.zip"):
        urlretrieve("https://dataset.prophesee.ai/index.php/s/QFrGIKZ13fr1oQa/download", filename="sample_dataset.zip")
    with ZipFile('sample_dataset.zip', 'r') as zipObj:
        # Extract all the contents of zip file in current directory
        zipObj.extractall()
!ls "sample_dataset.zip" {dataset_path}

**Suggestions on organizing your dataset**

The data sample we provided here is only used for illustration purpose. In practice, we suggest you to separate train, test and validation data in distinct folders, using the following structure:

```
    dataset_folder/
    ├── train/
    │   ├── file_1.h5
    │   ├── file_1_bbox.npy
    │   ├── file_2.h5
    │   ├── file_2_bbox.npy
    ├── test/
    │   ├── file_3.h5
    │   ├── file_3_bbox.npy
    │   ├── file_4.h5
    │   ├── file_4_bbox.npy
    ├── val/
    │   ├── file_5.h5
    │   ├── file_5_bbox.npy
    ├── possibly a readme and some metadata file (JSON etc.)
```


## Load Labels

In supervised training, in addition to the training data, we need ground truth labels as well. Depending on the type of training, the labels might come in very different formats. To facilitate training with various label formats, we provide a **template function** that can be used to write your own label loading functions.


First, let's see how this **template function** looks like in our ML module.

In [None]:
from metavision_ml.data.sequential_dataset import load_labels_stub
help(load_labels_stub)  

As you can see, the function returns both a list of labels and a boolean mask indicating if the corresponding time bins are labeled or not. During training, this boolean mask will be used to filter out unlabeled time bins so that no loss will be computed on them.

### Customize a Label Loading Function

Now let's create a function to load detection bounding boxes.

In [None]:
def custom_load_boxes(metadata, batch_start_time, duration, tensor, **kwargs):

    # we first load the events from file
    box_events = box_api.load_box_events(metadata, batch_start_time, duration)
    
    # here, we just look in the class look up what is the corresponding number for each class 
    # in order to get contiguous class numbers for our training dataset.
    class_lookup = kwargs['class_lookup']
    box_events['class_id'] = class_lookup[box_events['class_id']]
    
    # We then split the box events into each time bin in a list of box event array
    num_tbins = tensor.shape[0]
    box_events = box_api.split_boxes(box_events, batch_start_time=batch_start_time, delta_t=duration // num_tbins, num_tbins=num_tbins)
    # if all frames contain labels
    all_frames_are_okay = np.ones((len(box_events)), dtype=np.bool)
    return box_events, all_frames_are_okay

You see that the function above requires an additional argument: `class_lookup` compared to our **template function**. Therefore, we need to customize the function so that its signature is exactly the one we expect. You can use the [partial](https://docs.python.org/3/library/functools.html#functools.partial) function from the [functools](https://docs.python.org/3/library/functools.html) module to pass additional arguments.

In [None]:
# the labels of the class we want to load from the dataset
wanted_keys = ['car', 'pedestrian', 'two wheeler']

# create a look up table to get the lookup IDs from the selected classes
class_lookup = box_api.create_class_lookup(label_map_path, wanted_keys)

custom_load_boxes_fn = partial(custom_load_boxes, class_lookup=class_lookup)

## Event-Based SequentialDataLoader

Before instantiating the ``SequentialDataLoader`` class, let's first define some input parameters, then pass our custom label loading function ``custom_load_boxes_fn``.

In [None]:
files = glob.glob(os.path.join(dataset_path, "*.h5"))[:2]
preprocess_function_name = "histo"
delta_t = 50000
channels = 2  # histograms have two channels
num_tbins = 3
height, width = 360, 640
batch_size = 2
max_incr_per_pixel = 2.5
array_dim = [num_tbins, channels, height, width]


**Instantiate the class**

In [None]:
seq_dataloader = SequentialDataLoader(files, delta_t, preprocess_function_name, array_dim,
                                      load_labels=custom_load_boxes_fn,
                                      batch_size=batch_size, num_workers=0,
                                      preprocess_kwargs={"max_incr_per_pixel": max_incr_per_pixel})

**Let's iterate over the loaded data and visualize its metadata.**

In [None]:
for index, batch in enumerate(seq_dataloader):
    if index == 1: # we only visualize one example, remove it if you want to process all data
        break
    print("available keys: ", batch.keys(), "\n") 
    print("input shape:", batch["inputs"].shape, "\n")
    print("metadata:", batch["video_infos"], "\n")
    print("box events: ", len(batch['labels']), "lists (corresponding to the no. of time bins), each containing a batch sized lists :", [len(labels) for labels in batch['labels']])

As you can see, at each iteration ``SequentialDataLoader`` produces a dictionary, containing information of  `inputs`, `labels`, `mask_keep_memory`, `frame_is_labeled` and `video_infos`. 

The inputs are tensors of the shape $[T \times N \times C \times H \times N]$ instead of $[N \times C \times H \times N]$, because we need to deal with the temporal information in our training, and it allows to process the data sequentially from the first time bin to the last. 

* T: number of time bins
* N: batch size
* C: feature size
* H: height
* W: width



Similarly, the bounding boxes are organized in $T$ lists of $N$ nested lists, so that labels and tensor are indexed consistently. 

The ``mask_keep_memory`` is a binary tensor of length $N$, with value 0. indicating the beginning of a new recording. This is useful in case we want to reset memory between different recordings.


**Let's also take a closer look at those labels in one batch of the time bin.**

For instance, bounding boxes in the 2nd time bin of the 1st batch are:

In [None]:
batch['labels'][1][0]

### Visualization Utility of SequentialDataLoader

The class ``SequentialDataLoader`` provides a visualization method named ``show``. It can visualize batches of the ``SequentialDataLoader`` in parallel with openCV. 

**Let's visualize the frames we have just loaded.**

In [None]:
if os.environ.get("DOC_DISPLAY", "ON") != "OFF":
    cv2.namedWindow('sequential_dataloader')
    for frame in seq_dataloader.show():
        cv2.imshow('sequential_dataloader', frame[..., ::-1])
        key = cv2.waitKey(1)
        if key == 27:
            break
    cv2.destroyWindow('sequential_dataloader')


The ``show()`` method can be called with a custom label visualization function so as to stream the labels together with the input data. Its signature should match the following:

In [None]:
def draw_labels(frame, labels):
    """
    Args:
        frame (np.ndarray) frame of size height x width x 3
        labels: label for one file and one tbin

    Returns:
        The input frame on which the labels were drawn.
    """
    return frame

**Let's now visualize the batch data together with the labels.** 


In [None]:
load_boxes_fn = partial(box_api.load_boxes, class_lookup=class_lookup)

seq_dataloader = SequentialDataLoader(files, delta_t, preprocess_function_name, array_dim, load_labels=load_boxes_fn,
                                      batch_size=batch_size, num_workers=0, preprocess_kwargs={"max_incr_per_pixel": max_incr_per_pixel})


In [None]:
from metavision_ml.detection_tracking.display_frame import draw_box_events

label_map = ['background'] + wanted_keys

# adding box visualization. Notice how here again we rely on partial.
viz_labels = partial(draw_box_events, label_map=label_map)

if os.environ.get("DOC_DISPLAY", "ON") != "OFF":
    cv2.namedWindow('sequential_dataloader')
    for frame in seq_dataloader.show(viz_labels):
        cv2.imshow('sequential_dataloader', frame[..., ::-1])
        key = cv2.waitKey(1)
        if key == 27:
            break
    cv2.destroyWindow('sequential_dataloader')

## Training Loop Example

The data loader presented in this tutorial can be used to create a custom training loop.
The following is the pseudo-code you can use to train you own event-based models:

```python
for data in seq_dataloader:
    # we first need to reset the memory for each new sequence in the batch and detach the gradients
    # detaching the gradients prevents the computational graph to be as long as the full sequence.
    # This is called *truncated backpropagation*.
    net.reset(data['mask_keep_memory'])
    
    # clear the optimiser
    optimizer.zero_grad()
    
    # we compute the predictions chronologically. This is the forward pass.
    predictions = []
    for batch in data['inputs']:
        predictions.append(net.forward(batch))
    predictions = torch.stack(predictions)
        
    # loss is computed only during relevant timestamps.
    loss = compute_loss(predictions[data["frame_is_labeled"]], data['targets'][data["frame_is_labeled"]])

    # we then compute the backward pass and update the networks weights.
    loss.backward()
    optimizer.step()
```


## CDProcessorDataLoader

Above Implementation is based on the Pytorch Dataset (Map-Style Dataset with __get_item__ function to override). We present here an alternative implementation based on Pytorch IterableDataset. The advantage is for cases where frequent file seeking is costly or not possible. With it you can stream .raw files directly, without having to convert them to dat or h5 files. 

In [None]:
from metavision_ml.data.cd_processor_dataset import CDProcessorDataLoader
from metavision_core.utils import get_sample

# we grab a folder of .raw files
# if the file doesn't exist, it will be downloaded from Prophesee's public sample server 
files = ["driving_sample.raw","hand_spinner.raw", "spinner.raw", "80_balls.raw"]

for file in files:
    get_sample(file)
    assert os.path.isfile(file)

dataloader = CDProcessorDataLoader(
    files,
    mode='n_events',
    delta_t=0,
    n_events=10000,
    max_duration=10000000,
    preprocess_function_name="diff",
    height=240,
    width=320,
    num_tbins=5,
    batch_size=4,
    num_workers=2,
    load_labels=None,
    padding_mode='zeros')


In [None]:
from metavision_ml.data.sequential_dataset_common import show_dataloader

show = show_dataloader(
    dataloader,
    dataloader.height,
    dataloader.width,
    dataloader.batch_size,
    dataloader.get_vis_func(),
    None)

if os.environ.get("DOC_DISPLAY", "ON") != "OFF":
    cv2.namedWindow('stream_dataloader')
    for frame in show:
        cv2.imshow('stream_dataloader', frame[..., ::-1])
        key = cv2.waitKey(1)
        if key == 27:
            break
    cv2.destroyWindow('stream_dataloader')