## Using PyTorch Dataset Loading Utilities for Custom Datasets (Face Images from CelebA

This notebook provides an example for how to load a dataset from an HDF5 file created from a CSV file, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official

- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)
- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation

In this example, we are using the CelebA face image dataset, which is available at http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. To execute the following examples, you need to download the "img_align_celeba.zip" (1.34 GB) file from the website and unzip in the current directory where this notebook is located. Similarly, download the attribute list "list_attr_celeba.txt" (25.48 MB) into this directory.

In [None]:
import pandas as pd
import numpy as np
import os

import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image

### Preprocesing the Dataset

划分数据集为训练集(80%)和测试集(20%)，并将image name和对应label分别存放到txt文件中

In [None]:
num_train_examples = int(df.shape[0]*0.8)
df_train = df.iloc[:num_train_examples, :]
df_test = df.iloc[num_train_examples:, :]

df_train.to_csv('celeba_gender_attr_train.txt', sep=" ")
df_test.to_csv('celeba_gender_attr_test.txt', sep=" ")

In [None]:
print('Number of male and female images in training dataset:')
np.bincount(df_train['Male'].values)

In [None]:
print('Number of male and female images in test dataset:')
np.bincount(df_test['Male'].values)

In [None]:
pd.read_csv("celeba_gender_attr_train.txt", nrows=3, sep=" ", index_col=0)

In [None]:
img = Image.open('img_align_celeba/000001.jpg')
print(np.asarray(img, dtype=np.uint8).shape)
plt.imshow(img)

### 自定义Dataset



In [None]:
class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, txt_path, img_dir, transform=None):
    
        df = pd.read_csv(txt_path, sep=" ", index_col=0)
        self.img_dir = img_dir
        self.txt_path = txt_path
        self.img_names = df.index.values
        self.y = df['Male'].values
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]


#### transforms
1. convert images to grayscale
2. normalize the images (here: dividing by 255)
3. converting the image arrays into PyTorch tensors

In [None]:
custom_transform = transforms.Compose([transforms.Grayscale(),                                       
                                       #transforms.Lambda(lambda x: x/255.),
                                       transforms.ToTensor()])

train_dataset = CelebaDataset(txt_path='celeba_gender_attr_train.txt',
                              img_dir='img_align_celeba/',
                              transform=custom_transform)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=128,
                          shuffle=True,
                          num_workers=4)

### iterator

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

num_epochs = 2
for epoch in range(num_epochs):

    for batch_idx, (x, y) in enumerate(train_loader):
        
        print('Epoch:', epoch+1, end='')
        print(' | Batch index:', batch_idx, end='')
        print(' | Batch size:', y.size()[0])
        
        x = x.to(device)
        y = y.to(device)
        break

In [None]:
x.shape

In [None]:
one_image = x[0].permute(1, 2, 0)
one_image.shape

In [None]:
# note that imshow also works fine with scaled
# images in [0, 1] range.

plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='gray');