In [1]:
import json
import os
from tqdm import tqdm
from collections import defaultdict
import numpy as np

In [2]:
def relabel_class(c):
    '''
    maps hexadecimal class value (string) to a decimal number
    returns:
    - 0 through 9 for classes representing respective numbers
    - 10 through 35 for classes representing respective uppercase letters
    - 36 through 61 for classes representing respective lowercase letters
    '''
    if c.isdigit() and int(c) < 40:
        return (int(c) - 30)
    elif int(c, 16) <= 90: # uppercase
        return (int(c, 16) - 55)
    else:
        return (int(c, 16) - 61)

def add_class_images(class_root_path, user_to_class_to_imagepath):
    """
    class_root_path is root of class directory
    user_to_class_to_imagepath: defaultdict(lambda: defaultdict(list))
                                is a dictionary mapping user->class->imagepath

    use .mit file's mapping information to add every example of this class to the correct user
    """

    class_hex = os.path.basename(class_root_path)
    class_label = relabel_class(class_hex)
    print(f"Reading class hex {class_hex}, char {chr(int(class_hex, 16))}")
            
    for hsf_fname in os.listdir(class_root_path):
        # read mit files which contain metadata
        if 'mit' in hsf_fname:
            with open(os.path.join(class_root_path, hsf_fname)) as f:
                class_images_details = list(map(
                    lambda x: x.strip().split(), f.readlines()))
            
            # drop first line of class_images_details, since it only contains count
            count, class_images_details = int(class_images_details[0][0]), class_images_details[1:]

            # iterate over class_images_details
            for image_fname, user_info in class_images_details:
                user_id = user_info.split("/")[0]
                
                # add root directory and hsf directory to image_fname
                full_image_path = os.path.join(
                    class_root_path,
                    hsf_fname.split(".")[0], # remove mit extension
                    image_fname
                ) 

                # add data to class/user dict
                user_to_class_to_imagepath[user_id][class_label].append(full_image_path)

In [3]:
femnist_root = '/home/oscarli/projects/leaf/data/femnist/'
user_to_class_to_imagepath = defaultdict(lambda: defaultdict(list))
classes_home = os.path.join(femnist_root, "data/raw_data/by_class")
for class_name in sorted(os.listdir(classes_home)):
    class_root_path = os.path.join(
        classes_home,
        class_name)
    add_class_images(
        class_root_path,
        user_to_class_to_imagepath)

Reading class hex 30, char 0
Reading class hex 31, char 1
Reading class hex 32, char 2
Reading class hex 33, char 3
Reading class hex 34, char 4
Reading class hex 35, char 5
Reading class hex 36, char 6
Reading class hex 37, char 7
Reading class hex 38, char 8
Reading class hex 39, char 9
Reading class hex 41, char A
Reading class hex 42, char B
Reading class hex 43, char C
Reading class hex 44, char D
Reading class hex 45, char E
Reading class hex 46, char F
Reading class hex 47, char G
Reading class hex 48, char H
Reading class hex 49, char I
Reading class hex 4a, char J
Reading class hex 4b, char K
Reading class hex 4c, char L
Reading class hex 4d, char M
Reading class hex 4e, char N
Reading class hex 4f, char O
Reading class hex 50, char P
Reading class hex 51, char Q
Reading class hex 52, char R
Reading class hex 53, char S
Reading class hex 54, char T
Reading class hex 55, char U
Reading class hex 56, char V
Reading class hex 57, char W
Reading class hex 58, char X
Reading class 

In [4]:
len(user_to_class_to_imagepath.keys())

3597

In [5]:
per_class_minimum = 2 # every class has to have at least this number of examples
per_user_minimum = 2 # every user has to have at least this number of classes
n_shot = 1

In [6]:
user_to_numcl = {}

In [7]:
for user, cl_to_imagepath in user_to_class_to_imagepath.items():
    for cl, pathlist in list(cl_to_imagepath.items()):
        if len(pathlist) < per_class_minimum:
            del cl_to_imagepath[cl]
    user_to_numcl[user] = len(cl_to_imagepath)

In [8]:
print(len(user_to_class_to_imagepath))

3597


In [9]:
np.mean([val for val in user_to_numcl.values()])

21.368084514873505

In [10]:
for user in user_to_numcl:
    if user_to_numcl[user] < per_user_minimum:
        del user_to_class_to_imagepath[user]

In [11]:
print(len(user_to_class_to_imagepath))

3585


In [12]:
import random

In [13]:
random.seed(a=100)
all_users = list(sorted(user_to_class_to_imagepath.keys()))
# random.shuffle(all_users)

In [14]:
user_to_class_to_sq_to_imagepath = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

