In [1]:
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import json

## Dataset

### MNIST

In [2]:
training_data = datasets.MNIST(
    root="./data",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)

testing_data = datasets.MNIST(
    root="./data",
    train=False,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)

### CIFAR10

In [3]:
training_data = datasets.CIFAR10(
    root="./data",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]),
)

testing_data = datasets.CIFAR10(
    root="./data",
    train=False,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]),
)

Files already downloaded and verified
Files already downloaded and verified


### CIFAR100

In [2]:
training_data = datasets.CIFAR100(
    root="./benchmark/cifar100/data/",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]),
)

testing_data = datasets.CIFAR100(
    root="./benchmark/cifar100/data/",
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]),
)

Files already downloaded and verified


## Division algorithm

**Each client:**

1. Contains no more than 3 labels
2. Each label has 8 to 20 samples
3. There're at least 5 * #numclass clients

In [3]:
import numpy as np
import os
import math

In [10]:
total_labels = np.unique(training_data.targets).tolist()
len(total_labels)

min_label_per_client = 2
max_label_per_client = 5

min_sample_per_client = 10
max_sample_per_client = 20

# num_clients = 5 * len(total_labels)
num_clients = 400

total_label = len(total_labels)
label_list = [i for i in range(total_label)]
label_per_client = 2

labels = training_data.targets
idxs = range(len(training_data))
training_idxs_labels = np.vstack((idxs, labels)).T

labels = testing_data.targets
idxs = range(len(testing_data))
testing_idxs_labels = np.vstack((idxs, labels)).T

training_dict_client = {client_id:[] for client_id in range(num_clients)}
testing_dict_client = {client_id:[] for client_id in range(num_clients)}

client_labels = []
not_passed_label_list = label_list.copy()

for client_id in range(num_clients):
    label_per_client = np.random.randint(min_label_per_client, max_label_per_client + 1)
    this_set = np.random.choice(label_list, label_per_client, replace=False)
    client_labels.append(list(this_set))
    not_passed_label_list = list(set(not_passed_label_list) - set(this_set))

if len(not_passed_label_list) > 0:
    print("Uncover", len(not_passed_label_list), "labels !")
    exit(0)
else:
    print("Uncover", len(not_passed_label_list), "labels !")

samples_details = []

for client_idx, client_label in zip(range(num_clients), client_labels):
    sample_this_client = []
    
    for label in client_label:
        sample_per_client = np.random.randint(min_sample_per_client, max_sample_per_client + 1)
        sample_this_client.append(sample_per_client)
        
        idxes_1 = training_idxs_labels[training_idxs_labels[:,1] == label][:,0]
        idxes_2 = testing_idxs_labels[testing_idxs_labels[:,1] == label][:,0]
        
        label_1_idxes = np.random.choice(idxes_1, sample_per_client, replace=False)
        label_2_idxes = np.random.choice(idxes_2, int(sample_per_client/4), replace=False)
        
        training_dict_client[client_idx] += label_1_idxes.tolist()
        testing_dict_client[client_idx] += label_2_idxes.tolist()
        
        training_idxs_labels[label_1_idxes] -= 100
        testing_idxs_labels[label_2_idxes] -= 100
    
    samples_details.append(sample_this_client)


dis_mtx = np.zeros([num_clients, total_label])
for client_id in range(len(client_labels)):
    client_label = client_labels[client_id]
    client_samples = samples_details[client_id]
    
    for label, num_samples in zip(client_label, client_samples):
        dis_mtx[client_id][label] = num_samples

Uncover 0 labels !


In [11]:
class NumpyEncoder(json.JSONEncoder):
    """ Special json encoder for numpy types """
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)
    
savepath = f"./dataset_idx/cifar100/sparse/{num_clients}client"
if not Path(savepath).exists():
    os.makedirs(savepath)
    
json.dump(training_dict_client, open(f"{savepath}/cifar100_sparse.json", "w"), cls=NumpyEncoder)
json.dump(testing_dict_client, open(f"{savepath}/cifar100_sparse_test.json", "w"), cls=NumpyEncoder)
np.savetxt(f"{savepath}/cifar100_sparse_stat.csv", dis_mtx, fmt="%d", delimiter=",")