In [None]:
import os
import sys
import pathlib
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import torch
import numpy as np
import matplotlib.pyplot as plt
import datasets
import config

In [None]:
args = config.init_args(server=True, ipykernel=True)

In [None]:
# 1) IID 

def get_unique_labels(arr):
    unique_list = np.unique(arr)
    unique_list = unique_list[unique_list!=0]
    unique_list.sort()
    arr = np.zeros(21)
    arr[unique_list] = 1
    return arr

def get_labels(trainset):
    labels = [get_unique_labels(trainset[i][1].numpy()) for i in range(len(trainset))]
    labels = np.sum(labels, axis=0)
    return labels

def plot_clients_class_distribution_scatter_v3(clients_counts, scale_factor=50, save_path=None, title=None, classes=None):
    num_clients, num_classes = clients_counts.shape
    print(num_clients, num_classes)
    # num_classes = max(max(client_data) for client_data in clients_data) + 1
    # clients_counts = [np.bincount(client_data, minlength=num_classes) for client_data in clients_data]
    clients_percentages = [counts / sum(counts) for counts in clients_counts]
    # if nan_to_zero:
    clients_percentages = [np.nan_to_num(percentages) for percentages in clients_percentages]

    fig, ax = plt.subplots()
    if num_classes > 20:
        colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    elif num_classes > 10:
        colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    else:
        colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    legend_elements = []

    if classes is None:
        classes = [f'Class {i+1}' for i in range(num_classes)]
    for i in range(num_classes):
        legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', label=f'{classes[i]}', markerfacecolor=colors[i], markersize=10))

    for client_idx, percentages in enumerate(clients_percentages):
        for class_idx, percentage in enumerate(percentages):
            size = percentage * sum(clients_counts[client_idx]) * scale_factor
            print(size, client_idx, class_idx)
            ax.scatter(client_idx, class_idx, s=size, c=colors[class_idx].reshape(1, -1), alpha=0.6)

    ax.set_yticks(range(num_classes))
    ax.set_yticklabels([f'{classes[i]}' for i in range(num_classes)])
    ax.set_ylabel('Class Labels')
    ax.set_xticks(range(len(clients_counts)))
    # ax.set_xticklabels([f'{i+1}' for i in range(len(clients_data))])
    ax.set_xlabel('Clients')
    ax.legend(handles=legend_elements[::-1], loc='upper left', bbox_to_anchor=(1, 1))
    
    if title is None:
        if save_path is not None:
            title = save_path.split('/')[-1].split('.')[0]
        else:
            title = 'Class Distribution per Client (Size Proportional to Data Amount)'
    ax.set_title(title)
    # plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
args.alpha = -1.0
dataset = datasets.PascalVocSegmentationPartition(args=args)
class_counts = []
for i in range(args.num_clients):
    trainset, testset = dataset.load_partition(i)
    labels = get_labels(trainset)
    print(labels)
    class_counts.append(labels)
class_counts = np.array(class_counts, dtype=np.int32)

# make csv 
import pandas as pd
df = pd.DataFrame(class_counts)
df.to_csv("./PASCAL_VOC_2012_IID_class_counts.csv", index=False)

import pathlib
title = f'PASCAL_VOC_2012_IID'
# total_count = np.load("/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/total_count.npy").astype(int)
PASCAL_VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\
     'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
path = pathlib.Path(f'/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/PASCAL_VOC_2012_IID.png')
plot_clients_class_distribution_scatter_v3(clients_counts=class_counts, scale_factor=30, save_path=path, title=title, classes=PASCAL_VOC_CLASSES)

In [None]:
args.alpha = 0.1
dataset = datasets.PascalVocSegmentationPartition(args=args)
class_counts = []
for i in range(args.num_clients):
    trainset, testset = dataset.load_partition(i)
    labels = get_labels(trainset)
    print(labels)
    class_counts.append(labels)
class_counts = np.array(class_counts, dtype=np.int32)

# make csv 
import pandas as pd
df = pd.DataFrame(class_counts)
df.to_csv("./PASCAL_VOC_2012_NIID_class_counts.csv", index=False)

import pathlib
title = f'PASCAL_VOC_2012_NIID (alpha={args.alpha})'
# total_count = np.load("/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/total_count.npy").astype(int)
PASCAL_VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\
     'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
path = pathlib.Path(f'/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/PASCAL_VOC_2012_NIID.png')
plot_clients_class_distribution_scatter_v3(clients_counts=class_counts, scale_factor=30, save_path=path, title=title, classes=PASCAL_VOC_CLASSES)