(dataloader_chapter)=
# PyTorch datasets and dataloaders

The PyTorch [dataset and dataloader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) classes make it easy for us to use batching etc. Their tutorial is a good place to start understanding how they work and should be read alongside this tutorial here in case things are unclear.

You can use them by importing the ```DataLoader``` and ```Dataset``` packages from ```torch.utils.data```.

In [1]:
import os
from torch.utils.data import DataLoader, Dataset, Subset
from delphi.utils.tools import ToTensor

Unfortunately, I cannot provide Dataset classes for basically any usecase unless everyone follows the conventions that we follow in the Biomedical Imaging Group. However, I do propose that you employ a data tree like the one shown below:

```{image} ../_images/datastruct.png
:alt: example datastructure
:width: 150px
:align: center
```

This helps us in quickly identifying the label of a given class just by looking up the directory name. For some use cases, for example with tabular data (such as .csv or .xls datatypes) the labels usually are assigned in one column of the data. Then a directory structure like it is shown above is unecessary.

Let us now see what we need to implement for our custom dataset classes. The code block below depicts a template of the functions you absolutely need to implement when designing your own dataset class. To better understand what each of the functions do read the comments I provide:

In [2]:
# the general setup of a Dataset class
class MyDataset(Dataset):
    def __init__(self):
        r"""
        This is the so-called constructor of any class. As the name suggests
        this function is used to initialize important variables that the dataset
        needs to know. This construct function is automatically called when 
        you instantiate a new class, e.g., when you call data = MyDataset()
        This function can also have arguments. I suggest that the constructor
        should have at least these parameters: 
        
        Args:
            path_to_data (str): this could be either a path to mulitple files
            or the file itself.
            
            device (torch.device): which device would you like to use 
            (default=torch.device("cpu"))
            
            shuffle_labels (bool): this is a nice addition in case one wants
            to create a null performance measure (default: False)
            
            transform: in case one wants to transform the data (i.e., normalize,
            rotate, scale, etc.)
        """
        pass
    
    def __len__(self):
        r"""
        returns the length of the dataset
        this is done by taking the len() of the dataset
        
        return len(self.data)
        """
        pass
    
    def __getitem__(self, idx):
        r"""
        This is the function in which we actually load the data and the labels.
        In here it is common to set up the data such that you return a tuple of
        the data and the labels as their own variables. In this function we also
        move the data and labels to the supplied device and transform them according
        to the supplied transformation functions in self.transform (if they exist).
        
        Args:
            idx (int): indicates with datum and label to return.
            
        Returns:
            tuple(data, label)
        """
        pass

Ok, now that we have an intuition of what a dataset class looks like let us look at some implementations that I provide with my code.

The first implementation we look at is a ```TabularDataset``` class. This dataset can take a ```.csv``` or an excel compatible (e.g., ```.xls, .xlsx, .odf, .ods``` among others) file-type as input. 

```{note}
Do not be discouraged by the extra code that is in this implementation. This is to prevent erronous inputs etc. and is meant to guide the user. In most cases this is not necessary because you will commonly write code suited to your use case. I tried to be as general as possible.
```

In [3]:
# let us create a simple dataset that takes .csv or maybe other tabular files as input
from torch.utils.data import Dataset
import os
import torch
import pandas as pd
import numpy as np

