In [2]:
from torch.utils.data import Dataset
from torchvision.datasets import CelebA
from torchvision import transforms

In [4]:
class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.

    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """

    def _check_integrity(self) -> bool:
        return True


In [7]:
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                        transforms.CenterCrop(148),
                                        transforms.Resize(64),
                                        transforms.ToTensor()])

val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                        transforms.CenterCrop(148),
                                        transforms.Resize(64),
                                        transforms.ToTensor()])

train_dataset = MyCelebA('../Data',
                        split='train',
                        transform=train_transforms,
                        download=False)

val_dataset = MyCelebA('../Data',
                        split='test',
                        transform=val_transforms,
                        download=False)


In [8]:
train_dataset

Dataset MyCelebA
    Number of datapoints: 162770
    Root location: ../Data
    Target type: ['attr']
    Split: train
    StandardTransform
Transform: Compose(
               RandomHorizontalFlip(p=0.5)
               CenterCrop(size=(148, 148))
               Resize(size=64, interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [9]:
val_dataset

Dataset MyCelebA
    Number of datapoints: 19962
    Root location: ../Data
    Target type: ['attr']
    Split: test
    StandardTransform
Transform: Compose(
               RandomHorizontalFlip(p=0.5)
               CenterCrop(size=(148, 148))
               Resize(size=64, interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
           )

In [3]:
from get_dataloader import get_train_val_dataloader


In [4]:
train_data_loader, test_data_loader = get_train_val_dataloader('../Data')

In [7]:
sample = next(iter(train_data_loader))

In [8]:
len(sample)

2

In [9]:
feature, label = sample

In [10]:
feature.shape

torch.Size([32, 3, 64, 64])

In [11]:
label.shape

torch.Size([32, 40])