In [1]:
import os
import copy
import numpy as np

from data.dataset import KuzushijiMNIST, get_subset, get_dataset, statstic_info
from utils.seed import set_seed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 7
trials = 5
request_rate = 0.2
val_rate = 0.2

prepared_data_save_path_template = './runs/prepared_data/%s/trial_%s/'
dataset_name = 'mnist'

raw_train_set, raw_test_set = get_dataset(dataset_name)
num_classes = len(raw_train_set.classes)
raw_train_set_stat = statstic_info(raw_train_set)
print(f'Train sample num: {raw_train_set_stat[0]}, num_classes: {raw_train_set_stat[1]}, class_sample_num: {raw_train_set_stat[2]}')


Train sample num: 60000, num_classes: 10, class_sample_num: [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]


In [3]:
for trial in range(trials):
    print(f'\n{"-"*10} Trial {trial}, set seed as {seed+trial} {"-"*10}')
    set_seed(seed + trial)

    train_set = copy.deepcopy(raw_train_set)
    test_set = copy.deepcopy(raw_test_set)

    shuffled_idx = np.arange(len(train_set))
    np.random.shuffle(shuffled_idx)
    split_at = int(len(train_set) * (1 - request_rate))
    train_idx, request_idx = shuffled_idx[:split_at], shuffled_idx[split_at:]

    val_at = int(len(train_idx) * (1 - val_rate))
    val_idx = train_idx[val_at:]
    train_idx = train_idx[:val_at]
    print(f'{len(train_idx)}, {len(val_idx)}, {len(request_idx)}: {len(train_idx) + len(val_idx) + len(request_idx)}')
    print(train_idx[:10])

    # save indexs
    save_path = prepared_data_save_path_template % (dataset_name, trial)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    np.save(os.path.join(save_path, 'train_idx.npy'), train_idx)
    np.save(os.path.join(save_path, 'val_idx.npy'), val_idx)
    np.save(os.path.join(save_path, 'request_idx.npy'), request_idx)

    request_set = get_subset(train_set, request_idx)
    print(f'Statistic info of request set: {statstic_info(request_set)}')


---------- Trial 0, set seed as 7 ----------
38400, 9600, 12000: 60000
[25170 10935 43596 13970 24883  7581  1712  3004 12834 50176]
Statistic info of request set: (12000, 10, [1227, 1346, 1200, 1162, 1188, 1081, 1199, 1219, 1183, 1195])

---------- Trial 1, set seed as 8 ----------
38400, 9600, 12000: 60000
[32997 56059 13395 50681 27244 38348 38157  2062 27732 21545]
Statistic info of request set: (12000, 10, [1196, 1349, 1206, 1203, 1180, 1039, 1240, 1244, 1189, 1154])

---------- Trial 2, set seed as 9 ----------
38400, 9600, 12000: 60000
[13131 14170 37932 50105 40269 11929 52723 58524 46502 42902]
Statistic info of request set: (12000, 10, [1186, 1312, 1204, 1185, 1170, 1125, 1205, 1281, 1156, 1176])

---------- Trial 3, set seed as 10 ----------
38400, 9600, 12000: 60000
[ 4883 28477  8527 38347  3363 36370  6083 34418 18938 17773]
Statistic info of request set: (12000, 10, [1207, 1399, 1149, 1170, 1194, 1036, 1171, 1248, 1214, 1212])

---------- Trial 4, set seed as 11 -------