# we call this class TabularDataset since that is what it is
class TabularDataset(Dataset): 
    r"""
    Class to take care of tabular files such as .csv, .xls, etc.
    """
    
    # we can support different types of tabular dataformats.
    # for now we support the file types given in the variables below.
    POSSIBLE_EXCEL_EXTENSIONS = [".xls", ".xlsx", ".xlsm", ".xlsb", ".odf", ".ods", ".odt"]
    
    # in this implementation it is required that at one column has one of the following 
    # descriptions. If it does not exist in the data raise an error.
    EXPECTED_LABEL_COLUMN_NAMES = ["class", "label", "target"]
    
    ####### REQUIRED CLASS FUNCTIONS ########
    
    # now we come to the so-called constructor or the initializer function
    # I personally like having the option to add transformation functions and 
    # the option to shuffle the labels. This allows me to quickly create
    # a null distribution/performance estimate
    def __init__(
        self, 
        path_to_file, 
        device=torch.device("cpu"), 
        shuffle_labels=False, 
        transform=None
    ):
        r"""
        The constructor of the TabularDataset class.
        
        Args:
            path_to_file (str): the path to the file. Supports .csv, .xls file-types at the moment
            device (torch.device): the device on which to store the data
            shuffle_labels (bool): Default=False; permutes the class labels
            transform: can be a list of functions to transform the data
        """     
        super(TabularDataset, self).__init__()
        
        self.path_to_file   = path_to_file
        self.shuffle_labels = shuffle_labels
        self.transform      = transform
        self.device         = device
        
        # we can check what the file extension of the supplied file is. 
        # this informs us which function to use to read the file.
        filename, file_extension = os.path.splitext(self.path_to_file)
        
        # read the file 'path_to_file' with pandas reading functions
        if file_extension == '.csv':
            self.data = pd.read_csv(self.path_to_file)
            
        elif file_extension in self.POSSIBLE_EXCEL_EXTENSIONS:
            self.data = pd.read_excel(self.path_to_file)
            
        else:
            raise ValueError(f"{file_extension} is not \'.csv\' or one of {self.POSSIBLE_EXCEL_EXTENSIONS}")
        
        # check if a ["class", "label", "target"] column is found in the data
        # if not, raise an error
        self.label_column = self._check_for_label_column()
        
    def __len__(self):
        r"""
        returns the length, i.e. the number of samples, of the dataset
        """
        return len(self.data)
    
    def __getitem__(self, idx):
        r"""
        get the (batch) sample and label. A sample is one row of the dataset.
        A column represents one feature. 
        
        Returns:
            tuple(sample, label)
        """
        # select all the data except for the label column
        data = self.data.loc[:, self.data.columns != self.label_column].to_numpy()
        
        # extract only the requested row
        sample = data[idx, :]
        
        # assign the corresponding label to the row
        label = self.data[self.label_column].to_numpy()[idx]
        
        # In case you provide a set of transformations execute them here
        if self.transform:
            label = self.transform(label).to(self.device)
            sample = self.transform(sample).float().to(self.device)
        
        return (sample, label)
    
    ####### CUSTOM / HELPER FUNCTIONS ########
    
    def _check_for_label_column(self):
        r"""
        make sure the dataset has a column indicating the label, class, or target
        
        Returns:
            label_column: the column name of the target/class/label
        """
        # make sure all column values are lowercase
        columns = [column_name.lower() for column_name in self.data.columns.to_list()]
        
        # check if there is a column with one of these descriptions: ["class", "label", "target"]
        # if not, raise an error and indicate to the user that they need a label column with
        # the description in EXPECTED_LABEL_COLUMN_NAMES
        label_column = [col_name for col_name in self.EXPECTED_LABEL_COLUMN_NAMES if columns.count(col_name) > 0]
        
        if not label_column:
            raise ValueError(f"Did not find a column indicating the {self.EXPECTED_LABEL_COLUMN_NAMES}")
        
        return label_column[0]

Right, so what was all this for? The goal of these datasets is to have a general setup such that we can exploit the power of the PyTorch ```DataLoader``` class.

The ```DataLoader``` class can take care of a number of things for us. 
* It is a ```generator``` which means we can easily iterate over all datapoints in our dataset
* we can set how many samples per batch we want
* we can shuffle the dataset
* we can distribute the data/batches to multiple workers, meaning we can exploit parallel processing
* etc...

I created a ```dummy_ds.csv``` file to demonstrate how this works.

