In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset
import csv
import os
from collections import namedtuple
import numpy as np
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg, download_file_from_google_drive, extract_archive
from torchvision.datasets.vision import VisionDataset
CSV = namedtuple("CSV", ["header", "index", "data"])

In [2]:
dataroot = '/home/fusang/Desktop/mcr2/data/celebA/celeba'
image_size = 32

In [3]:
def _load_csv(
    filename: str,
    header = None,
):
    with open(os.path.join(dataroot, filename)) as csv_file:
        data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))

    if header is not None:
        headers = data[header]
        data = data[header + 1 :]
    else:
        headers = []

    indices = [row[0] for row in data]
    data = [row[1:] for row in data]
    data_int = [list(map(int, i)) for i in data]

    return CSV(headers, indices, torch.tensor(data_int))

In [4]:
split_map = {
    "train": 0,
    "valid": 1,
    "test": 2,
    "all": None,
}
split = 'train'
split_ = split_map[verify_str_arg(split.lower(), "   ", ("train", "valid", "test", "all"))]
# print("using data split:",split_)
splits = _load_csv("list_eval_partition.txt")
# print(splits.data.shape)
identity = _load_csv("identity_CelebA.txt")
# bbox = _load_csv("list_bbox_celeba.txt", header=1)
# landmarks_align = _load_csv("list_landmarks_align_celeba.txt", header=1)
attr = _load_csv("list_attr_celeba.txt", header=1)

In [5]:
mask1 = slice(None) if split_ is None else (splits.data == split_).squeeze() # mask for train valid and test data

In [6]:
identity = identity.data[:]
# bbox = bbox.data[:]
# andmarks_align = landmarks_align.data[:]
attr_names = attr.header
attr = attr.data[:]
attr = torch.div(attr + 1, 2, rounding_mode="floor") # map from {-1, 1} to {0, 1}

classes = np.array([19, 31, 34])
attr = attr.cpu().detach().numpy()
attr = attr[:, classes]
num_attrs = int(len(classes))
C = np.array([2 ** x for x in range(num_attrs)]).reshape(num_attrs,1)
class_list = np.dot(attr, C)
# print("class_list", class_list.shape)
targets_ = np.squeeze(class_list).tolist()

In [7]:
print("showing class information")
attr_names = np.array(attr_names)[classes]
print(attr_names)
selected_labels = [0,1,2,3,4,5,6,7]
num_imgs_per_class = 100000
selected_pos = np.array([])
mask2 = np.zeros(splits.data.shape[0],dtype=bool) # mask of label classes
print(f'selected label:{selected_labels}')
print(f'maxnumber of imgs per class:{num_imgs_per_class}')

print("##################################################")
print("statistic information for the whole celebA dataset")
for i in selected_labels:
    temp = class_list == i
    print(f"class {i}: number {sum(temp)} before masking")
    pos_temp,_ = np.where(class_list == i)
    pos_temp = np.array(pos_temp)
    np.random.shuffle(pos_temp)
    selected_pos = np.concatenate((selected_pos, pos_temp[:num_imgs_per_class]))
print("##################################################")

showing class information
['High_Cheekbones' 'Smiling' 'Wearing_Earrings']
selected label:[0, 1, 2, 3, 4, 5, 6, 7]
maxnumber of imgs per class:100000
##################################################
statistic information for the whole celebA dataset
class 0: number [81666] before masking
class 1: number [10187] before masking
class 2: number [17141] before masking
class 3: number [55329] before masking
class 4: number [9974] before masking
class 5: number [3103] before masking
class 6: number [1629] before masking
class 7: number [23570] before masking
##################################################


