In [None]:
import torch
import torchvision
import tqdm
import random
from typing import List
import pickle
dataset = torchvision.datasets.MNIST('./data/',
                                     train=True,
                                     download=True,
                                     transform=torchvision.transforms.Compose([
                                         torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
                                     ]))


In [None]:
res_stimulis: torch.Tensor = torch.zeros(size=(10, 7000, 1, 28, 28), dtype=torch.float32)
res_labels: torch.Tensor = torch.zeros(size=(10, 7000), dtype=torch.float32)
res_index: torch.Tensor = torch.zeros(size=(10,),dtype=torch.int64)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True)

with tqdm.tqdm(dataset) as pbar:
    for item in dataset:
        label = item[1]
        res_stimulis[label, res_index[label],:,:,:] = item[0]
        res_labels[label, res_index[label]] = label
        res_index[label] += 1
        pbar.update()

res_stimulis_all: torch.Tensor = torch.cat([res_stimulis[idx, 0:res_index[idx],...] for idx in range(10)])
res_labels_all: torch.Tensor = torch.cat([res_labels[idx, 0:res_index[idx],...] for idx in range(10)]).to(torch.int64)

In [None]:
n_client: int = 100
n_patches: int = n_client * 2
patch_sz = 60000 // n_patches
data_assignment: List[int] = list(range(n_patches))
random.shuffle(data_assignment)
export_dir = f'./export_mnist_{n_client}'

for client_id in range(n_client):
    patch_idx_1 = data_assignment[client_id * 2]
    patch_idx_2 = data_assignment[client_id * 2  + 1]

    stimulis_tmp = torch.cat([
        res_stimulis_all[patch_idx_1 * patch_sz:(patch_idx_1 + 1) * patch_sz, ...],
        res_stimulis_all[patch_idx_2 * patch_sz:(patch_idx_2 + 1) * patch_sz, ...]
    ])
    labels_tmp = torch.cat([
        res_labels_all[patch_idx_1 * patch_sz:(patch_idx_1 + 1) * patch_sz, ...],
        res_labels_all[patch_idx_2 * patch_sz:(patch_idx_2 + 1) * patch_sz, ...]
    ]).to(torch.int64)

    data = {'stimulis': stimulis_tmp, 'labels': labels_tmp}
    with open(f'{export_dir}/client_{client_id}.pkl', 'wb') as f:
        pickle.dump(data, f)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(data['stimulis'][400,0,...])
plt.show()
data['labels'][400]