In [1]:
import os
import pandas
import matplotlib.pyplot as plt
import seaborn
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CelebA
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid

plt.rcParams["savefig.bbox"] = 'tight'

In [2]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(148),
        transforms.Resize(64),
        transforms.ToTensor()
    ]
)

root_dir = '/media/mountHDD2/data'

trainset = CelebA(root=root_dir, split='train', download=True, transform=transform)
train_dl = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=24)
validset = CelebA(root=root_dir, split='valid', download=True, transform=transform)
valid_dl = DataLoader(validset, batch_size=64, shuffle=False, num_workers=24)
testset = CelebA(root=root_dir, split='test', download=True, transform=transform)
test_dl = DataLoader(testset, batch_size=64, shuffle=False, num_workers=24)

print(len(trainset), len(validset), len(testset))
print(len(train_dl), len(valid_dl), len(test_dl))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
162770 19867 19962
2544 311 312


In [3]:
trainset.attr_names[:5]

['5_o_Clock_Shadow',
 'Arched_Eyebrows',
 'Attractive',
 'Bags_Under_Eyes',
 'Bald']

In [4]:
class CustomCeleb(CelebA):
    def __init__(self, root=root_dir, split='train', download=True, transform=transform):
        super().__init__(root=root_dir, split=split, download=download, transform=transform,
                        target_type = ['attr', 'identity'])
        self.attr_names = self.attr_names[:-1]

    def __len__(self):
        return super().__len__()

    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)

        attr = target[0]
        identity = target[1]

        target = {
            "rec" : img,
            "identity" : identity
        }

        attr_dict = {
            f"attr_{self.attr_names[idx]}" : attr[idx] for idx in range(attr.shape[0])
        }

        target.update(attr_dict)

        return img, target

In [5]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(148),
        transforms.Resize(64),
        transforms.ToTensor()
    ]
)

root_dir = '/media/mountHDD2/data'

trainset = CustomCeleb(root=root_dir, split='train', download=True, transform=transform)
train_dl = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=24)
validset = CustomCeleb(root=root_dir, split='valid', download=True, transform=transform)
valid_dl = DataLoader(validset, batch_size=64, shuffle=False, num_workers=24)
testset = CustomCeleb(root=root_dir, split='test', download=True, transform=transform)
test_dl = DataLoader(testset, batch_size=64, shuffle=False, num_workers=24)

print(len(trainset), len(validset), len(testset))
print(len(train_dl), len(valid_dl), len(test_dl))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
162770 19867 19962
2544 311 312


In [6]:
sample_img, sample_target = trainset[0]

print(sample_img.shape)
print(sample_target)

torch.Size([3, 64, 64])
{'rec': tensor([[[0.6353, 0.4471, 0.4118,  ..., 0.9882, 0.9843, 0.9843],
         [0.6471, 0.5020, 0.3882,  ..., 0.9882, 0.9843, 0.9882],
         [0.6706, 0.4980, 0.4980,  ..., 0.9882, 0.9882, 0.9882],
         ...,
         [0.5373, 0.2549, 0.2510,  ..., 0.5765, 0.5490, 0.5333],
         [0.5490, 0.2275, 0.2471,  ..., 0.6275, 0.5765, 0.5294],
         [0.5137, 0.2627, 0.2902,  ..., 0.6078, 0.5804, 0.5137]],

        [[0.4824, 0.2549, 0.2196,  ..., 0.9373, 0.9333, 0.9255],
         [0.5059, 0.3098, 0.1922,  ..., 0.9373, 0.9333, 0.9373],
         [0.5216, 0.3137, 0.3098,  ..., 0.9373, 0.9373, 0.9333],
         ...,
         [0.3608, 0.1216, 0.0863,  ..., 0.3490, 0.3137, 0.3059],
         [0.3686, 0.0980, 0.0706,  ..., 0.3843, 0.3373, 0.3020],
         [0.3373, 0.1216, 0.0980,  ..., 0.3569, 0.3412, 0.2863]],

        [[0.3255, 0.1373, 0.1176,  ..., 0.8039, 0.8000, 0.7961],
         [0.3137, 0.1882, 0.1020,  ..., 0.8078, 0.8039, 0.8078],
         [0.3373, 0.1922, 

In [7]:
for img, target in train_dl:
    print(img.shape)
    print(target["rec"].shape)
    for key in target:
        if key == "rec":
            continue
        else:
            print(target[key].shape)
    break

torch.Size([64, 3, 64, 64])
torch.Size([64, 3, 64, 64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


In [8]:
for img, target in tqdm(train_dl):
    pass
for img, target in tqdm(valid_dl):
    pass
for img, target in tqdm(test_dl):
    pass

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2544/2544 [07:13<00:00,  5.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 311/311 [00:04<00:00, 67.10it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [00:04<00:00, 69.53it/s]
