# PyTorch datasets and dataloaders

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../fmriDEEP'))
if module_path not in sys.path:
    sys.path.append(module_path)

The PyTorch [dataset and dataloader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) classes make it easy for us to use batching etc. Their tutorial is a good place to start understanding how they work and should be read alongside this tutorial here in case there are unclear things.

You can use them by importing the DataLoader and Dataset packages from torch.utils.data.

In [2]:
from torch.utils.data import DataLoader, Dataset, Subset

Unfortunately, I cannot provide Dataset classes for basically any usecase unless everyone follows the conventions that we follow in the Biomedical Imaging Group. 
Let us anyway look at a Dataset that I wrote to take care of Nifti files:

In [None]:
import os # we are commonly working with paths, so importing os is helpful
import numpy as np # numpy for some numeric operations and better array structures
import nibabel as nib # nibabel is for loading nifti files
import glob
from torch.utils.data import Dataset # we need to inherit from the PyTorch Dataset class


class NiftiDataset(Dataset):
    """
      NiftiLoader has torch functionality to rapidly generate and load new
      batches for training and testing.
    """

    def __init__(self, data_dir, labels, n, device, dims=3, shuffle_labels=False, transform=None):
        """
        Constructor for the NiftiDataset class
        
        :param data_dir:        path to the data
        :param labels:          list of class names (directories within data_dir)
        :param n:               the number of samples to load. If "0" take every example in directory.
        :param device:          the device to use (cpu|gpu)
        :param dims:            3 to keep the dimension, 1 to flatten into vector
        :param shuffle_labels:  in case one wants to train a null-model enable label shuffling. Using this for training
                                should lead to a network that provides information if labels would not matter. I.e.,
                                it should perform only at chance level.
        """

        self.device = device
        self.classes = labels
        self.dims = dims
        self.transform = transform

        # get the file paths and labels
        for iLabel in range(len(labels)):
            # look for all files in alphanumerical order in the label directory
            file_names = sorted(glob.glob(os.path.join(data_dir, labels[iLabel], "*.nii.gz")))
            # select only the requested number of files if n > 0
            n_files = len(file_names[:n]) if n != 0 else len(file_names)
            
            if iLabel == 0:
                self.data = np.array(file_names[:n_files])
                self.labels = np.array(np.repeat(labels[iLabel], n_files))
            else:
                self.data = np.append(self.data, file_names[:n_files])
                self.labels = np.append(self.labels, np.repeat(labels[iLabel], n_files))

        if shuffle_labels:
            self.labels = np.random.permutation(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx: int):
        """
        load a (batch) sample. This is usually done automatically by the Pytorch DataLoader class.
        
        :param idx: the index of the sample to load
        :return: tuple(volume, label)
        """
        
        # make sure that there are no NaNs in the data. 
        volume = np.nan_to_num(nib.load(self.data[idx]).get_fdata())
        
        volume[np.isnan(volume)] = 0 # this one is in here because I am paranoid
        
        # sometimes nibabel retains the temporal dimension. (x, y, z, t)
        # we do not want that so we get rid of it.
        if len(volume.shape) > 3:
            volume = volume.squeeze()

        volume = np.expand_dims(volume, 0) if self.dims == 3 else volume.flatten()  # add the channel dimension
        label = np.squeeze(np.where(np.array(self.labels[idx]) == np.array(self.classes)))

        # In case you provide a set of transformations execute them here
        if self.transform:
            label = self.transform(label).to(self.device)
            volume = self.transform(volume).float().to(self.device)
        else 
            label = label.to(self.device)
            volume = volume.to(self.device)

        return volume, label


In [5]:
%%script echo skipping
# the general setup of a Dataset class

class MyDataset(Dataset):
    def __init__(self):
        # do your initializations here
        pass
    
    def __len__(self):
        # returns the length of the dataset
        # this usually is done by taking the len() of the labels
        #return len(self.labels)
        pass
    
    def __getitem(self, idx):
        # this is where we actually load the data and labels
        # commonly the data and labels are returned as a tuple(data, label)
        # return loaded_data, loaded_label
        pass

skipping
