### Goal
* how to decouple the training code from the dataset code
* how to process large amounts of data
* understand Pytorch particularities
* end-to-end example

### Imports

In [None]:
import torch as th
import torchvision.datasets
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pickle
import os
from PIL import Image

## Data loading

Two data primitives to achieve decoupling: 
* <b>torch.utils.data.Dataset</b> as a data store for (sample, label) pairs
* <b>torch.utils.data.DataLoader</b> wraps an iterator over the data store

Different built-in datasets that subclass <b>torch.utils.data.Dataset</b>:
* <b>Image datasets</b>: classification, object detection, segmentation, optical flow, stereo matching, 3D reconstruction, captioning, video classification, etc.
* <b>Text datasets</b>: text classification, language modeling, machine translation, tagging, question answering, etc.
* <b>Audio datasets</b>: speaker verification, music genre recognition, emotion recognition, source separation, etc.

More on TorchData project (Beta): https://pytorch.org/data/beta/index.html

In [None]:
dataset = torchvision.datasets.CIFAR10(root='./cifar10',
                                       train=True,  # False for test set
                                       download=True,
                                       transform=ToTensor())  # more on data augmentation

Visualize a sample:

In [None]:
num_samples = len(dataset)
idx = np.random.randint(0, num_samples)
sample, label = dataset[idx] # sample, label
plt.imshow(sample.numpy().transpose(1, 2, 0))
plt.title(dataset.classes[label])

<b>Conclusions:</b> <br>
A PyTorch dataset object behave like a Python iterable 
* We can index elements
* We can iterate over it one element at a time 

It acts as a datastore: 
* Each element of the list consists of (sample, label) pair
* We need a different mechanism to pass the samples to the training loop i.e. DataLoader

### Iterator DP

Q: How a data structure is traversed?

![data_structure_navigation](../Presentations/assets_data_handling/data_structure_navigation.png)

Yields one item at a time without exposing the data structure (dict, list, set, file, tuple, generator - these are already iterable). It doesn't matter if the data structure is linear or not (e.g., tree) => <b>no need to understand the internal representation of the data structure.</b>

The pattern involves 3 objects:
1. container: aggregate, collection, object whose content is iterable
2. item: elements of the container
3. iterator: sequential access to items <br>

It powers for-loops and list comprehensions.

In [None]:
# Maybe change with a tree but might be too confusing
class CustomIterator:
    def __init__(self, custom_list):
        self._custom_list = custom_list
        self._idx = 0

    def __iter__(self):  # must implement to be considered iterator
        return self
        
    def __next__(self):
        if self._idx < len(self._custom_list):
            value = self._custom_list[self._idx]
            self._idx += 1
            return value
        else:
            raise StopIteration

class CustomList:
    def __init__(self, n=10):
        self.elems = range(0, n)
        self.len = n
    
    def __iter__(): # this should return an Iterator object
        return CustomListIterator(self)
        
    def __len__():
        return self.len  

#### Why not combine?

In [None]:
class CustomListIterator:
    def __init__(self, n=10):
        self._custom_list = range(0, n)
        self._idx = 0
        self.len = n
        
    def __len__(self):
        return self.len

    def __iter__(self):
        return self

    def __setitem__(self, idx, value):
        self._custom_list[idx] = value

    def __getitem__(self, idx):
        return self._custom_list[idx]

    def __next__(self):
        if self._idx < len(self._custom_list):
            value = self._custom_list[self._idx]
            self._idx += 1
            return value
        else:
            raise StopIteration

In [None]:
custom_list_iter = CustomListIterator(10)
for x in custom_list_iter:
    print(x+1, "/", len(custom_list_iter))

### Generators

Function that returns an iterator that produces a (potentially large) sequence of values when iterated over that would otherwise not fit into memory at once.

Iterables need to be stored in memory but you have more flexibility over the state.

It preserves state between 2 yield calls.

yield = return Generator

In [None]:
def my_generator(n):
    # n can be very large
    cnt = 0
    while cnt < n:
        yield cnt
        cnt += 1
        
for value in my_generator(2):
    print(value)
    
generator = my_generator(2)
print(next(generator))
print(next(generator))

### Datasets

When implementing a torch.utils.data.Dataset one should override:
* <b>\_\_len__</b> so len(dataset) returns the total number of elements 
* <b>\_\_getitem__</b> to support indexing over elements

