<a href="https://colab.research.google.com/github/YuanChenhang/USAAIO/blob/main/PyTorch_DataLoader%2C_Collation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Import necessary libraries**

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# **DataLoader**

**DataLoader**: A DataLoader object

* partitions a dataset into batches
* is an iterable that consists of these batches

```
dataloader = DataLoader(dataset, batch_size = ..., shuffle = ..., drop_last = ..., collate_fn = ...)
```

* dataset: Dictionary format

```
dataset[index] = {key_0: val_0[index], ..., key_{M-1}: val_{M-1}[index]}
```

* ```batch_size```: Number of samples in each batch

* ```shuffle```: Whether dataset indices in a batch are in order

* ```drop_last```: Whether dropping out the last batch if its size is smaller than the batch size

Each item in **dataloader**:

```
batch = {key_0: batch_val_0, ..., key_{M-1}: batch_val_{M-1}}
```

* Let $B$ be the set of indices of samples that are in the batch. For each ```key_m```,

```
batch_val_m = torch.stack([val_m[index] for index in B], dim = 0)
```

**Collate function** ```collate_fn```: For each key ```key_m```, make ```val_m[index]``` **$\color{red}{\text{the same shape for all indices}}$**

* If ```val_m[index]``` has the same shape for all indices, then we use DataLoader's default function ```collate_fn```.

* Otherwise, create a customized ```collate_fn```.



**Example: Build a dataloader object with the image dataset that we built in Chapter "PyTorch - Dataset"**

* Raw dataset
    * Image data in numpy with shape ```(batch_size, height, width, num_channels)```
    * Labels

*  Dataset to construct
    * Image data in tensor with shape ```(batch_size, num_channels, height, width)``` and is normalized within 0 and 1
    * Labels

In [None]:
# Build a raw dataset

num_samples = 100
size = (128, 256, 3) # shape of each sample (height, width, num_channels)
num_classes = 5

images = np.random.randint(low = 0, high = 256, size = (num_samples, *size), dtype = np.uint8)
labels = np.random.randint(low = 0, high = num_classes, size = (num_samples,), dtype = np.int64)

In [None]:
class MyDataset(Dataset):
    def __init__(self, images_numpy, labels):
        self.images_tensor = torch.from_numpy(images_numpy).to(torch.float32) / 255
        self.images_tensor = self.images_tensor.permute(0, 3, 1, 2)
        self.labels = torch.from_numpy(labels).to(torch.int64)

    def __len__(self):
        return self.images_tensor.shape[0]

    def __getitem__(self, index):
        image = self.images_tensor[index]
        label = self.labels[index]
        return {'image': image, 'label': label, 'sample_index': torch.tensor(index)}

In [None]:
# Construct dataset
dataset = MyDataset(images, labels)
print(type(dataset[0]))
print(dataset[0].keys())
print(dataset[0]['image'].shape)
print(dataset[0]['label'].shape)
print(dataset[0]['sample_index'])

<class 'dict'>
dict_keys(['image', 'label', 'sample_index'])
torch.Size([3, 128, 256])
torch.Size([])
tensor(0)


**Build dataloader with the default ```collate_fn```**

In [None]:
batch_size = 4

dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = False, collate_fn = None)

batch = next(iter(dataloader))
print(type(batch))
print(batch['image'].shape)
print(batch['label'].shape)
print(batch['sample_index'])

<class 'dict'>
torch.Size([4, 3, 128, 256])
torch.Size([4])
tensor([45, 99, 83, 96])


# **Collation functions**

**Build dataloader with a customized ```collate_fn```**

* Input argument that ```dataloader``` sends to a collate function:

    * A list whose length is the batch size
    * Each item in the list is ```dataset[index]```

In [None]:
def my_collate_fn(batch_before_collation):
    '''
    # Let us understand the input argument that dataloader sends to this collate function
    print("batch before collation:", batch_before_collation)
    print("type of batch before collation:", type(batch_before_collation))
    print("length of batch before collation:", len(batch_before_collation))
    print("One item in batch_before_collation:", batch_before_collation[0])
    print("type of one item in batch_before_collation:", type(batch_before_collation[0]))

    index_in_dataset = batch_before_collation[0]['sample_index'] # Find index in dataset
    print([torch.all(batch_before_collation[0][key] == dataset[index_in_dataset][key]) for key in batch_before_collation[0].keys()])
    '''

    batch_size = len(batch_before_collation) # Compute batch size
    dictionary = {key: []  for key in batch_before_collation[0].keys() } # Initialize dictionary

    for i in range(batch_size):
        for key in dictionary.keys():
            dictionary[key].append(batch_before_collation[i][key])

    # Construct a collated batch
    batch = dict()
    for key in dictionary.keys():
        batch[key] = torch.stack(dictionary[key], dim = 0)
    return batch

In [None]:
batch_size = 4

dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = False, collate_fn = my_collate_fn)

batch = next(iter(dataloader))
print(type(batch))
print(batch['image'].shape)
print(batch['label'].shape)
print(batch['sample_index'].shape)

<class 'dict'>
torch.Size([4, 3, 128, 256])
torch.Size([4])
torch.Size([4])


# **Exercise: Build a dataloader for the Next Sentence Prediction (NSP) whose dataset was built in Chapter "PyTorch - Dataset"**

**NSP dataset:**

* Each sample is a dictionary with keys ```input``` and ```label```
* An input is the concatenation of two sequences of token IDs

**Build a raw dataset**

In [None]:
num_samples = 10
vocab_size = 30 # token IDs: 0, 1, ..., vocab_size - 1
length_LB = 5
length_UB = 12

raw_dataset = []
for n in range(num_samples):
    # Generate sentence 1
    length_sen1 = torch.randint(low = length_LB, high = length_UB, size = ())
    sen1 = torch.randint(low = 3, high = vocab_size, size = (length_sen1,))
    sen1 = torch.cat([sen1, torch.tensor([2])], dim = 0)

    # Generate sentence 2
    length_sen2 = torch.randint(low = length_LB, high = length_UB, size = ())
    sen2 = torch.randint(low = 3, high = vocab_size, size = (length_sen2,))
    sen2 = torch.cat([sen2, torch.tensor([2])], dim = 0)

    # Put two sentences to a list
    raw_dataset.append([sen1, sen2])

**Define Dataset class**

In [None]:
class MyNSPDataset(Dataset):
    def __init__(self, raw_dataset, num_noisy_samples):
        num_samples_raw = len(raw_dataset) # Number of samples in the raw dataset
        self.num_samples = num_samples_raw * (1 + num_noisy_samples) # Number of samples in the dataset that we are constructing
        self.inputs = [] # List of all paired sentences
        self.labels = [] # List of labels of all paired sentences

        # Extract sentence 1s and 2s, respectively, from the raw dataset
        sen1_raw_list = []
        sen2_raw_list = []
        for n in range(num_samples_raw):
            sen1_raw_list.append(raw_dataset[n][0])
            sen2_raw_list.append(raw_dataset[n][1])

        # Compute probability distribution of sentence 2s that are randomly drawn
        prob_sen2_indices = 1/num_samples_raw * torch.ones(num_samples_raw)

        for n in range(num_samples_raw):
            # Add the nth ground-truth pair of sentences to the dataset
            self.inputs.append(torch.cat([torch.tensor([1]), sen1_raw_list[n], sen2_raw_list[n]]))
            self.labels.append(torch.tensor(True))

            # Randomly generate sentence 2 indices
            sen2_indices = torch.multinomial(input = prob_sen2_indices, num_samples = num_noisy_samples, replacement = True)

            # Add the nth sentence 1 and each randomly selected sentence 2 to the dataset
            for m in range(num_noisy_samples):
                self.inputs.append(torch.cat([torch.tensor([1]), sen1_raw_list[n], sen2_raw_list[sen2_indices[m]]]))
                self.labels.append(sen2_indices[m] == n)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        input = self.inputs[index]
        label = self.labels[index]
        return {'input': input, 'label': label}


**Construct dataset**

In [None]:
dataset_NSP = MyNSPDataset(raw_dataset, 2)

**Define a customized collate function**

* Inputs from different samples are with different lengths
* Shorter inputs shall be padded with token ID 0
* After padding, all inputs in a batch are with the same length

In [None]:
def collate_NSP_fn(batch_before_collation):
    batch_size = len(batch_before_collation)
    dictionary = {key: []  for key in batch_before_collation[0].keys() } # Initialize dictionary
    for i in range(batch_size):
        for key in dictionary.keys():
            dictionary[key].append(batch_before_collation[i][key])

    length_inputs = [dictionary['input'][i].shape[0] for i in range(batch_size)]
    max_length = max(length_inputs)

    # Padding
    for i in range(batch_size):
        dictionary['input'][i] = torch.cat([dictionary['input'][i], torch.zeros(max_length - length_inputs[i], \
                                                    dtype = torch.int64)], dim = 0)

    # Construct a collated batch
    batch = dict()
    for key in dictionary.keys():
        batch[key] = torch.stack(dictionary[key], dim = 0)
    return batch

**Define a dataloader for the NSP task**

In [None]:
batch_size = 4

dataloader_NSP = DataLoader(dataset_NSP, batch_size = batch_size, shuffle = True, collate_fn = collate_NSP_fn)

batch = next(iter(dataloader_NSP))
print(batch['input'].shape)
print(batch['label'].shape)

torch.Size([4, 24])
torch.Size([4])


**Copyright  Beaver-Edge AI Institute. All Rights Reserved. No part of this document may be copied or reproduced without the written permission of Beaver-Edge AI Institute.**