In [5]:
import os
from collections import defaultdict
import random
import json

In [6]:
# change this to the path of the folder containing three files:
# identity_CelebA.txt  img_align_celeba  list_attr_celeba.txt
# img_align_celeba is the unzipped folder from img_align_celeba.zip (containing jpeg files)
celeba_root = '/home/oscarli/projects/leaf/data/celeba/data/raw/'

In [7]:
with open(os.path.join(celeba_root, 'list_attr_celeba.txt'), 'r') as file:
    file_iter = iter(file)
    _ = next(file_iter) # throw away the number of images 202599
    attribute_names = next(file_iter)
    # get the attribute name and construct a mapping from the attribute name to a unique location index
    attribute_to_index = {attribute: i for i, attribute in enumerate(attribute_names.split())}
    
    image_to_attribute = {}
    # every remaining line is the name and attribute values for a specific image
    for line in file_iter:
        sample = line.split()
        if len(sample) != 41:
            raise(RuntimeError("# Annotated face attributes of CelebA dataset should not be different from 40"))
        image_name = sample[0]
        image_to_attribute[image_name] = [int(i) for i in sample[1:]]

In [8]:
print(f'the total number of images is {len(image_to_attribute.keys())}')

the total number of images is 202599


In [9]:
# define a labelling function based on the attributes of interest
labelling_function = lambda x: int((x[attribute_to_index['Smiling']] + 1) / 2)

In [10]:
print(f"the label for image '000019.jpg' is {labelling_function(image_to_attribute['000019.jpg'])}")

the label for image '000019.jpg' is 0


In [13]:
client_to_class_to_imagepathlist = defaultdict(lambda: defaultdict(list))

In [14]:
with open(os.path.join(celeba_root, 'identity_CelebA.txt'), 'r') as file:
    for line in file:
        # every is an image and the corresponding celebrity's id
        image_name, celeb_id = line.split()
        label = labelling_function(image_to_attribute[image_name])
        full_image_path = os.path.join(celeba_root, 'img_align_celeba', image_name)
        client_to_class_to_imagepathlist[celeb_id][label].append(full_image_path)

In [15]:
client_to_class_to_imagepathlist['2880']

defaultdict(list,
            {1: ['/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/000001.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/000404.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/003415.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/018062.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/025244.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/047978.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/049142.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/052623.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/053184.jpg',
              '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/053311.jpg',
              '/home/oscarli/projects/l

In [46]:
min_num_examples_per_class = 3
total_client_list = []
all_cl_list = [0, 1] # this depends on the labelling function and should be changed accordingly
for client_id in client_to_class_to_imagepathlist.keys():
    have_enough_examples = True
    for cl in all_cl_list: # the client needs to have at least min_num_examples_per_class for every single class
        # the client cannot have missing class
        if len(client_to_class_to_imagepathlist[client_id][cl]) < min_num_examples_per_class:
            have_enough_examples = False
            break
    if have_enough_examples:
        total_client_list.append(client_id)

In [47]:
print(f'the total number of clients with at least {min_num_examples_per_class} for each class is {len(total_client_list)}')

the total number of clients with at least 3 for each class is 7142


In [48]:
total_client_list = sorted(total_client_list)

In [49]:
random.seed(a=42)
random.shuffle(total_client_list)

In [50]:
total_client_list[:5]

['2584', '6858', '2219', '8424', '8644']

In [51]:
num_train = int(len(total_client_list) * 0.6)

In [52]:
num_val = int(len(total_client_list) * 0.2)

In [53]:
base_client_list = total_client_list[:num_train]
val_client_list = total_client_list[num_train:num_train + num_val]
novel_client_list = total_client_list[num_train + num_val:]

In [54]:
json_name_to_client_list = {
    'base.json': base_client_list,
    'val.json': val_client_list,
    'novel.json': novel_client_list,
}

In [59]:
print(f'meta-train client number {len(base_client_list)}')
print(f'meta-val client number {len(val_client_list)}')
print(f'meta-test client number {len(novel_client_list)}')

meta-train client number 4285
meta-val client number 1428
meta-test client number 1429


In [60]:
for json_name, client_list in json_name_to_client_list.items():
    with open(json_name, 'w') as f:
        json.dump(obj={client_id: client_to_class_to_imagepathlist[client_id] for client_id in client_list},
                  fp=f)


In [62]:
from PIL import Image

In [63]:
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    return img

In [66]:
client_to_class_to_imagepathlist['2'][0]

['/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/016188.jpg',
 '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/051523.jpg',
 '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/111618.jpg',
 '/home/oscarli/projects/leaf/data/celeba/data/raw/img_align_celeba/112468.jpg']

In [81]:
load_image(client_to_class_to_imagepathlist['1'][0][5]).size

(178, 218)