# Example of loading a dataset using the CustomImageDataset class

## Libraries import

In [None]:
import torchvision.transforms as T
from CustomDataset import CustomImageDataset
import torch.utils.data as data

## Parameters

In [None]:
# Directory of the dataset
dataset_dir = '../dataset/clean_road/images'

# Directory and file name of the labels
label_dir = '../dataset/clean_road'
label_file = label_dir+'/metadata.csv'

# Custom transformation
#transform = T.Resize((256,256))
transform = T.Compose([T.ToTensor(),])

# Additional parameters for Dataloader
batch_size = 4
num_workers = 2

## Create dataset and load with Dataloader

In [None]:
# Create dataset using CustomImageDataset
dataset = CustomImageDataset(label_file, dataset_dir, transform, use_cv2=True)
print(len(dataset))

# Load the dataset using DataLoader
data_batch = data.DataLoader(dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers, drop_last=False)
print(len(data_batch))

## Inspect dataset

In [None]:
# Show the first n_batches
n_batches = 3

for j, (imgs, labels) in enumerate(data_batch):
  if j == n_batches:
    break
  print(imgs["name"])
  print(imgs["file"].shape[0])
  print(labels.shape)
  print(labels)

## Augmentation

In [None]:
# Define augmentation transform
augmentation_transform = T.Compose([
    T.ToPILImage(),
    T.AutoAugment(),
    #T.RandAugment(),
    #T.AugMix(),
    #T.TrivialAugmentWide(),
    T.ToTensor(),    
])

# Create new dataset with augmentation transform
transformed_dataset = CustomImageDataset(label_file, dataset_dir, augmentation_transform, use_cv2=True)

# Concatenate the original dataset and the transformed dataset to create augmented dataset
dataset_list = [dataset, transformed_dataset]
augmented_dataset = data.ConcatDataset(dataset_list)

# Create dictionary to track the position of datasets in the augmented dataset
augmentation_dictionary = {}
augmentation_dictionary["original"] = {"index": 0, "transform": transform}
augmentation_dictionary["AutoAugment"] = {"index": 1, "transform": augmentation_transform}


print(len(augmented_dataset))

# Load the augmented dataset using DataLoader
augmented_batch = data.DataLoader(augmented_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers, drop_last=False)
print(len(data_batch))
