In [1]:
import pandas as pd
import numpy as np
from skimage import io, transform
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image

In [3]:
class FaceDataSet(Dataset):

    def __init__(self, csv_file, root_dir, transform=None, target_transform=None):
        """
            csv_file (string): path to total.csv
            root_dir (string): path to directory with images
                (empty if total.csv contains full path)
            transform (callable, optional): optional transform to be applied to sample
                (convert image to torch.Tensor by default)
            target_transform (callable, optional): optional transform to be applied to target
            
        """
        self.meta = pd.read_csv(csv_file, index_col=0)
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.meta.shape[0]

    def __getitem__(self, idx):
        img_name = self.root_dir + self.meta.urls[idx]
        image = io.imread(img_name)
        pic = Image.open(img_name)
        coords_str = self.meta.face_coords[idx]
        #coords has string type since saved as csv
        coords = [float(x) for x in coords_str[1:-1].split()]
        target = self.meta.age_cluster[idx]
        #sample = {'pic': pic, 'image': image, 'coords': coords}
        sample = image
      
        if self.transform:
            sample = self.transform(sample)
        else:
            to_tens = transforms.ToTensor()
            sample = to_tens(sample)
            
        if self.target_transform:
            target = self.target_transform(target)
            
        return (sample, target)

In [12]:
ds = FaceDataSet('../../csv/total.csv', '../')

In [13]:
ds.meta

Unnamed: 0,urls,face_coords,age_cluster
4012,../data/imdb_crop/00/nm0000100_rm1001569280_19...,[547.33361035 79.47623005 896.72664558 428.86...,3
28610,../data/imdb_crop/00/nm0000200_rm2120191744_19...,[214.07789842 48.77991255 275.49664313 110.19...,3
135422,../data/imdb_crop/00/nm0002100_rm985837568_197...,[ 303.17238157 453.73457236 978.65424011 112...,2
185827,../data/imdb_crop/01/nm1107001_rm3321071360_19...,[ 85.38616999 85.38616999 237.25567597 237.25...,2
224698,../data/imdb_crop/00/nm4415900_rm2517889024_19...,[1417.76028854 710.51214427 1993.15386112 128...,1
350199,../data/imdb_crop/01/nm4652001_rm4133345280_20...,[1427.456 460.8 1572.864 606.208],0
369045,../data/imdb_crop/01/nm0365501_rm2791093760_19...,[433.95876134 69.77048476 521.67509889 157.48...,2
455183,../data/imdb_crop/01/nm2692301_rm3067461376_19...,[115.87385706 85.21256184 183.9677713 153.30...,2


In [25]:
for i in ds.meta.index:
    print(ds[i][1])

3
3
2
2
1
0
2
2
