# build loader

> Module to load the data from the dataset

In [None]:
#| default_exp build_loader

In [None]:
#|export
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist

In [None]:
#|eval: false
#|hide
from pillarnext_explained import dataset as pillarnext_dataset

In [None]:
#|exports
def collate(batch_list):
    """This function is designed to merge a batch of data examples into a format suitable for further processing."""
    example_merged = defaultdict(list)
    for example in batch_list:
        for k, v in example.items():
            example_merged[k].append(v)
    ret = {}
    for key, elems in example_merged.items():
        if key == "token":
            ret[key] = elems
        elif 'point' in key:
            coors = []
            for i, coor in enumerate(elems):
                coor_pad = np.pad(
                    coor, ((0, 0), (1, 0)), mode="constant", constant_values=i
                )
                coors.append(coor_pad)
            ret[key] = torch.tensor(np.concatenate(coors, axis=0))
        elif isinstance(elems[0], list):
            ret[key] = defaultdict(list)
            res = []
            for elem in elems:
                for idx, ele in enumerate(elem):
                    ret[key][str(idx)].append(torch.tensor(ele))
            for kk, vv in ret[key].items():
                res.append(torch.stack(vv))
            ret[key] = res
        else:
            ret[key] = torch.tensor(np.stack(elems, axis=0)).float()

    return ret

In [None]:
#|eval: false
# Sample batch list of examples
batch_list = [
    {
        "token": [1, 2, 3],
        "point1": np.array([[1.0, 2.0], [3.0, 4.0]]),
        "point2": np.array([[5.0, 6.0]]),
        "nested_list": [[1, 2], [3, 4]],
        "value": np.array([1.0, 2.0])
    },
    {
        "token": [4, 5, 6],
        "point1": np.array([[7.0, 8.0]]),
        "point2": np.array([[9.0, 10.0], [11.0, 12.0]]),
        "nested_list": [[5, 6], [7, 8]],
        "value": np.array([3.0, 4.0])
    }
]

# Using the collate function
collated_batch = collate(batch_list)

# Display the collated result
for key, value in collated_batch.items():
    print(f"{key}: {value}")

token: [[1, 2, 3], [4, 5, 6]]
point1: tensor([[0., 1., 2.],
        [0., 3., 4.],
        [1., 7., 8.]], dtype=torch.float64)
point2: tensor([[ 0.,  5.,  6.],
        [ 1.,  9., 10.],
        [ 1., 11., 12.]], dtype=torch.float64)
nested_list: [tensor([[1, 2],
        [5, 6]]), tensor([[3, 4],
        [7, 8]])]
value: tensor([[1., 2.],
        [3., 4.]])


In [None]:
#|exports
def build_dataloader(dataset, # Dataset object
                     batch_size=4, # Batch size
                     num_workers=8, # Number of workers
                     shuffle:bool=False, # Shuffle the data
                     pin_memory=False # Pin memory
                     ):
    """This function is designed to build a DataLoader object for a given dataset."""
    if dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        sampler = DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
    else:
        sampler = None

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None and shuffle),
        num_workers=num_workers,
        collate_fn=collate,
        pin_memory=pin_memory,
    )

    return data_loader

In [None]:
#|eval: false
train_dataset = pillarnext_dataset.NuScenesDataset("infos_train_10sweeps_withvelo_filterZero.pkl",
                                "/root/nuscenes-dataset/v1.0-mini",
                                10,
                                class_names=[["car"], ["truck", "construction_vehicle"], ["bus", "trailer"], ["barrier"], ["motorcycle", "bicycle"], ["pedestrian", "traffic_cone"]],
                                resampling=True)

train_loader = build_dataloader(train_dataset)
print(f"Number of batches: {len(train_loader)}")

Number of batches: 303


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()