<img align="center" style="max-width: 1000px" src="figures/banner.png">

<img align="right" style="max-width: 200px; height: auto" src="figures/hsg_logo.png">

##  Lab 04- Custom Datasets in PyTorch

Machine Learning, University of St. Gallen, Spring Term 2024


In this tutorial, we want to implement a custom PyTorch dataset that processes images of a given dataset folder and prepares inputs for training and evaluation. Although the structure of datasets can significantly for vary, the principles in this tutorial should be applicable to any PyTorch dataset regardless of the folder structure or file formats.

Lab Objectives:
- Understand dataset structures and how to process dataset files.
- Learn how to implement a PyTorch dataset class.


## Example: A Multi-Folder Dataset

In this example, we have a dataset called **Omniglot** where the images of each class are inside a separate folder. We want to load the images inside each folder which corrsponds to a separate class and return them as instances of that class. 

First let's download the files that we need from this link: https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip

For more information about the dataset please refer to this link:
https://github.com/brendenlake/omniglot/

To read the file list from a folder we need a package called `glob`. We use pip to install the package:

In [1]:
# ! pip install glob2

Now, let's see how we can retrieve and print the list of folders for a given root directory:

In [2]:
import glob

# glob.glob("dataset/Greek/*")

folder_names = [f.split("/")[-1] for f in glob.glob("dataset/Greek/*")]
print(folder_names)

['character11', 'character16', 'character20', 'character18', 'character19', 'character21', 'character17', 'character10', 'character03', 'character04', 'character05', 'character02', 'character15', 'character12', 'character24', 'character23', 'character22', 'character13', 'character14', 'character09', 'character07', 'character01', 'character06', 'character08']


If each folder corresponds to a class, we need to map the class name to a class ID:

In [3]:
name_to_id = {name: id for (id, name) in enumerate(sorted(folder_names))}

print(name_to_id)

{'character01': 0, 'character02': 1, 'character03': 2, 'character04': 3, 'character05': 4, 'character06': 5, 'character07': 6, 'character08': 7, 'character09': 8, 'character10': 9, 'character11': 10, 'character12': 11, 'character13': 12, 'character14': 13, 'character15': 14, 'character16': 15, 'character17': 16, 'character18': 17, 'character19': 18, 'character20': 19, 'character21': 20, 'character22': 21, 'character23': 22, 'character24': 23}


Next, we extract the list of all images in the dataset and assign them their label IDs:

In [4]:
all_files = glob.glob("./dataset/Greek/*/*.png")
all_label = [name_to_id[path.split("/")[-2]] for path in all_files]

print(all_label)

[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,

Then, let's define a class that takes care of loading file lists and returning random samples from the dataset:

In [5]:
# CODE TO BE IMPLEMENTED DURING THE TUTORIAL SESSION
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image

class MyDataset(Dataset):
    def __init__(self, root, transform=None) -> None:
        super().__init__()

        self.transform = transform

        folder_names = [f.split("/")[-1] for f in glob.glob(root + "/*")]
        name_to_id = {name: id for (id, name) in enumerate(sorted(folder_names))}

        self.all_paths = glob.glob(root + "/*/*.png")
        self.all_label = [name_to_id[path.split("/")[-2]] for path in self.all_paths]
    
    def __len__(self):
        return len(self.all_paths)

    def __getitem__(self, index):
        path_i = self.all_paths[index]
        image = Image.open(path_i)

        if self.transform is not None:
            image = self.transform(image)
        
        label = self.all_label[index]

        
        return image, label


Finally, we need to test the implemented PyTorch dataset class.

In [6]:
my_transform = transforms.ToTensor()
my_dataset = MyDataset(root="./dataset/Greek", transform=my_transform)


In [7]:
len(my_dataset)

480

In [8]:
my_dataset[13]

(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]),
 10)

Now, we use the dataset to create a dataloader and iterate through its samples.

In [9]:
from torch.utils.data import DataLoader

my_dataloder = DataLoader(my_dataset, batch_size=32, num_workers=0)

In [10]:
for batch in my_dataloder:
    image, label = batch
    print(image.shape)

torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
torch.Size([32, 1, 105, 105])
