In [1]:
"""
In this second example, we will take a closer look at the dataloader.
"""

'\nIn this second example, we will take a closer look at the dataloader.\n'

In [2]:
from config.efficient_net_b2_pretrained import Configuration
# load configuration information
configuration = Configuration()


In [3]:
import os
import glob

import torch
import torchvision

from random import shuffle

from PIL import Image

"""
What you see here is fairly standard in pytorch.
We have created a custom dataset class, inheriting from torch.utils.data.Dataset.
The dataset knows the amount of samples in the dataset and have some iterator responsible for fetching new samples as needed.
A dataloader is a wrapper around a dataset such that we can fetch not just one example at a time, but a whole batch of them.
The dataloader also provides a bunch of performance improvements that we won't get into now, but which you can read more on in the advanced section later.
"""

def load_data_wrapper(path_input_train,path_input_val,path_input_test,n_train,n_val,n_test,batch_size_train,batch_size_val,batch_size_test, device,**kwargs):
    """
    A wrapper that allows us to handle different datasets.
    """

    dataloaders = load_cat_dog_data(path_input_train,path_input_val,path_input_test,n_train,n_val,n_test,batch_size_train,batch_size_val,batch_size_test, device)

    return dataloaders

def load_cat_dog_data(path_input_train,path_input_val,path_input_test,n_train,n_val,n_test,batch_size_train,batch_size_val,batch_size_test, device):
    dataloaders = {}

    dataset_train = ImageDataset(path_input_train,n_train,device=device)
    dataloaders['train'] = torch.utils.data.DataLoader(dataset_train,shuffle=True, batch_size=batch_size_train,drop_last=True)

    dataset_val = ImageDataset(path_input_val,n_val,device=device)
    dataloaders['val'] = torch.utils.data.DataLoader(dataset_val,shuffle=False, batch_size=batch_size_val)

    dataset_test = ImageDataset(path_input_test,n_test,device=device,training_data=False)
    dataloaders['test'] = torch.utils.data.DataLoader(dataset_test,shuffle=False, batch_size=batch_size_test)
    return dataloaders

def eval_transform():
    """
    A transform that crops an image to size of 256x256 converts it to a torch tensor
    and normalizes the data according to the imagenet dataset.
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.CenterCrop(256),  # crops the image in the center
        torchvision.transforms.ToTensor(), # This permutes the dimensions of the image such that they are ordered the way neural networks usually works with them and then converts their datatype into floats or doubles.
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform

class ImageDataset(torch.utils.data.Dataset):
    """
    Our custom dataset used for cats and dogs images.
    """

    def __init__(self, path,n_samples=-1,extension='jpg',transform=eval_transform(),training_data=True,device='cpu'):
        self.path = path
        self.n_samples = n_samples
        self.transform = transform
        self.training_data = training_data
        self.device = device
        search_param = os.path.join(path,f"*.{extension}")
        files = glob.glob(search_param)
        assert len(files)>0, f"Searching for images files in {search_param} gave zero hits. Make sure you already downloaded the dataset using /data/download_cats_and_dogs.py"
        shuffle(files) # Remember to shuffle the pictures before you extract a subset of them!
        self.files = files[:n_samples]
        self.label_names = ['cat', 'dog']
        return

    def __len__(self):
        return len(self.files)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + ')'

    def __getitem__(self, index):
        file = self.files[index]
        name = file.split('/')[-1]
        if self.training_data:
            if "cat" in name:
                label = 0
            elif "dog" in name:
                label = 1
            else:
                raise ValueError(f"{file} does not contain dog or cat label.")
        else:
            label = -1

        # load image
        image = Image.open(file)

        if self.transform:
            image = self.transform(image)

        image = image.to(device=self.device)
        label = torch.tensor(label,device=self.device,dtype=torch.int64)
        return image, label





NameError: name 'data' is not defined

In [None]:
dataloaders = load_data_wrapper(**configuration)
dataloader_train = dataloaders['train']

# A few questions to consider:

1. How do we get a sample out of a dataloader?
2. How do we know how many samples are in a dataloader?
3. Why are we returning 3 different dataloaders?
4. What happens if an image has both the name cat and dog in its filename?, what should happen?

# Some more advanced questions:

5. When we extract a random subset of our data, (as we currently do in the dataloader), we will likely not end up with a completely balanced subset, is this a problem? how could you fix this?
6. Our current setup handles images of different sizes by using a transformation that crops them to the centermost 256x256 pixels. When could this be a problem, and what are some alternatives?
7. Currently, we do not have any data augmentation. What kind of data augmentation might we do on images?

## Further considerations 
In our example we used accuracy as a measure, but for unbalanced datasets this might not be a good measure.
   (imagine an unbalanced dataset with 99 cat images and 1 dog image.)
   by always predicting cat we will reach 99% accuracy on such a dataset.

   Alternatively our dataset could be balanced but the importance of different predictions could be skewed.
   For cancer screening for instance a false negative might mean that you miss that someone has cancer, whereas a false positive means that an unnecessary person gets additional testing.

   So depending on the problem and the importance of true positives, true negatives, false positives and false negatives, we have various measures that are usefull.
   For more information search google for: "f1-score, recall, precision, accuracy"


