In [68]:
# Set up
import pandas as pd
import random
SEED = 20
random.seed(SEED)

# Load all image paths
dataset_path = '../Datasets/CIFAR10'
data_df = pd.read_csv(dataset_path + '/data.csv')

In [69]:
# Split test dataset and private (data owner) dataset
test_dataset = data_df.groupby('label').sample(frac=.2) 
private_dataset = data_df.iloc[[i for i in data_df.index if i not in test_dataset.index]]

In [70]:
# Generate a random hundreds, default: between 500 ~ 3000
def get_random_hundreds(low=500, high=3000):
    return round(random.randint(low//100, high//100)) * 100

In [71]:
# Generate data owner datasets by labels
def generate_data_owner_datasets(labels):
    label_count_info = {label:get_random_hundreds() for label in labels}
    return label_count_info

In [89]:
# Set data owner label information
data_owner_label_info = {
    'A': ('trainer', 'plane, car, ship, boat, truck, other'),
    'B': ('trainer', 'truck, car, other'),
    'C': ('trainer', 'plane, truck, ship, other'),
    'D': ('trainer', 'ship, plane'),
    'E': ('trainer', 'plane, car, ship, other'),  # (overlapping dataset)
    'F': ('trainer', 'cat, dog, other'),  # (non existing labels)
    'G': ('trainer', 'truck, other'),  # (drop below baseline models)
    'T1': ('trainer', 'truck, car'),  # (all labels below baseline)
    'T2': ('trainer', 'plane, boat'),  # (one label below baseline)
    'T3': ('trainer', 'cat, dog'),  # (large model)
    'X': ('client', 'truck'),
    'Y': ('client', 'truck'), # (select models above baseline)
    'Z': ('client', 'plane, ship'), # (multiple labels)
    'T4': ('client', 'horse') # (non existing)
}
data_owner_label_info = {k:(v[0], v[1].split(', ')) for k,v in data_owner_label_info.items()}
data_owner_label_count_info = {k:{label:get_random_hundreds() for label in v[1]} for k,v in data_owner_label_info.items()}
all_labels = set()
for k,v in data_owner_label_count_info.items():
    all_labels.update(list(v.keys()))
    print(k, '\t', '\t'.join([f'{i[0]}({i[1]})' for i in v.items()]))

all_labels.remove('other')
print('All Labels:', ', '.join(list(all_labels)))

A 	 plane(1900)	car(2600)	ship(2800)	boat(1200)	truck(2000)	other(900)
B 	 truck(1600)	car(1500)	other(2500)
C 	 plane(2200)	truck(900)	ship(2500)	other(2800)
D 	 ship(2700)	plane(2800)
E 	 plane(2600)	car(2600)	ship(2200)	other(2900)
F 	 cat(800)	dog(2800)	other(1300)
G 	 truck(2500)	other(600)
T1 	 truck(2100)	car(1100)
T2 	 plane(1600)	boat(1700)
T3 	 cat(1100)	dog(2200)
X 	 truck(2100)
Y 	 truck(2700)
Z 	 plane(1400)	ship(2000)
T4 	 horse(2500)
All Labels: cat, boat, car, truck, horse, plane, dog, ship