In [None]:
class CustomImageDatasetV1(th.utils.data.Dataset):
    def __init__(self, dataset):
        # this would be read from a file during iteration
        # here we would just pass some paths
        self.images = dataset.data
        self.labels = dataset.targets
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image, label =  self.images[idx], self.labels[idx] 
        return image, label

PyTorch supports 2 types of datasets:
* map-style (overrides Dataset & implements \_\_getitem()__ and \_\_len__())
* iterable-style (overrides IterableDataset & implements \_\_iter__() i.e. there is no notion of key or index): suitable when the batch size depends of the fetched data e.g., data comes from a stream)

Also check specializations of these 2 categories: https://pytorch.org/docs/stable/data.html#torch.utils.data.StackDataset 

When dealing with large datasets a good practice is:
1. Create a dictionary where you gather:
* in partition['train'] a list of training IDs (where ID can be the index of image, path of the file, etc.)
* in partition['validation'] a list of validation IDs

2. Create a dictionary called labels where for each sample ID, the associated label is given by labels[ID]

In [None]:
class Dataset(th.utils.data.Dataset):
  def __init__(self, list_IDs, labels):
    self.labels = labels
    self.list_IDs = list_IDs

  def __len__(self):
    return len(self.list_IDs)

  def __getitem__(self, index):
    ID = self.list_IDs[index]
   
    X = torch.load('data/' + ID + '.pt')
    y = self.labels[ID]

    return X, y

### Data transformations

Goal: manipulate data to make it suitable for training. <br>
Use cases:
* Preprocessing  <br>
* Augmentation <br>

Transforms

In [None]:
class Rotate90Left(object):
    def __call__(self, image):
        return np.rot90(image)

<b>\_\_getitem__</b> can optionally support preprocessing/augmentation functionality => Transforms.

In [None]:
class CustomImageDatasetV2(th.utils.data.Dataset):
    def __init__(self, dataset, transforms=None, target_transforms=None):
        # this would be read from a file during iteration
        # here we would just pass some paths
        self.images = dataset.data
        self.labels = dataset.targets
        self.transforms = transforms
        self.target_transforms = target_transforms
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image, label =  self.images[idx], self.labels[idx] 
        if self.transforms:
            image = self.transforms(image)
        if self.target_transforms:
            label = self.target_transforms(label)
        return image, label

In [None]:
custom_dataset = CustomImageDatasetV2(dataset, transforms=None)
plt.imshow(custom_dataset[0][0])
plt.show()
custom_dataset = CustomImageDatasetV2(dataset, transforms=Rotate90Left())
plt.imshow(custom_dataset[0][0])

Composing Transforms

In [None]:
from torchvision.transforms import v2
# faster and arbitrary input structures like dicts, lists, tuples
# Also check:  https://albumentations.ai/

transforms = v2.Compose([
    # preprocessing
    v2.ToImage(),
    v2.ToDtype(th.float32, scale=True), 
    # augmentation,
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
])

In [None]:
custom_dataset = CustomImageDatasetV2(dataset, transforms=None)
plt.imshow(custom_dataset[0][0])
plt.show()
custom_dataset = CustomImageDatasetV2(dataset, transforms=transforms)
plt.imshow(custom_dataset[0][0].permute(1, 2, 0))
plt.show()

During training vs Before training 

Before training = Fixed amount of augmentation<br>
During training = Variable amount of augmentation

### Dataloaders

A Dataset retrieves samples one at a time. => iterable over samples.

During training, a DataLoader will pass samples in minibatches. => iterable over batches of samples.

Parameters:
* object implementing torch.utils.data.Dataset
* batch_size
* shuffle: reshuffle data every epoch to reduce overfitting
* uses Python multiprocessing to speed up data retrieval


In [None]:
data_loader = th.utils.data.DataLoader(dataset,
                                       batch_size=4,
                                       shuffle=True,
                                       num_workers=4)

#### Inner workings of DataLoader

Multiple components:
* Generator: yields batches of data
* Collator: formulated as collate_fn combines samples into batches
* Sampler: sequential or shuffled sampler will be constructed based on shuffle argument - select indices of samples to be batched

Example for CIFAR10:
Each sample is a tuple of (image, label). 
The Sampler will select the indices of the tuples to be batched. i.e. number of batches lists of indices.
These lists are then combined into batches by the collate_fn function => list of tensors with first dimension=batch_size. 
If initially, for a batch we had a list of tuples in the end we will have a tuple of batched images and batched labels.
DataLoader then yields batched samples.

1. Sampler

In [None]:
indices = list(th.utils.data.BatchSampler(th.utils.data.RandomSampler(range(10)), batch_size=3, drop_last=False))
print(indices)

