In [71]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda
from torchvision.io import read_image
import os
import pandas as pd

In [72]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        dataframe = pd.read_csv(annotations_file, sep='\s+', skiprows=1)
        self.ds_img_labels = dataframe["Smiling"].replace(-1, 0)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.ds_img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.ds_img_labels.index[idx])
        image = read_image(img_path)
        label = self.ds_img_labels.iloc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [73]:
annotations_file = r"D:\SSD_Optimization\User\Desktop\Celeba Dataset\Anno\list_attr_celeba.txt"
img_dir = r"D:\SSD_Optimization\User\Desktop\Celeba Dataset\Img\align_cutted\img_align_celeba"

test_customDataset = CustomImageDataset(annotations_file, img_dir,
                                        transform=transforms.Compose([
                                            transforms.ToPILImage(),
                                            transforms.Resize((84,84)),
                                            transforms.ToTensor()
                                        ]),
                                        target_transform=Lambda(lambda y: torch.zeros(2)
                                            .scatter_(dim=0, index=torch.tensor(y, dtype=torch.int64), value=1))
                                        )

In [74]:
print(test_customDataset[0][0].shape)
print(test_customDataset[2][1])

torch.Size([3, 84, 84])
tensor([1., 0.])


In [None]:
# Remove Celebs with less than 5 photos
# https://github.com/TalwalkarLab/leaf/blob/master/data/celeba/preprocess.sh

In [None]:
# Add CNN model
# https://github.com/TalwalkarLab/leaf/blob/master/models/celeba/cnn.py
# Check the whole code for Celeba and then train it