In [3]:
import os
import numpy as np
import torchvision
from PIL import Image

import settings

## Load Dataset

In [4]:
root = settings.DATA_HOME['cifar10']
train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


## Merge Dataset

In [5]:
all_data = []
all_labels = []
for PIL_image, label in train_dataset:
    # image = np.transpose(PIL_image, (2, 0, 1))
    all_data.append(PIL_image)
    all_labels.append(label)

for PIL_image, label in test_dataset:
    # image = np.transpose(PIL_image, (2, 0, 1))
    all_data.append(PIL_image)
    all_labels.append(label)

all_data = np.array(all_data, dtype=Image.Image)
all_labels = np.array(all_labels, dtype=np.long)

all_data.shape, all_labels.shape

((60000,), (60000,))

## Divide data

In [6]:
n_classes = np.unique(all_labels).shape[0]
classified_data = [all_data[all_labels == i] for i in range(n_classes)]
classified_data = np.array(classified_data, dtype=Image.Image)

n_clients = 10
clients = ['Client-{}'.format(i) for i in range(n_clients)]

label_seq_0 = np.arange(0, 10, 1)
label_seq_1 = np.concatenate([np.arange(1, 10, 1), [0]])
label_seq_2 = np.concatenate([np.arange(2, 10, 1), [0, 1]])
classified_data.shape, label_seq_0, label_seq_1, label_seq_2

((10, 6000),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
 array([2, 3, 4, 5, 6, 7, 8, 9, 0, 1]))

In [7]:
client_data = []
client_labels = []

for client_name, label_0, label_1, label_2 in zip(clients, label_seq_0, label_seq_1, label_seq_2):
    data = np.concatenate([classified_data[label_0][:2000],
                           classified_data[label_1][2000:4000],
                           classified_data[label_2][4000:],
                           ])
    labels = np.array([label_0] * 2000 + [label_1] * 2000 + [label_2] * 2000)

    np.savez_compressed(os.path.join(root, '{}_dataset'.format(client_name)), data=data, labels=labels)

    client_data.append([np.transpose(image, (2, 0, 1)) for image in data])
    client_labels.append(labels)

client_data = np.array(client_data, np.uint8)
client_labels = np.array(client_labels, np.long)
print(client_data.shape)

np.savez_compressed(os.path.join(root, 'CIFAR10_dataset'),
                    client_names=clients, data=client_data, labels=client_labels)

(10, 6000, 3, 32, 32)


In [8]:
f = np.load(os.path.join(root, 'CIFAR10_dataset.npz'), allow_pickle=True)
f['data'].shape


(10, 6000, 3, 32, 32)