In [9]:
def find_subset_mask(remove_label, class_list, attr_names, split):
    split_map = {
    "train": 0,
    "valid": 1,
    "test": 2,
    "all": None,
    }
    split = split
    split_ = split_map[verify_str_arg(split.lower(), "   ", ("train", "valid", "test", "all"))]
    splits = _load_csv("list_eval_partition.txt")

    mask1 = slice(None) if split_ is None else (splits.data == split_).squeeze() # mask for train valid and test data
    
    selected_labels = [0,1,2,3,4,5,6,7]
    selected_labels.remove(remove_label)
    num_imgs_per_class = 2000
    selected_pos = np.array([])
    mask2 = np.zeros(splits.data.shape[0],dtype=bool) # mask of label classes
    print(f'selected label:{selected_labels}')
    print(f'maxnumber of imgs per class:{num_imgs_per_class}')

    print("##################################################")
    print("statistic information for the whole celebA dataset")
    for i in selected_labels:
        temp = class_list == i
        print(f"class {i}: number {sum(temp)} before masking")
        pos_temp,_ = np.where(class_list == i)
        pos_temp = np.array(pos_temp)
        np.random.shuffle(pos_temp)
        selected_pos = np.concatenate((selected_pos, pos_temp[:num_imgs_per_class]))
    print("##################################################")

    selected_pos = np.array(selected_pos,dtype=int)
    mask2[selected_pos] = True
    # mask2 = torch.from_numpy(mask2)

    mask = mask1 * mask2
    filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
    targets = [targets_[i] for i in torch.squeeze(torch.nonzero(mask))]

    print("##################################################")
    print("statistic information for the CUNSTOM dataset")
    # show class information
    print("showing class information")
    attr_names = np.array(attr_names)[classes]
    print(attr_names)

    class_list = np.array(targets, dtype=int)
    for i in range(2**(len(classes))):
        temp = class_list == i
        # pos_temp,_ = np.where(class_list == i)
        # pos_temp = np.array(pos_temp)
        print(f"class {i}: number {sum(temp)}")
    print("##################################################")
    np.save(f'celeba_no{remove_label}_cls_{split}.npy',mask)
    print(f'mask saving to celeba_no{remove_label}_cls_{split_}.npy')

In [10]:
for i in range(8):
    for split in ['train','valid','test']:
        find_subset_mask(i, class_list, attr_names, split)


selected label:[1, 2, 3, 4, 5, 6, 7]
maxnumber of imgs per class:2000
##################################################
statistic information for the whole celebA dataset
class 1: number [10187] before masking
class 2: number [17141] before masking
class 3: number [55329] before masking
class 4: number [9974] before masking
class 5: number [3103] before masking
class 6: number [1629] before masking
class 7: number [23570] before masking
##################################################
##################################################
statistic information for the CUNSTOM dataset
showing class information


IndexError: index 19 is out of bounds for axis 0 with size 3

In [None]:
import imageio
from utils.image_transform import NumpyResize, pil_loader

def select_subset(mask_cls, inputPath, outputPath, maxNumber):
    splits = _load_csv("/home/fusang/Desktop/pytorch_GAN_zoo/data/list_eval_partition.txt")
    # mask = np.load(mask_npy)
    # mask = np.array(mask_cls, dtype=bool)
    # mask = torch.from_numpy(mask)
    imgList = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask_cls))]
    numImgs = len(imgList)
    print('Number of Images:', numImgs)

    if not os.path.isdir(outputPath):
        os.mkdir(outputPath)

    for index, item in enumerate(imgList[:maxNumber]):
        path = os.path.join(inputPath, item)
        img = np.array(pil_loader(path))
        path = os.path.join(outputPath, item)
        imageio.imwrite(path, img)
    print("Finished saving subdataset to", outputPath)

In [None]:
def create_cls_dataset(class_list, attr_names, split, inputPath, outputPath):
    if not os.path.isdir(outputPath):
        os.mkdir(outputPath)
    if not os.path.isdir(os.path.join(outputPath,split)):
        os.mkdir(os.path.join(outputPath,split))
    split_map = {
    "train": 0,
    "valid": 1,
    "test": 2,
    "all": None,
    }
    split = split
    split_ = split_map[verify_str_arg(split.lower(), " ", ("train", "valid", "test", "all"))]
    splits = _load_csv("list_eval_partition.txt")

    mask1 = slice(None) if split_ is None else (splits.data == split_).squeeze() # mask for train valid and test data
    
    selected_labels = [0,1,2,3,4,5,6,7]
    num_imgs_per_class = 10000
    selected_pos = np.array([])
    # mask2 = np.zeros(splits.data.shape[0],dtype=bool) # mask of label classes
    print(f'selected label:{selected_labels}')
    print(f'maxnumber of imgs per class:{num_imgs_per_class}')

    print("##################################################")
    print("statistic information for the whole celebA dataset")
    for i in selected_labels:
        temp = class_list == i
        temp = torch.from_numpy(np.squeeze(temp))
        # print(temp.shape)
        # print(mask1.shape)
        print(f"class {i}: number {sum(temp)} before masking")
        mask = mask1*temp
        print(mask.shape)
        ouputPathTemp= os.path.join(outputPath, split, str(i))
        print(outputPath)
        select_subset(mask ,inputPath, ouputPathTemp, num_imgs_per_class)
    print("##################################################")

In [None]:
inputPath = "data/celebA/celeba/img_align_celeba"
outputPath = 'celebA_attrs1_cls'
create_cls_dataset(class_list, attr_names, 'train', inputPath, outputPath)
create_cls_dataset(class_list, attr_names, 'valid', inputPath, outputPath)
create_cls_dataset(class_list, attr_names, 'test', inputPath, outputPath)

0: ['Black_Hair' 'Eyeglasses' 'Male'] 

# DRAFT