In [1]:
import os
import sys
import pathlib
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import torch
import numpy as np
import matplotlib.pyplot as plt
import datasets
import config

In [2]:
args = config.init_args(server=True, ipykernel=True)

args.excluded_heads []


In [5]:
# 1) IID 

def get_unique_labels(arr):
    unique_list = np.unique(arr)
    unique_list = unique_list[unique_list!=0]
    unique_list.sort()
    arr = np.zeros(21)
    arr[unique_list] = 1
    return arr

def get_labels(trainset):
    labels = [get_unique_labels(trainset[i][1].numpy()) for i in range(len(trainset))]
    labels = np.sum(labels, axis=0)
    return labels

def plot_clients_class_distribution_scatter_v3(clients_counts, scale_factor=50, save_path=None, title=None, classes=None):
    num_clients, num_classes = clients_counts.shape
    print(num_clients, num_classes)
    # num_classes = max(max(client_data) for client_data in clients_data) + 1
    # clients_counts = [np.bincount(client_data, minlength=num_classes) for client_data in clients_data]
    clients_percentages = [counts / sum(counts) for counts in clients_counts]
    # if nan_to_zero:
    clients_percentages = [np.nan_to_num(percentages) for percentages in clients_percentages]

    fig, ax = plt.subplots()
    if num_classes > 20:
        colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    elif num_classes > 10:
        colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    else:
        colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    legend_elements = []

    if classes is None:
        classes = [f'Class {i+1}' for i in range(num_classes)]
    for i in range(num_classes):
        legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', label=f'{classes[i]}', markerfacecolor=colors[i], markersize=10))

    for client_idx, percentages in enumerate(clients_percentages):
        for class_idx, percentage in enumerate(percentages):
            size = percentage * sum(clients_counts[client_idx]) * scale_factor
            print(size, client_idx, class_idx)
            ax.scatter(client_idx, class_idx, s=size, c=colors[class_idx].reshape(1, -1), alpha=0.6)

    ax.set_yticks(range(num_classes))
    ax.set_yticklabels([f'{classes[i]}' for i in range(num_classes)])
    ax.set_ylabel('Class Labels')
    ax.set_xticks(range(len(clients_counts)))
    # ax.set_xticklabels([f'{i+1}' for i in range(len(clients_data))])
    ax.set_xlabel('Clients')
    ax.legend(handles=legend_elements[::-1], loc='upper left', bbox_to_anchor=(1, 1))
    
    if title is None:
        if save_path is not None:
            title = save_path.split('/')[-1].split('.')[0]
        else:
            title = 'Class Distribution per Client (Size Proportional to Data Amount)'
    ax.set_title(title)
    # plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


In [6]:
args.alpha = -1.0
dataset = datasets.PascalVocSegmentationPartition(args=args)
class_counts = []
for i in range(args.num_clients):
    trainset, testset = dataset.load_partition(i)
    labels = get_labels(trainset)
    print(labels)
    class_counts.append(labels)
class_counts = np.array(class_counts, dtype=np.int32)

# make csv 
import pandas as pd
df = pd.DataFrame(class_counts)
df.to_csv("./PASCAL_VOC_2012_IID_class_counts.csv", index=False)

import pathlib
title = f'PASCAL_VOC_2012_IID'
# total_count = np.load("/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/total_count.npy").astype(int)
PASCAL_VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\
     'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
path = pathlib.Path(f'/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/PASCAL_VOC_2012_IID.png')
plot_clients_class_distribution_scatter_v3(clients_counts=class_counts, scale_factor=30, save_path=path, title=title, classes=PASCAL_VOC_CLASSES)

malicious clients: []
iid
103 144
[ 0.  7.  2.  9.  6.  3.  4.  6.  9.  2.  4.  2.  7.  5.  6. 21.  2.  5.
  1.  7.  2.]
iid
102 144
[ 0.  7.  3.  9.  6.  3.  4.  4.  9.  3.  4.  3.  8.  5.  5. 20.  2.  4.
  2.  7.  2.]
iid
103 144
[ 0.  7.  3.  9.  5.  3.  5.  7.  9.  3.  5.  1.  8.  5.  6. 23.  2.  5.
  1.  7.  2.]
