In [None]:
import os
import torch
import torchvision

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys

from fedlab.utils.dataset.partition import CIFAR10Partitioner
from fedlab.utils.functional import partition_report

In [None]:
num_clients = 10
num_clusters = 10
num_clients_per_cluster = num_clients // num_clusters
num_classes = 10
seed = 2021
hist_color = '#4169E1'

PROJECT_DIR = os.path.dirname(os.getcwd())
print(PROJECT_DIR)
CIFAR10_DIR = os.path.join(PROJECT_DIR, "data", "CIFAR10", "raw")
REPORT_DIR = os.path.join(PROJECT_DIR, "result", "notebook", "cifar10_partition", "report")
if not os.path.exists(REPORT_DIR):
    os.makedirs(REPORT_DIR)

In [None]:
# 指定data文件夹下面CIFAR10的原始数据
trainset = torchvision.datasets.CIFAR10(root=CIFAR10_DIR, train=True, download=True)

In [None]:
def plot_distribution(csv_file_name):
    df = pd.read_csv(os.path.join(REPORT_DIR, f"{csv_file_name}.csv"), header=1)
    df = df.set_index("client")
    col_names = [f"class{i}" for i in range(num_classes)]
    for col in col_names:
        df[col] = (df[col] * df["Amount"]).astype(int)
    df[col_names].iloc[:10].plot.barh(stacked=True)
    plt.tight_layout()
    plt.xlabel("Number of samples")
    plt.savefig(os.path.join(REPORT_DIR, f"{csv_file_name}.png"), dpi=300)

In [None]:
def combine_partition_indices(partition_indices):
    new_partition_indices = {
        cluster_id: []
        for cluster_id in range(num_clusters)
    }
    for cluster_id in range(num_clusters):
        # clients_id = [cluster_id * num_clients_per_cluster + i for i in range(num_clients_per_cluster)]
        for i in range(num_clients_per_cluster):
            client_id = cluster_id * num_clients_per_cluster + i
            new_partition_indices[cluster_id].extend(partition_indices[client_id])
    return new_partition_indices

## Dirichlet

### None_Dirichlet

In [None]:
cifar10_None_Dirichelet_partitioner = CIFAR10Partitioner(trainset.targets, num_clients, 
                                                         balance=None, partition="dirichlet", dir_alpha=0.3, min_require_size=100, seed=seed)
report_file_name = "cifar10_None_Dirichelet_partitioner"
report_file_path = os.path.join(REPORT_DIR, f"{report_file_name}.csv")
partition_report(trainset.targets, cifar10_None_Dirichelet_partitioner.client_dict, class_num=num_classes, verbose=False, file=report_file_path)
plot_distribution(report_file_name)

### False_Dirichlet

In [None]:
cifar10_False_Dirichelet_partitioner = CIFAR10Partitioner(trainset.targets, num_clients, 
                                                         balance=False, partition="dirichlet", dir_alpha=0.3, unbalance_sgm=0.3, seed=seed)
report_file_name = "cifar10_False_Dirichelet_partitioner"
report_file_path = os.path.join(REPORT_DIR, f"{report_file_name}.csv")
partition_report(trainset.targets, cifar10_False_Dirichelet_partitioner.client_dict, class_num=num_classes, verbose=False, file=report_file_path)
plot_distribution(report_file_name)

### True_Dirichlet

In [None]:
cifar10_True_Dirichelet_partitioner = CIFAR10Partitioner(trainset.targets, num_clients,
                                                            balance=True, partition="dirichlet", dir_alpha=0.3, seed=seed)
report_file_name = "cifar10_True_Dirichelet_partitioner"
report_file_path = os.path.join(REPORT_DIR, f"{report_file_name}.csv")
partition_report(trainset.targets, cifar10_True_Dirichelet_partitioner.client_dict, class_num=num_classes, verbose=False, file=report_file_path)
# partition_report(trainset.targets, combine_partition_indices(cifar10_True_Dirichelet_partitioner.client_dict), class_num=num_classes, verbose=False, file=report_file_path)
plot_distribution(report_file_name)