In [None]:
import glob
import matplotlib.pyplot as plt
import torch
import pickle
import numpy as np
import os

In [None]:
def show_hist(results, category_names, n):
    """
    Parameters
    ----------
    results : dict
        A mapping from question labels to a list of answers per category.
        It is assumed all lists contain the same number of entries and that
        it matches the length of *category_names*.
    category_names : list of str
        The category labels.
    """
    labels = list(results.keys())
    data = np.array(list(results.values()))
    data_cum = data.cumsum(axis=1)
    category_colors = plt.get_cmap('RdYlGn')(
        np.linspace(0.15, 0.85, data.shape[1]))

    fig, ax = plt.subplots(figsize=(10, len(labels) * 0.8))
    ax.invert_yaxis()
    ax.xaxis.set_visible(False)
    ax.set_xlim(0, np.sum(data, axis=1).max())

    for i, (colname, color) in enumerate(zip(category_names, category_colors)):
        widths = data[:, i]
        starts = data_cum[:, i] - widths
        rects = ax.barh(labels, widths, left=starts, height=0.5,
                        label=colname, color=color)

        # r, g, b, _ = color
        # text_color = 'white' if r * g * b < 0.5 else 'darkgrey'
        # ax.bar_label(rects, label_type='center', color=text_color)
    ax.legend(ncol=len(category_names), bbox_to_anchor=(0, 1),
              loc='lower left', fontsize='small')
    ax.set_title(f'Discrete distribution of realworld dataset (client={n})', y=-0.5 / n)
    ax.set_xlim(0, 60000 // n * np.exp(1))
    return fig, ax

In [None]:
def show_dataset(n):
    N_CLIENTS = n
    DATASET_PATH =f'./export_realworld/mnist_{N_CLIENTS}'
    DATASET_PKL_LIST = glob.glob(os.path.join(DATASET_PATH, "*.pkl"))
    DATASET_PKL_LIST.sort()

    client_datasets = dict()
    for client_id in range(N_CLIENTS):
        with open(DATASET_PKL_LIST[client_id], 'rb') as f:
            client_datasets[client_id] = pickle.load(f)

    client_stats = dict()
    for client_id in range(N_CLIENTS):
        res = [0 for _ in range(10)]
        for label in client_datasets[client_id]['labels']:
            res[int(label)] += 1
        client_stats[f'client_{client_id}'] = res

    category_names = [f'{str(i)}' for i in range(10)]
    show_hist(client_stats, category_names, N_CLIENTS)
    plt.show()

In [None]:
show_dataset(10)

In [None]:
show_dataset(20)

In [None]:
show_dataset(50)

In [None]:
show_dataset(100)