In [9]:
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import math

In [1]:
def partition_CIFAR_dataset(
    dataset,
    file_name: str,
    balanced: bool,
    matrix,
    n_clients: int,
    n_classes: int,
    train: bool,
):
    """Partition dataset into `n_clients`.
    Each client i has matrix[k, i] of data of class k"""

    list_clients_X = [[] for i in range(n_clients)]
    list_clients_y = [[] for i in range(n_clients)]
    list_clients_shannon = []

    if balanced and train:
        n_samples = [500] * n_clients
    elif balanced and not train:
        n_samples = [100] * n_clients
    elif not balanced and train:
        n_samples = (
                [100] * 2 + [250] * 10 + [500] * 10 + [750] * 5 + [1000] * 3 + [100] * 8 + [250] * 20 + [500] * 20 + [
            750] * 15 + [1000] * 7
        )
    elif not balanced and not train:
        n_samples = [20] * 2 + [50] * 10 + [100] * 10 + [150] * 5 + [200] * 3 + [20] * 8 + [50] * 20 + [100] * 20 + [
            150] * 15 + [200] * 7


    list_idx = []
    for k in range(n_classes):

        idx_k = np.where(np.array(dataset.targets) == k)[0]
        list_idx += [idx_k]             # 第一维：标签值 第二维：属于该标签的条目下标

    for idx_client, n_sample in enumerate(n_samples):
        # 客户下标    客户样本数

        clients_idx_i = []  # client_i 分到的数据条目下标
        client_samples = 0
        client_shannon = 0  # 香农指数

        for k in range(n_classes):

            if k < 9:
                samples_digit = int(matrix[idx_client, k] * n_sample)
            if k == 9:
                samples_digit = n_sample - client_samples
            client_samples += samples_digit

            p_k = samples_digit / n_sample          # 样本比例
            if p_k != 0:
                client_shannon -= p_k * math.log(p_k)   # 香农指数计算

            clients_idx_i = np.concatenate(
                (clients_idx_i, np.random.choice(list_idx[k], samples_digit))   # 将标签k的数据随机选取分给client_i
            )

        # clients_idx_i 当前客户所持有数据 在数据集中的下标
        clients_idx_i = clients_idx_i.astype(int)

        list_clients_shannon.append(client_shannon)

        for idx_sample in clients_idx_i:

            list_clients_X[idx_client] += [dataset.data[idx_sample]]       # 客户idx_client 数据样本
            list_clients_y[idx_client] += [dataset.targets[idx_sample]]    # 客户idx_client 数据标签

        list_clients_X[idx_client] = np.array(list_clients_X[idx_client])

    folder = "./data/"
    with open(folder + file_name, "wb") as output:
        pickle.dump((list_clients_X, list_clients_y, list_clients_shannon), output)
        
def create_MNIST_dirichlet(
    dataset_name: str,
    balanced: bool,
    alpha: float,
    n_clients: int,
    n_classes: int,
):

    from numpy.random import dirichlet

    # matrix = dirichlet([alpha] * n_classes, size=n_clients)
    matrix = np.concatenate((dirichlet([alpha * 100] * n_classes, size=30), dirichlet([alpha] * n_classes, size=70)))

    MNIST_train = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
    MNIST_test = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())

    file_name_train = f"{dataset_name}_train_{n_clients}.pkl"
    partition_CIFAR_dataset(
        MNIST_train,
        file_name_train,
        balanced,
        matrix,
        n_clients,
        n_classes,
        True,
    )

    file_name_test = f"{dataset_name}_test_{n_clients}.pkl"
    partition_CIFAR_dataset(
        MNIST_test,
        file_name_test,
        balanced,
        matrix,
        n_clients,
        n_classes,
        False,
    )

In [18]:
for i in range(10):
    print(i)

0
1
2
3
4
5
6
7
8
9