iid
104 144
[ 0.  7.  3.  9.  6.  3.  4.  6.  9.  4.  5.  2.  8.  5.  6. 23.  3.  5.
  1.  7.  1.]
iid
103 144
[ 0.  8.  2.  7.  6.  3.  5.  6.  9.  3.  5.  3.  8.  5.  6. 23.  1.  5.
  1.  7.  2.]
iid
104 144
[ 0.  8.  3.  8.  5.  2.  5.  6.  9.  4.  5.  3.  8.  5.  4. 19.  3.  5.
  2.  7.  2.]
iid
102 144
[ 0.  8.  2.  9.  6.  3.  5.  7.  9.  4.  4.  3.  8.  5.  5. 17.  2.  5.
  2.  7.  2.]
iid
102 144
[ 0.  7.  2.  9.  6.  3.  5.  5.  9.  4.  5.  2.  8.  5.  5. 19.  1.  5.
  2.  7.  2.]
iid
102 144
[ 0.  7.  2.  8.  5.  2.  5.  7.  9.  4.  5.  2.  7.  5.  5. 17.  3.  5.
  2.  7.  2.]
iid
102 144
[ 0.  7.  3.  8.  6.  2.  4.  5.  8.  3.  5.  3.  7.  5.  6. 22.  1.  5.
  

In [7]:
args.alpha = 0.1
dataset = datasets.PascalVocSegmentationPartition(args=args)
class_counts = []
for i in range(args.num_clients):
    trainset, testset = dataset.load_partition(i)
    labels = get_labels(trainset)
    print(labels)
    class_counts.append(labels)
class_counts = np.array(class_counts, dtype=np.int32)

# make csv 
import pandas as pd
df = pd.DataFrame(class_counts)
df.to_csv("./PASCAL_VOC_2012_NIID_class_counts.csv", index=False)

import pathlib
title = f'PASCAL_VOC_2012_NIID (alpha={args.alpha})'
# total_count = np.load("/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/total_count.npy").astype(int)
PASCAL_VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',\
     'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
path = pathlib.Path(f'/home/suncheol/code/FedTest/0_FedMHAD_Seg/test/PASCAL_VOC_2012_NIID.png')
plot_clients_class_distribution_scatter_v3(clients_counts=class_counts, scale_factor=30, save_path=path, title=title, classes=PASCAL_VOC_CLASSES)

malicious clients: []
dirichlet
137 144
[ 0. 15.  2. 19.  1.  4.  6.  8.  2. 11.  0.  6. 20.  1.  5. 23.  9.  8.
  1.  4.  2.]
dirichlet
140 144
[ 0.  1.  0.  1.  1.  6.  4.  5. 24.  0. 14.  0.  1.  0. 16. 67.  3.  0.
  1.  5.  0.]
dirichlet
53 144
[ 0.  2.  0.  0.  0.  0.  3.  1.  4. 10.  5.  4.  8.  0.  2. 25.  0.  0.
  3.  1.  1.]
dirichlet
130 144
[ 0.  0. 13.  0. 18.  5.  8.  3. 20.  5.  0.  7. 19.  3.  0. 30.  1.  0.
  7. 21.  1.]
dirichlet
129 144
[ 0. 40.  1. 15.  3.  1.  3. 20.  0.  2.  2.  5. 26.  2.  0. 18.  1.  0.
  1.  0.  3.]
dirichlet
118 144
[ 0.  4.  2.  0.  7.  4.  6.  5.  8. 13.  6.  7.  1. 17.  2.  8.  1. 27.
  3. 10.  0.]
dirichlet
117 144
[ 0.  4.  2. 10. 11.  2.  1.  1. 15.  2. 11.  1.  1.  4. 10.  7.  1.  8.
  4.  2. 18.]
dirichlet
68 144
[ 0.  4.  3. 13.  9.  3.  3.  1.  6.  0.  1.  0.  1.  0.  5. 19.  1.  0.
  5.  2.  0.]
dirichlet
153 144
[ 0.  6.  5. 25.  1.  1. 15. 18. 23.  3.  8.  0.  0. 16.  2. 19. 13.  0.
  1. 17.  3.]
dirichlet
82 144
[ 0.  0.  0.  1.  