2. (Default) Collator
* prepends the batch dimension
* automatically converts Numpy and Python data into PyTorch tensors
* preserves the data structure (e.g., if each sample is a dict, the output will be a dict; same for list and tuple)

In [None]:
th.utils.data.default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])

3. Generator

Roughly this is what happens (not sure about the generator):

In [None]:
def data_loader(dataset, batch_sampler, collate_fn):
    for indices in batch_sampler:
        yield collate_fn([dataset[i] for i in indices])

In [None]:
loader = data_loader(dataset,
                    th.utils.data.BatchSampler(th.utils.data.RandomSampler(range(len(dataset))), batch_size=2048, drop_last=False),
                    th.utils.data.default_collate)

for batch in loader:
    img, labels = batch
    print(img.shape, labels.shape)

#### Multiprocessing in DataLoaders

GIL prevents true parallelization -->  computation/training code is blocked by data loading code --> setting num_workers > 0 switches to multi-process data loading

Two strategies:

1. Single-process data loading (default): data fetching is done in the same process a DataLoader is initialized (e.g., main). It may be preffered when datasets are small.
2. Multi-process data loading: num_workers processes are created each of which gets dataset, collate_fn and worker_init_fn => fetch, transforms & collate run in each worker. Only the workers will retrieve data, main process won't. Batch indices are generated by the batch sampler in the main process and sent to each worker while the main process waits until the batch is retrieved by the assigned worker --> can lead to high I/O load due to these exchanges between processes and high GPU memeory consumption as we load more data at once.

On windows you might get an error: <br>
solution - not respawn the processes each epoch: persistent_workers=True

### Saving and loading

* Numpy Archives
* Pickle

In [None]:
class CustomImageDatasetV3(th.utils.data.Dataset):
    def __init__(self, dataset, transforms=None, target_transforms=None):
        # this would be read from a file during iteration
        # here we would just pass some paths
        self.images = dataset.data
        self.labels = dataset.targets
        self.transforms = transforms
        self.target_transforms = target_transforms
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image, label =  self.images[idx], self.labels[idx] 
        if self.transforms:
            image = self.transforms(image)
        if self.target_transforms:
            label = self.target_transforms(label)
        return image, label

    def save_data(self, path):
        if isinstance(self.images, np.ndarray):
            # save for single array -> .npy
            # savez for multiple arrays -> .npz
            np.savez(path, images=self.images, labels=self.labels)
        else:
            # slower
            # dumps is to represent as byte object
            with open(path, "wb") as file:
                pickle.dump({"images": self.images, "labels": self.labels}, file)
        
    def load_data(self, path):
        if os.path.exists(path+".npz"):
            npz = np.load(path+".npz")
            self.images, self.labels = npz['images'], npz['labels'] 
        else:
            with open(path, "rb") as file:
                pkl = pickle.load(file)
            self.images, self.labels = pkl['images'], pkl['labels'] 

In [None]:
custom_dataset = CustomImageDatasetV3(dataset)
custom_dataset.save_data("ds_v3")
custom_dataset.load_data("ds_v3")

### End-to-end example: Adaptive Cruise Control (ACC)

![acc_explained.png](../Presentations/assets_data_handling/acc_explained.png)

### ACC data

In [None]:
for root, dirs, files in os.walk("../Presentations/assets_data_handling/ACC"):
    if files:
        print(root, ":", files)

In [None]:
dataset = torchvision.datasets.ImageFolder(root="../Presentations/assets_data_handling/ACC")

print(len(dataset))
print_details = lambda x: print(x[0].shape, x[1])

print_details(dataset[0])
print_details(dataset[78])

idx = np.random.randint(0, len(dataset))
sample = dataset[idx][0]
plt.imshow(sample)

### Hyperparameters

In [None]:
BATCH_SIZE=32
NUM_EPOCHS = 50
device = ("cuda" if th.cuda.is_available() else "cpu")

### Data transformations

![hsv_space.png](../Presentations/assets_data_handling/hsv_space.png)

In [None]:
class HighlightRoad:
    def __init__(self, intensity_factor=2):
        self.intensity_factor=intensity_factor

    def __call__(self, img):
        # img is a Pillow image
        
        img_hsv = np.array(img.convert('HSV'))
        val_min = 0.2  
        val_max = 0.8  
        sat_max = 0.3
        saturation_values = img_hsv[:, :, 1] / 255.0
        hue_values = img_hsv[:, :, 2] / 255.0 
        gray_mask = np.logical_and(hue_values >= val_min, hue_values <= val_max)
        gray_mask = np.logical_and(gray_mask, saturation_values <= sat_max)

        img_hsv[:, :, 2] = np.where(gray_mask, np.minimum(255, img_hsv[:, :, 2] * self.intensity_factor), img_hsv[:, :, 2])
        img_rgb = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
        return img_rgb