In [4]:
# We now load the data stored within the dummy_ds.csv file
data = TabularDataset('dummy_ds.csv')

If we wanted to, we could now access each sample by using the ```__getitem__``` function and supply a single index as shown below. On the other hand if we wanted to get a batch, we could also supply a list of indices to the function.

In [5]:
# get a single sample with its label
sample, label = data.__getitem__(10)
print("Sample data: ", sample, "\nLabel: ", label)

# get a batch of 4 samples with their respective labels
sample, label = data.__getitem__([2, 4, 0, 10])
print("\nBatched Sample data: ", sample, "\nBatched Labels: ", label)

Sample data:  [ 0.34810022  1.10560502 -0.12978938  1.68536215 -0.01200362  2.12168985
 -0.53985801 -0.45726264 -0.57580312  0.96060891 -0.46569884  0.94239489
  1.63580326  0.39518677 -0.49509537 -0.4432387   1.01888881 -0.68882815
  0.05426619  0.03923666] 
Label:  0

Batched Sample data:  [[-1.25744765  0.99514179  0.76161641  1.27121089 -0.1042117  -0.42573839
  -0.66558342  0.88723275  0.1171774  -0.87590716  1.39366192  0.75422015
  -1.61686852  0.22071409 -0.89460092  0.63354489  1.32051097 -0.23074283
   0.99393088  0.35410665]
 [ 0.96989177  0.28143284 -0.35807049  1.37952275  0.07642126  0.26766915
  -0.86966526  0.64460088 -0.33722325 -0.35106016  1.05169211  0.79733317
  -1.98811722 -0.52962523  0.71545117  0.21856745  1.52769592  0.35716858
  -0.29743858  2.74152369]
 [ 0.62496938 -0.55165461  0.45645439  1.03613864  1.09365959 -2.26105163
   1.64820326  1.95124017 -0.03287803 -2.18782345 -0.59050062  1.41066745
  -0.93724356  0.01958021 -0.68527278  1.25714178  0.05581559

The thing is, we are lazy. We do not want to bother doing this ourselves, so we use the ```DataLoader``` class I mentioned above. This time, we will use a loop to see how it works:

In [6]:
# Let us define a dataloader that generates us a sample, label tuple
# with 4 samples. We also want the indices to be shuffled.
dl_tabular = DataLoader(data, batch_size=4, shuffle=True)

# in case you are interested in how many batches you dataloader contains
# simple use len(DataLoader)
print(f"Number of batches: {len(dl_tabular)}\n")

# now print the samples and labels:
for i, (samples, labels) in enumerate(dl_tabular):
    print("Batched Sample data: ", samples, "\nBatched Labels: ", labels, "\n")

Number of batches: 10

Batched Sample data:  tensor([[-0.8591, -1.7704, -0.1745, -0.9754,  0.5658,  0.1900, -0.7620, -0.7420,
          1.3585,  0.1110, -0.3965,  0.1332,  1.8672, -0.0978, -1.0952,  0.0330,
          1.8324,  0.0407, -1.0200, -0.5890],
        [ 0.1291, -1.7352,  0.7939, -1.7178,  0.6857, -0.9208, -0.3240, -1.6038,
          0.1892, -1.4704, -0.3705, -0.1547,  0.9224,  0.0558,  0.6375,  0.5315,
          1.4825,  0.5101, -1.3424,  1.1783],
        [-0.2652, -0.4826,  0.3204,  0.7555, -0.6018,  0.1626,  0.2298,  1.1416,
          1.0210,  1.1791, -0.6822, -0.4467, -0.7880, -0.5294, -0.6582,  0.6906,
          0.6552,  1.6260, -0.5466,  0.4498],
        [ 0.2328,  1.6465,  1.4599,  1.9211,  0.5543,  0.8886, -1.5124,  0.9497,
          0.0959,  2.7158, -0.9761,  0.5828, -0.0260, -0.2237, -0.1546,  1.9794,
          0.5556,  0.0892, -0.1464,  0.1042]], dtype=torch.float64) 
Batched Labels:  tensor([0, 0, 1, 0]) 

Batched Sample data:  tensor([[-4.5431e-01, -6.5284e-01,  5.

```{note}
Another cool thing these DataLoaders have is that you can provide a generator seed which helps in reproducibility.
```

Now since this course is actually targeted towards Neuroimagers I want to show you the code for the ```NiftiDataset```. Simply click on the button below to expand the code block. In case you want to import some datasets from my code you can find them in ```_core.utils.datasets```.

As an exercise you could try to make sense out of the code and see if you understand what it does. :)

In [7]:
import os # we are commonly working with paths, so importing os is helpful
import numpy as np # numpy for some numeric operations and better array structures
import nibabel as nib # nibabel is for loading nifti files
import glob
from torch.utils.data import Dataset # we need to inherit from the PyTorch Dataset class


class NiftiDataset(Dataset):
    """
      NiftiLoader has torch functionality to rapidly generate and load new
      batches for training and testing.
    """

    def __init__(
        self, 
        data_dir, 
        labels, 
        n, 
        device, 
        dims=3, 
        shuffle_labels=False, 
        transform=None
    ):
        """
        Constructor for the NiftiDataset class
        
        :param data_dir:        path to the data
        :param labels:          list of class names (directories within data_dir)
        :param n:               the number of samples to load. If "0" take every example in directory.
        :param device:          the device to use (cpu|gpu)
        :param dims:            3 to keep the dimension, 1 to flatten into vector
        :param shuffle_labels:  in case one wants to train a null-model enable label shuffling. Using this for training
                                should lead to a network that provides information if labels would not matter. I.e.,
                                it should perform only at chance level.
        :param transform:       A composition of transformation functions that should be applied to the data.
        """

        self.device = device
        self.classes = labels
        self.dims = dims
        self.transform = transform

        # get the file paths and labels
        for iLabel in range(len(labels)):
            # look for all files in alphanumerical order in the label directory
            file_names = sorted(glob.glob(os.path.join(data_dir, labels[iLabel], "*.nii.gz")))
            # select only the requested number of files if n > 0
            n_files = len(file_names[:n]) if n != 0 else len(file_names)
            
            if iLabel == 0:
                self.data = np.array(file_names[:n_files])
                self.labels = np.array(np.repeat(labels[iLabel], n_files))
            else:
                self.data = np.append(self.data, file_names[:n_files])
                self.labels = np.append(self.labels, np.repeat(labels[iLabel], n_files))

        if shuffle_labels:
            self.labels = np.random.permutation(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx: int):
        """
        load a (batch) sample. This is usually done automatically by the 
        Pytorch DataLoader class.
        
        :param idx: the index of the sample to load
        :return: tuple(volume, label)
        """
        
        # make sure that there are no NaNs in the data. 
        volume = np.nan_to_num(nib.load(self.data[idx]).get_fdata())
        
        volume[np.isnan(volume)] = 0 # this one is in here because I am paranoid
        
        # sometimes nibabel retains the temporal dimension. (x, y, z, t)
        # we do not want that so we get rid of it.
        if len(volume.shape) > 3:
            volume = volume.squeeze()

        volume = np.expand_dims(volume, 0) if self.dims == 3 else volume.flatten()  # add the channel dimension
        label = np.squeeze(np.where(np.array(self.labels[idx]) == np.array(self.classes)))

        # In case you provide a set of transformations execute them here
        if self.transform:
            label = self.transform(label).to(self.device)
            volume = self.transform(volume).float().to(self.device)
        else:
            label = label.to(self.device)
            volume = volume.to(self.device)

        return volume, label


That is pretty much it with the Datasets and DataLoaders. You will see many references to them in later chapters since they are one of the building blocks of PyTorch functionality. Make sure you understand what they do and how to build one for yourself.