In [15]:
for user in all_users:
    cl_list = sorted(user_to_class_to_imagepath[user].keys())
#     chosen_classes = random.sample(cl_list, k=per_user_minimum)
    for cl in cl_list:
#         chosen_examples = random.sample(user_to_class_to_imagepath[user][cl], k=per_class_minimum)
        chosen_examples = user_to_class_to_imagepath[user][cl]
        random.shuffle(chosen_examples)
        user_to_class_to_sq_to_imagepath[user][cl]['support'].extend(chosen_examples[:n_shot])
        user_to_class_to_sq_to_imagepath[user][cl]['query'].extend(chosen_examples[n_shot:])

In [16]:
user_to_class_to_sq_to_imagepath['f0620_49']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {0: defaultdict(list,
                         {'support': ['/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01319.png'],
                          'query': ['/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01322.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01327.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01320.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01317.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01318.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/30/hsf_1/hsf_1_01328.png',
                           '/home/oscarli/projects

In [18]:
folder_name = f'fixedsq_atleast{per_user_minimum}class{n_shot}shot{per_class_minimum - n_shot}query_split'
os.mkdir(folder_name)

In [19]:
cut_off = 0.7, 0.85 # percentage of train, followed by percentage of test
train_users = all_users[:int(len(user_to_class_to_imagepath) * cut_off[0])]
val_users = all_users[int(len(user_to_class_to_imagepath) * cut_off[0]):int(len(user_to_class_to_imagepath) * cut_off[1])]
test_users = all_users[int(len(user_to_class_to_imagepath) * cut_off[1]):]

In [20]:
print(len(train_users))
print(len(val_users))
print(len(test_users))

2509
538
538


In [21]:
with open(f"{folder_name}/base.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in train_users}, f)
with open(f"{folder_name}/val.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in val_users}, f)
with open(f"{folder_name}/novel.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in test_users}, f)

In [24]:
user_to_class_to_sq_to_imagepath[train_users[0]]

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {9: defaultdict(list,
                         {'support': ['/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01252.png'],
                          'query': ['/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01261.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01259.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01260.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01258.png',
                           '/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/39/hsf_1/hsf_1_01256.png']}),
             5: defaultdict(list,
                         {'support': ['/home/oscarli/projects/leaf/data/femnist/data/raw_data/by_class/35/hsf_1/hsf_1_01191.png'],
 

################################

Checking the raw data statistics

In [16]:
sorted([len(cl_to_imagepath) for cl_to_imagepath in user_to_class_to_imagepath.values()])[:10]
# the writers with the smallest number of classes

[9, 10, 10, 10, 10, 10, 10, 10, 10, 10]

In [18]:
for user, cl_to_imagepath in user_to_class_to_imagepath.items():
    print(sorted([len(x) for x in cl_to_imagepath.values()])[-5:])

[20, 21, 22, 24, 31]
[20, 21, 21, 23, 29]
[20, 22, 22, 26, 36]
[22, 22, 23, 23, 28]
[16, 16, 17, 18, 22]
[18, 21, 22, 24, 33]
[14, 20, 21, 22, 24]
[17, 21, 22, 25, 31]
[20, 20, 23, 25, 38]
[16, 16, 21, 23, 26]
[21, 22, 23, 26, 39]
[21, 22, 26, 28, 35]
[15, 16, 19, 24, 25]
[15, 17, 18, 19, 20]
[19, 19, 22, 27, 31]
[18, 20, 20, 25, 25]
[14, 15, 17, 20, 26]
[19, 22, 25, 26, 38]
[13, 19, 20, 25, 31]
[17, 18, 21, 22, 29]
[20, 22, 22, 30, 32]
[13, 13, 13, 14, 14]
[17, 19, 19, 20, 24]
[20, 20, 26, 26, 32]
[17, 18, 18, 20, 31]
[16, 16, 19, 19, 27]
[16, 18, 22, 22, 25]
[18, 19, 22, 26, 38]
[15, 18, 22, 26, 26]
[12, 12, 13, 13, 13]
[12, 13, 13, 13, 17]
[20, 22, 23, 25, 39]
[13, 13, 14, 17, 18]
[16, 17, 17, 21, 28]
[14, 14, 18, 23, 35]
[19, 20, 22, 23, 34]
[16, 21, 23, 24, 30]
[15, 22, 24, 26, 34]
[15, 18, 19, 21, 31]
[20, 21, 22, 22, 29]
[13, 13, 20, 21, 24]
[13, 14, 20, 23, 25]
[15, 16, 17, 22, 25]
[14, 14, 18, 27, 28]
[17, 19, 22, 26, 29]
[13, 15, 18, 19, 25]
[20, 22, 23, 26, 36]
[21, 21, 23, 