# Example of loading a dataset using the CustomImageDataset class

## Libraries import

In [None]:
import torchvision.transforms as T
import torch.utils.data as data
import os

from CustomDataset import CustomImageDataset
from CustomDataset import augment_dataset as augment_dataset
from CustomDataset import show_augmented_dataset_info as show_augmented_dataset_info

## Parameters

In [None]:
# Constants
DATASET_DIR_SUFFIX = 'images'

# Directory of the dataset
dataset_type = 'clean_road'
dataset_dir = os.path.join('../dataset', dataset_type, DATASET_DIR_SUFFIX)

# Directory and file name of the labels
label_dir = os.path.join('../dataset', dataset_type)
label_file = label_dir+'/metadata.csv'

# Output directory


# Custom transformation
#transform = T.Resize((256,256))
transform = T.Compose([T.ToTensor(),])

# Additional parameters for Dataloader
batch_size = 4
num_workers = 2

## Create dataset and load with Dataloader

In [None]:
# Create dataset using CustomImageDataset
dataset = CustomImageDataset(label_file, dataset_dir, transform, use_cv2=True)
print(len(dataset))

# Load the dataset using DataLoader
data_batch = data.DataLoader(dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers, drop_last=False)
print(len(data_batch))

## Inspect dataset

In [None]:
# Show the first n_batches
n_batches = 3

for j, (imgs, labels) in enumerate(data_batch):
  if j == n_batches:
    break
  print(imgs["name"])
  print(imgs["file"].shape[0])
  print(labels.shape)
  print(labels)

## Augmentation

In [None]:
# Create augmentation transform list
augmentation_transform_list = []

# Define AutoAugment augmentation transform
augmentation_transform = T.Compose([
    T.ToPILImage(),
    T.AutoAugment(),
    #T.RandAugment(),
    #T.AugMix(),
    #T.TrivialAugmentWide(),
    T.ToTensor(),    
])
transform_name = "AutoAugment"

# Populate augmentation_transform_list
transform_dict = {"name": transform_name, "transform": augmentation_transform}
augmentation_transform_list.append(transform_dict)

# Apply augment_dataset function to create augmented dataset
augmented_dataset = augment_dataset(dataset, augmentation_transform_list, create_dict=False)
#augmented_dataset, augmentation_dictionary = augment_dataset(dataset, augmentation_transform_list, create_dict=True)

print(len(augmented_dataset))

# Load the augmented dataset using DataLoader
augmented_batch = data.DataLoader(augmented_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers, drop_last=False)
print(len(augmented_batch))

show_augmented_dataset_info(augmented_dataset)
#show_augmented_dataset_info(augmented_dataset, augmentation_dictionary)

## Split dataset

In [None]:
train_dataset, test_dataset = dataset.split_train_test()

print(len(train_dataset))
print(len(test_dataset))
print(len(dataset))