In [None]:
import torch
import numpy as np
from torchvision import transforms
import torchvision.models as  models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from utils import encode_labels
import os
import matplotlib.pyplot as plt
import pathlib

In [None]:
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

def encode_labels(target):
    """
    Encode multiple labels using 1/0 encoding 
    
    Args:
        target: xml tree file
    Returns:
        torch tensor encoding labels as 1/0 vector
    """
    
    ls = target['annotation']['object']
  
    j = []
    if type(ls) == dict:
        if int(ls['difficult']) == 0:
            j.append(object_categories.index(ls['name']))
  
    else:
        for i in range(len(ls)):
            if int(ls[i]['difficult']) == 0:
                j.append(object_categories.index(ls[i]['name']))
    
    k = np.zeros(len(object_categories))
    k[j] = 1
  
    return torch.from_numpy(k)

In [None]:
# resize 224, 224
transformations = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.ToTensor()])
train_datasets = datasets.voc.VOCDetection(root='~/.data/', year='2012', image_set='train', download=False, transform = transformations, target_transform = encode_labels)
val_datasets = datasets.voc.VOCDetection(root='~/.data/', year='2012', image_set='val', download=False, transform = transformations, target_transform = encode_labels)

train_loader = DataLoader(train_datasets, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)

In [None]:
def get_img_label(dataset):
    imgs = []
    labels = [] 
    for i, (img, label) in enumerate(dataset):
        img, label = img.cpu().numpy(), label.cpu().numpy()
        imgs.append(img)
        labels.append(label)
    imgs = np.array(imgs)
    labels = np.array(labels)
    return imgs, labels

train_imgs, train_labels = get_img_label(train_datasets)
val_imgs, val_labels = get_img_label(val_datasets)

In [None]:
# PASCAL VOC 2012
path = pathlib.Path('../data/PASCAL_VOC_2012/')
path.mkdir(parents=True, exist_ok=True)
np.save(path.joinpath('PASCAL_VOC_train_224_Img.npy'), train_imgs)
np.save(path.joinpath('PASCAL_VOC_train_224_Label.npy'), train_labels)
np.save(path.joinpath('PASCAL_VOC_val_224_Img.npy'), val_imgs)
np.save(path.joinpath('PASCAL_VOC_val_224_Label.npy'), val_labels)

In [1]:
import sys 
sys.path.append("..")
import utils.utils as utils
import utils.dirichlet_split as dirichlet

In [None]:
import pathlib
path = pathlib.Path('../data/PASCAL_VOC_2012/')
val_imgs = np.load(path.joinpath('PASCAL_VOC_val_224_Img.npy'))
val_labels = np.load(path.joinpath('PASCAL_VOC_val_224_Label.npy'))
train_imgs = np.load(path.joinpath('PASCAL_VOC_train_224_Img.npy'))
train_labels = np.load(path.joinpath('PASCAL_VOC_train_224_Label.npy'))


In [None]:
# invert onehotvector (labels to index for multi labels)
'''
[0,0,0,0,0,0] = 0
[1,0,0,0,0,0] = 1
[0,1,0,0,0,0] = 2
[0,1,1,0,0,0] = 6

'''
def get_oct_num(list):
    oct_num = 0
    for i in range(len(list)):
        oct_num += list[i] * 2 ** i
    return oct_num

def get_bin_num(oct, nClass):
    bin_num = []
    for i in range(nClass):
        bin_num.append(oct % 2)
        oct = oct // 2
    return bin_num
     
def get_label_to_index(labels):
    label_index = []
    for i in range(len(labels)):
        o = get_oct_num(labels[i])
        label_index.append(o)
    return label_index

def get_index_to_label(label_index, nClass):
    label_onehot = []
    for i in range(len(label_index)):
        label_onehot.append(get_bin_num(label_index[i], nClass))
    return label_onehot

