In [None]:
# default_exp datasets

In [None]:
#hide
%load_ext autoreload
%autoreload 2


# Dataset Classes

>Classes extending `torch.utils.data.Dataset` for the `torch.utils.data.DataLoader` objects to iterate over.

Two different classes are provided:

* `FaceClassificationDataset` to iterate over training/validation/test pictures, returning a tensor image and its corresponding label (index) in training/validation mode, and a tensor_image in test mode.

* `FaceClassificationDataset` to iterate over training/validation/test pictures, returning a tensor image and its label (index), in training mode, a pair of tensor images and their corresponding label, in validation mode, and a pair of tensor images in test mode.

To convert `.jpg` images to tensor, the Image module from `Pillow` is used.





In [None]:
#export
# imports
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
from PIL import Image

In [None]:
#export
class FaceClassificationDataset(Dataset):
    """ Face Classification Dataset
    
    Class inheriting from the torch.data.utils.Dataset class.

    Fields
    ------



    Methods
    -------

    
    """
    def __init__(self,
                 sample=None,
                 mode='train'):
        
        # Assertions to avoid wrong inputs
        assert mode in ['train', 'val', 'test']
        assert mode == 'test' and 'sample' != None or \
            mode != 'test'
        if sample is not None:
            assert isinstance(sample, (list, np.ndarray))
            assert mode != 'test'

        self.mode = mode

        # Directory setup
        if mode == 'train': 
            self.data_dir = '../nbs/data/s1/train_data'
        elif mode == 'val':
            self.data_dir = '../nbs/data/s1/val_data'
        else:
            self.data_dir = '../nbs/data/s1/test_data'

        # Labels
        if (mode == 'train' or mode == 'val'):
            if sample is not None:
                sample = np.array(sample)
                # assert files.sum() + len(files) == \
                #     len(files)*(len(files) + 1)//2
                assert sample.min() >= 0
                self.labels = np.array(sample)
                self.labels.sort(axis=0)
            else: 
               self.labels = [int(d) for d in os.listdir(self.data_dir)]
               self.labels = np.array(self.labels)
               self.labels.sort(axis=0)
        else:
            self.labels = os.listdir(self.data_dir)
            self.labels = np.array([int(f.split('.')[0]) for f in self.labels])
            self.labels.sort(axis=0)

        self.map_files = []
        for l in self.labels:
            temp_ls = [(l, f) for f in \
                os.listdir(os.path.join(self.data_dir, str(l)))]
            temp_ls = [(t[0], t[1]) for t in temp_ls]
            self.map_files.append(temp_ls)
        
        self.map_files = [t for sl in self.map_files for t in sl]
        self.labels = [t[0] for t in self.map_files]
        self.X = [t[1] for t in self.map_files]        

    def __len__(self): return len(self.X)

    def __getitem__(self, idx):
        image_path = os.path.join(self.data_dir, 
                                  str(self.labels[idx]), 
                                  self.X[idx])
        print(image_path)
        image_tensor = Image.open(image_path)
        image_tensor = torchvision.transforms.ToTensor()(image_tensor)
        if self.mode == 'test': return image_tensor
        else: return image_tensor, self.labels[idx]

In [None]:
test_dataset = FaceClassificationDataset(np.random.randint(1,100,20), mode='train')


In [None]:
print(test_dataset.data_dir)
print(len(test_dataset.map_files))
print(test_dataset.map_files[0])
print(test_dataset.X[0])
print(test_dataset.labels[0])
print(test_dataset.__len__())

../nbs/data/s1/train_data
1800
(2, '0004_01.jpg')
0004_01.jpg
2
1800


In [None]:
!pwd

/mnt/c/Users/aleja/google drive alejandro/cmu/spring_2021/idl/hw/hw2p2/nbs


In [None]:
img, label = test_dataset.__getitem__(0)

../nbs/data/s1/train_data/2/0004_01.jpg


In [None]:
print(img.shape)

torch.Size([3, 64, 64])