In [None]:
plt.imshow(sample)
plt.show()
sample2 = HighlightRoad()(sample)
plt.imshow(sample2)

In [None]:
def to_tensor(x):
    return th.tensor([x])

def to_one_hot(y):
    return th.nn.functional.one_hot(y, num_classes=2)

def simple_norm(x):
    return x / 255.0

transforms = v2.Compose([
    HighlightRoad(),
    # preprocessing
    v2.ToImage(),
    v2.ToDtype(th.float32, scale=True), 
    v2.Resize((128, 128), antialias=True),
    v2.Lambda(simple_norm),
    # augmentation,
    v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 0.1))], p=0.5)
])


target_transforms = v2.Compose([
    v2.Lambda(to_tensor),
    # v2.Lambda(to_one_hot),
])

More examples here: https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py

In [None]:
dataset = torchvision.datasets.ImageFolder(root="../Presentations/assets_data_handling/ACC",
                                           transform=transforms,
                                           target_transform=target_transforms)

### Train/Val/[Test] Split

In [None]:
while 1:
    train_set, val_set = th.utils.data.random_split(dataset, [90, 10])
    val_labels = [x[1].item() for x in val_set]
    print(val_labels)
    if sum(val_labels) == 5:
        break

In [None]:
train_dataloader = th.utils.data.DataLoader(train_set,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

val_dataloader = th.utils.data.DataLoader(val_set,
                                           batch_size=BATCH_SIZE,
                                           shuffle=False)

### Epoch-level utility functions

In [None]:
def train(train_dataloader):
    avg_train_loss = 0.0
    for batch in train_dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        avg_train_loss += loss.item()
    return avg_train_loss/len(train_dataloader)


def evaluate(eval_dataloader):
    avg_val_loss = 0.0
    with th.no_grad():
        for batch in eval_dataloader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            avg_val_loss += criterion(outputs, labels.float()).item()
    return avg_val_loss/len(eval_dataloader)

### Model

In [None]:
class Net(th.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = th.nn.Conv2d(3, 16, 3)
        self.pool1 = th.nn.MaxPool2d(4, 4)
        self.bn1 = th.nn.BatchNorm2d(16)
        
        self.conv2 = th.nn.Conv2d(16, 32, 3)
        self.pool2 = th.nn.MaxPool2d(4, 4)
        self.bn2 = th.nn.BatchNorm2d(32)

        self.conv3 = th.nn.Conv2d(32, 64, 3)
        self.pool3 = th.nn.MaxPool2d(4, 4)
        self.bn3 = th.nn.BatchNorm2d(64)
        
        # self.gap = th.nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = th.nn.Flatten()
        
        self.fc = th.nn.Linear(64, 1)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = th.nn.functional.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = th.nn.functional.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = th.nn.functional.relu(x)
        x = self.pool3(x)

        x = self.flatten(x)
        # x = self.gap(x)        
        # x = x.view(x.size(0), -1)

        x = self.fc(x)
        x = th.nn.functional.sigmoid(x)
        return x

model = Net()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

model.to(device)

criterion = th.nn.BCELoss()
optimizer = th.optim.Adam(model.parameters(), lr=1e-3)

### Actual training

In [None]:
cnt = 0
for epoch in range(NUM_EPOCHS):
    train_loss = train(train_dataloader)
    val_loss = evaluate(val_dataloader)
    print(f"Epoch \t {epoch} train loss \t {train_loss} val loss \t {val_loss}")

    if (val_loss-train_loss)/val_loss > 0.5:
        cnt += 1
    else:
        cnt = 0
        
    if cnt == 3:
        print("Early stopping due overfitting")
        break

### Test

In [None]:
def test():
    idx = np.random.randint(0, len(val_set))
    img, gt = val_set[idx]
    img = img.to(device)
    with th.no_grad():
        pred = model(img[None, ...])[0]
    plt.imshow(img.cpu().permute(1, 2, 0).numpy() * 255.0)
    plt.title(f"Pred prob  {round(pred.item(), 2)}  GT {gt.item()}")

In [None]:
test()

### Resources:
* https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel
* https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
* https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
* https://pytorch.org/docs/stable/data.html