val_indices = get_label_to_index(val_labels)
train_indices = get_label_to_index(train_labels)

In [None]:
y_total = np.sum(val_labels, axis=0)

In [None]:
# plot classes distritubtion
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

plt.figure(figsize=(10, 4))
for i in range(len(object_categories)):
    # not overlapped text 
    plt.bar(object_categories[i], y_total[i])
    plt.text(object_categories[i], y_total[i], y_total[i], ha='center', va='bottom') 
plt.title(f'MSCOCO')
plt.xlabel('Object Categories')
plt.ylabel('Number of Images')
plt.xticks(rotation=90)
plt.legend()

In [None]:
unique, counts = np.unique(train_indices, return_counts=True)
dict_indices = dict(zip(unique, range(len(unique))))
dict_indices_inv = dict(zip(range(len(unique)), unique))
X_train = train_imgs
y_train = np.array([dict_indices[i] for i in train_indices])

N_class = len(unique)
N_parties = 5
alpha = 1

# from sklearn.model_selection import train_test_split
# X_train, X_test, y_train, y_test = train_test_split(img, gt, test_size=0.2, random_state=42, stratify=gt)

dirchlet_arr = dirichlet.get_dirichlet_distribution_count(N_class, N_parties, y_train, alpha)
# np.random.RandomState(1)
dirichlet.set_random_seed(0)
dirichlet.plot_dirichlet_distribution(N_class, N_parties, alpha)
dirichlet.plot_dirichlet_distribution_count(N_class, N_parties, y_train, alpha)
whole_y = np.hstack((y_train, y_test))
dirichlet.plot_whole_y_distribution(whole_y)
dirichlet.plot_dirichlet_distribution_count_subplot(N_class, N_parties, y_train, alpha)
split_dirchlet_data = dirichlet.get_dirichlet_split_data(X_train, y_train, N_parties, N_class, alpha)
y = split_dirchlet_data[0]['y']
y = np.array([dict_indices_inv[i] for i in y])
y = np.array(get_index_to_label(y, 20))
y.shape
# dirichlet_path
dirichlet_path = pathlib.Path(f'./../data/PASCAL_VOC_2012/dirichlet/alpha_{alpha}/')
dirichlet_path.mkdir(parents=True, exist_ok=True)

for i in range(N_parties):
    np.save(dirichlet_path.joinpath(f'Party_{i}_X_data.npy'), split_dirchlet_data[i]['x'])
    y = split_dirchlet_data[i]['y']
    y = np.array([dict_indices_inv[i] for i in y])
    y = np.array(get_index_to_label(y, 20))
    np.save(dirichlet_path.joinpath(f'Party_{i}_y_data.npy'), y)
    
# plot classes distritubtion
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

# load numpy data
for n_client in range(N_parties):
    print(f'n_client: {n_client}')
    plt.figure(figsize=(10, 4))
    dirichlet_path = pathlib.Path(f'./../data/PASCAL_VOC_2012/dirichlet/alpha_{alpha}/')
    X_train_0 = np.load(dirichlet_path.joinpath(f'Party_{n_client}_X_data.npy'))
    y_train_0 = np.load(dirichlet_path.joinpath(f'Party_{n_client}_y_data.npy'))
    y_total = np.sum(y_train_0, axis=0)
    for i in range(len(object_categories)):
        # not overlapped text 
        plt.bar(object_categories[i], y_total[i])
        plt.text(object_categories[i], y_total[i], y_total[i], ha='center', va='bottom') 
    plt.title(f'PASCAL VOC 2012 (party {n_client}), alpha={alpha}')
    plt.xlabel('Object Categories')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=90)
    plt.legend()

    # save figure 
    save_path = pathlib.Path(f"../figures/dirichlet/alpha_{alpha}/")
    save_path.mkdir(parents=True, exist_ok=True)
    plt.savefig(save_path.joinpath(f'Party_{n_client}_y_data.png'))
    plt.show()