In [1]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
import import_ipynb
from glob import glob
import os.path
from PIL import Image
from torchvision import transforms

In [2]:
class UTKFace(Dataset):
    def __init__(self, image_paths, subset='train', transform=None):
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((32, 32)), 
                transforms.ToTensor(), 
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        else:
            self.transform = transform
        
        self.image_paths = image_paths
        self.images = []
        self.ages = []
        self.genders = []
        self.races = []
        
        for path in image_paths:
            filename = path[13:].split("_")
            if len(filename)==4:
                self.images.append(path)
                self.ages.append(int(filename[0]))
                self.genders.append(int(filename[1]))
                self.races.append(int(filename[2]))
                
    def __getitem__(self, index):

        img = Image.open(self.images[index]).convert('RGB')
        img = self.transform(img)

        age = self.ages[index]
        gender = self.genders[index]
        race = self.races[index]

        sample = {'image':img, 'age': age, 'gender': gender, 'race':race}
        return sample
    
    def __len__(self):
         return len(self.images)