In [1]:
from sklearn.datasets import fetch_openml
from tqdm import trange
import numpy as np
import random

import os

WAY = 2
SHOT = 300
NUM_LABELS = 10
NUM_USERS = 100

DATA_ROOT = './mnist/type3/'
if not os.path.exists(DATA_ROOT):
    os.makedirs(DATA_ROOT)

# Get MNIST data, normalize, and divide by level
data, target = fetch_openml('mnist_784', version=1, return_X_y=True)
mu = np.mean(data.astype(np.float32), 0)
sigma = np.std(data.astype(np.float32), 0)
data = (data.astype(np.float32) - mu)/(sigma+0.001)
print('total number of data', len(data))

# data_tr, target_tr = data[:60000], target[:60000]
# data_ts, target_ts = data[60000:], target[60000:]

mnist_data_tr, mnist_data_ts = [], []
mnist_target_tr, mnist_target_ts = [], []
print(type(target))
for i in trange(10):
    idx = target==str(i)
    mnist_data_tr.append(data[idx])
    mnist_target_tr.append(target[idx])

    idx = target==str(i)
    mnist_data_ts.append(data[idx])
    mnist_target_ts.append(target[idx])

# 各labels的样本数量
print([len(v) for v in mnist_data_tr])
print([len(v) for v in mnist_data_ts])

total number of data 70000
<class 'numpy.ndarray'>


100%|██████████| 10/10 [00:00<00:00, 10.11it/s]

[6903, 7877, 6990, 7141, 6824, 6313, 6876, 7293, 6825, 6958]
[6903, 7877, 6990, 7141, 6824, 6313, 6876, 7293, 6825, 6958]





In [2]:
# Setup directory for train/test data

# test_path = './mnist/type3/all_data_niid_test.json'
# dir_path = os.path.dirname(os.path.join('all_data_niid_train.json'))

# mu = np.mean(mnist.data.astype(np.float32), 0)
# sigma = np.std(mnist.data.astype(np.float32), 0)
# mnist.data = (mnist.data.astype(np.float32) - mu)/(sigma+0.001)

将原数据集合进行分割

1. 共`num_users`个clients，取90%作为train clients，剩下的作为test clients；
2. 用一个长度为10的指针数组`idx`，记录每个label已经被“assign”的样本数（或者说下次取samples的时候从哪个位置开始取）
3. 每个clients首先分到`n-way k-shot`个samples；
4. 进一步，每个clients将会分到`n-way`随机数量的samples，这个随机数服从mu=0,sigma=2的对数正态分布；

In [3]:
###### CREATE USER DATA SPLIT #######
# Assign 10 samples to each user
# 为每个user分配10个样本，X,y分别代表1000个users各自的data及label
X = [[] for _ in range(NUM_USERS)]
y = [[] for _ in range(NUM_USERS)]

# 长度为10的数组: 指示每个标签已分配出去的样本数
idx = np.zeros(NUM_LABELS, dtype=np.int64)
num_shards = NUM_USERS * WAY # shards 表示要将全部数据分成多少份，假设每个user会应分到2份（2个不同labels），故共应划分为 2*num_users 份
# assign
shards_index = list(range(num_shards))
user_labels_count = {i: [] for i in range(NUM_USERS)}
# label_count = [0] * NUM_LABELS

np.random.seed(10) # seed 很重要
for user in range(NUM_USERS):
    while(True):
        rand_set = set(np.random.choice(shards_index, WAY, replace=False))
        labels = [idx // (NUM_LABELS * WAY) for idx in rand_set] # 
        if len(set(labels)) == WAY:
            break
#     print(labels)
    
#     for idx_lbl, lbl in enumerate(labels):
#         user_labels_count[user].append((idx_lbl, lbl, label_count[lbl]))
#         label_count[lbl] += 1

    shards_index = list(set(shards_index) - rand_set) # update shards_index
    # print(labels)
    for idx_lbl, lbl in enumerate(labels): # assign samples for the current user
#         print(lbl, idx[lbl], idx[lbl]+SHOT)
        X[user] += mnist_data_tr[lbl][idx[lbl]:idx[lbl]+SHOT].values.tolist()
        y[user] += (lbl*np.ones(SHOT, dtype=np.int8)).tolist() 
        idx[lbl] += SHOT # 记录当前标签已分配的样本数
    
    print(user, y[user])
#     print(idx)

0 [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [4]:
from tqdm import tqdm
# Create data structure
train_data = {'users': [], 'user_data':{}, 'num_samples':[], 'classes':{}}
test_data = {'users': [], 'user_data':{}, 'num_samples':[], 'classes':{}}

# Setup 100 users
for user in trange(NUM_USERS, ncols=120):
    uname = 'f_{0:05d}'.format(user) # type(X[i]), type(y[i]): list, list
    combined = list(zip(X[user], y[user])) # 如果外面加一个np.array()会出错！
    random.shuffle(combined)# 打乱
    
#     count = [0] * NUM_LABELS
#     for (x_, y_) in combined:
#         count[y_] += 1
#     print(count)
    
    # # 不再拆分成 X and y
    # X[i][:], y[i][:] = zip(*combined)
    # 取 80% 作为 train samples，剩下的作为 test samples
    train_len = int(0.5*len(combined))
    test_len = len(combined) - train_len
    
    classes = {idx_lbl:int(lbl) for idx_lbl, lbl, _ in user_labels_count[i]}
    
    train_data['users'].append(uname) 
    train_data['user_data'][uname] = combined[:train_len]
    train_data['num_samples'].append(train_len)
    train_data['classes'][uname] = classes
    
    test_data['users'].append(uname)
    test_data['user_data'][uname] = combined[train_len:]
    test_data['num_samples'].append(test_len)
    train_data['classes'][uname] = classes

print(type(train_data))
# npy 优点：读写更快、文件更小；缺点：无法在文本编辑器中直接查看
train_path = DATA_ROOT + '/all_data_niid_train_55_6000.npy'
test_path = DATA_ROOT + '/all_data_niid_test_55_6000.npy'
np.save(train_path, train_data)
np.save(test_path, test_data)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 151.10it/s]


<class 'dict'>


In [5]:
train_path = DATA_ROOT + '/all_data_niid_train_55_6000.npy'
test_path = DATA_ROOT + '/all_data_niid_test_55_6000.npy'
train_data = np.load(train_path, allow_pickle=True).item()
test_data = np.load(test_path, allow_pickle=True).item()
type(train_data)

dict