In [None]:
import json
import os
from tqdm import tqdm
from collections import defaultdict
import numpy as np

In [None]:
def relabel_class(c):
    '''
    maps hexadecimal class value (string) to a decimal number
    returns:
    - 0 through 9 for classes representing respective numbers
    - 10 through 35 for classes representing respective uppercase letters
    - 36 through 61 for classes representing respective lowercase letters
    '''
    if c.isdigit() and int(c) < 40:
        return (int(c) - 30)
    elif int(c, 16) <= 90: # uppercase
        return (int(c, 16) - 55)
    else:
        return (int(c, 16) - 61)

def add_class_images(class_root_path, user_to_class_to_imagepath):
    """
    class_root_path is root of class directory
    user_to_class_to_imagepath: defaultdict(lambda: defaultdict(list))
                                is a dictionary mapping user->class->imagepath

    use .mit file's mapping information to add every example of this class to the correct user
    """

    class_hex = os.path.basename(class_root_path)
    class_label = relabel_class(class_hex)
    print(f"Reading class hex {class_hex}, char {chr(int(class_hex, 16))}")
            
    for hsf_fname in os.listdir(class_root_path):
        # read mit files which contain metadata
        if 'mit' in hsf_fname:
            with open(os.path.join(class_root_path, hsf_fname)) as f:
                class_images_details = list(map(
                    lambda x: x.strip().split(), f.readlines()))
            
            # drop first line of class_images_details, since it only contains count
            count, class_images_details = int(class_images_details[0][0]), class_images_details[1:]

            # iterate over class_images_details
            for image_fname, user_info in class_images_details:
                user_id = user_info.split("/")[0]
                
                # add root directory and hsf directory to image_fname
                full_image_path = os.path.join(
                    class_root_path,
                    hsf_fname.split(".")[0], # remove mit extension
                    image_fname
                ) 

                # add data to class/user dict
                user_to_class_to_imagepath[user_id][class_label].append(full_image_path)

In [None]:
femnist_root = ''
user_to_class_to_imagepath = defaultdict(lambda: defaultdict(list))
classes_home = os.path.join(femnist_root, "data/raw_data/by_class")
for class_name in sorted(os.listdir(classes_home)):
    class_root_path = os.path.join(
        classes_home,
        class_name)
    add_class_images(
        class_root_path,
        user_to_class_to_imagepath)

In [None]:
len(user_to_class_to_imagepath.keys())

In [None]:
per_class_minimum = 2 # every class has to have at least this number of examples
per_user_minimum = 2 # every user has to have at least this number of classes
n_shot = 1

In [None]:
user_to_numcl = {}

In [None]:
for user, cl_to_imagepath in user_to_class_to_imagepath.items():
    for cl, pathlist in list(cl_to_imagepath.items()):
        if len(pathlist) < per_class_minimum:
            del cl_to_imagepath[cl]
    user_to_numcl[user] = len(cl_to_imagepath)

In [None]:
print(len(user_to_class_to_imagepath))

In [None]:
np.mean([val for val in user_to_numcl.values()])

In [None]:
for user in user_to_numcl:
    if user_to_numcl[user] < per_user_minimum:
        del user_to_class_to_imagepath[user]

In [None]:
print(len(user_to_class_to_imagepath))

In [None]:
import random

In [None]:
random.seed(a=100)
all_users = list(sorted(user_to_class_to_imagepath.keys()))
# random.shuffle(all_users)

In [None]:
user_to_class_to_sq_to_imagepath = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

In [None]:
for user in all_users:
    cl_list = sorted(user_to_class_to_imagepath[user].keys())
#     chosen_classes = random.sample(cl_list, k=per_user_minimum)
    for cl in cl_list:
#         chosen_examples = random.sample(user_to_class_to_imagepath[user][cl], k=per_class_minimum)
        chosen_examples = user_to_class_to_imagepath[user][cl]
        random.shuffle(chosen_examples)
        user_to_class_to_sq_to_imagepath[user][cl]['support'].extend(chosen_examples[:n_shot])
        user_to_class_to_sq_to_imagepath[user][cl]['query'].extend(chosen_examples[n_shot:])

In [None]:
user_to_class_to_sq_to_imagepath['f0620_49']

In [None]:
folder_name = f'fixedsq_atleast{per_user_minimum}class{n_shot}shot{per_class_minimum - n_shot}query_split'
os.mkdir(folder_name)

In [None]:
cut_off = 0.7, 0.85 # percentage of train, followed by percentage of test
train_users = all_users[:int(len(user_to_class_to_imagepath) * cut_off[0])]
val_users = all_users[int(len(user_to_class_to_imagepath) * cut_off[0]):int(len(user_to_class_to_imagepath) * cut_off[1])]
test_users = all_users[int(len(user_to_class_to_imagepath) * cut_off[1]):]

In [None]:
print(len(train_users))
print(len(val_users))
print(len(test_users))

In [None]:
with open(f"{folder_name}/base.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in train_users}, f)
with open(f"{folder_name}/val.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in val_users}, f)
with open(f"{folder_name}/novel.json", 'w') as f:
    json.dump({user: user_to_class_to_sq_to_imagepath[user] for user in test_users}, f)

In [None]:
user_to_class_to_sq_to_imagepath[train_users[0]]

################################

Checking the raw data statistics

In [None]:
sorted([len(cl_to_imagepath) for cl_to_imagepath in user_to_class_to_imagepath.values()])[:10]
# the writers with the smallest number of classes

In [None]:
for user, cl_to_imagepath in user_to_class_to_imagepath.items():
    print(sorted([len(x) for x in cl_to_imagepath.values()])[-5:])