In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
import sys
import os
print(os.getcwd())
# you're in fl-heterogeneity/heterogeneity/notebooks
sys.path.append(os.path.abspath("./../.."))

In [3]:
import itertools

import numpy as np
import pandas as pd
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner, ShardPartitioner, InnerDirichletPartitioner

from heterogeneity.metrics import compute_kl_divergence
from heterogeneity.utils import create_lognormal_partition_sizes

# KL

## IID

In [39]:
# Sample usage
num_partitions = 10
iid_partitioner = IidPartitioner(num_partitions=num_partitions)
cifar_iid = FederatedDataset(dataset="cifar10", partitioners={"train" : iid_partitioner})
cifar_iid_partitions = [cifar_iid.load_partition(i) for i in range(num_partitions)]


num_partitions_to_cifar_iid_partitions = {}
num_partitions_to_cifar_iid_fds = {}
num_partitions_list = [3, 10, 30, 100, 300, 1000]
for num_partitions in num_partitions_list:
    iid_partitioner = IidPartitioner(num_partitions=num_partitions)
    cifar_iid = FederatedDataset(dataset="cifar10", partitioners={"train" : iid_partitioner})
    num_partitions_to_cifar_iid_fds[num_partitions] = cifar_iid
    # cifar_iid_partitions = [cifar_iid.load_partition(i) for i in range(num_partitions)]
    # num_partitions_to_cifar_iid_partitions[num_partitions] = cifar_iid_partitions

num_partitions_to_cifar_iid_hellinger_distance = {}
num_partitions_to_cifar_iid_hellinger_distance_list = {}
for num_partitions, cifar_iid_fds in num_partitions_to_cifar_iid_fds.items():
    metric_list, metric_avg = compute_kl_divergence(cifar_iid_fds.partitioners["train"])
    num_partitions_to_cifar_iid_hellinger_distance_list[num_partitions] = metric_list
    num_partitions_to_cifar_iid_hellinger_distance[num_partitions] = metric_avg    

In [47]:
iid_kl_div_results = pd.Series(num_partitions_to_cifar_iid_hellinger_distance, name="iid_kl").iloc[:-1].to_frame().style.background_gradient()
iid_kl_div_results.index.name = "num_partitions"
iid_kl_div_results

In [None]:
# labels = num_partitions_to_cifar_iid_fds[100].partitioners["train"].loa
# distributions = []
# for partition_id in num_partitions_to_cifar_iid_fds[100].partitioners["train"].num_partitions:
#     labels = num_partitions_to_cifar_iid_fds[100].partitioners["train"].loa
#     compute_distributions(

## Dirichlet

In [28]:
num_partitions = 10
alpha = [0.1] * 10
dirichlet_partitioner = DirichletPartitioner(num_partitions=num_partitions, alpha=alpha, partition_by="label")
cifar_dir = FederatedDataset(dataset="cifar10", partitioners={"train" : dirichlet_partitioner})
cifar_dir_partitions = [cifar_dir.load_partition(i) for i in range(num_partitions)]

num_partitions_to_cifar_dir_partitions = {}
num_partitions_to_cifar_dir_fds = {}
num_partitions_list = [3, 10, 30, 100, 300, 1000]
alpha_list = [0.1, 0.3, 1., 3., 10., 100., 100.]
for num_partitions, alpha in itertools.product(num_partitions_list, alpha_list):
    dir_partitioner =  DirichletPartitioner(num_partitions=num_partitions, alpha=alpha, partition_by="label")
    cifar_dir = FederatedDataset(dataset="cifar10", partitioners={"train" : dir_partitioner})
    num_partitions_to_cifar_dir_fds[(num_partitions, alpha)] = cifar_dir
    # cifar_iid_partitions = [cifar_iid.load_partition(i) for i in range(num_partitions)]
    # num_partitions_to_cifar_iid_partitions[num_partitions] = cifar_iid_partitions

num_partitions_to_cifar_dir_metric_list = {}
num_partitions_to_cifar_dir_metric = {}
for (num_partitions, alpha), cifar_dir_fds in num_partitions_to_cifar_dir_fds.items():
    print((num_partitions, alpha))
    try:
        metric_list, avg_metric = compute_kl_divergence(cifar_dir_fds.partitioners["train"])
    except:
        print(f"Sampling failed for {(num_partitions, alpha)}")
        metric_list, avg_metric = np.nan, np.nan
    num_partitions_to_cifar_dir_metric_list[(num_partitions, alpha)] = metric_list
    num_partitions_to_cifar_dir_metric[(num_partitions, alpha)] = avg_metric

In [43]:
kl_dir = pd.Series(num_partitions_to_cifar_dir_metric).unstack(level=1)
kl_dir.replace([np.inf, -np.inf], np.nan, inplace=True)
kl_dir.index.name = "num_partitions"
kl_dir.columns.name = "alpha"
kl_dir.style.background_gradient(axis=None)

In [30]:
results = pd.Series(num_partitions_to_cifar_dir_metric).to_frame().unstack(level=1)
results.replace([np.inf, -np.inf], np.nan, inplace=False).style.background_gradient(axis=None)

## Shard

In [45]:
params_to_partitioner = {}
num_partitions_list = [3, 10, 30, 100, 300, 1000]
num_shards_per_partition_list = [2, 3, 4, 5]
for num_partitions, num_shards_per_partition in itertools.product(num_partitions_list, num_shards_per_partition_list):
    partitioner = ShardPartitioner(num_partitions=num_partitions, partition_by="label", num_shards_per_partition=num_shards_per_partition)
    fds = FederatedDataset(dataset="cifar10", partitioners={"train" : partitioner})
    params_to_partitioner[(num_partitions, num_shards_per_partition)] = fds

parameters_to_shard_cifar_fds_metric_list = {}
parameters_to_shard_cifar_fds_metric = {}
for (num_partitions, num_shards_per_partition), fds in params_to_partitioner.items():
    print((num_partitions, num_shards_per_partition))
    try:
        metric_list, avg_metric = compute_kl_divergence(fds.partitioners["train"])
    except:
        print(f"Sampling failed for {(num_partitions, num_shards_per_partition)}")
        metric_list, avg_metric = np.nan, np.nan
    parameters_to_shard_cifar_fds_metric_list[(num_partitions, num_shards_per_partition)] = metric_list
    parameters_to_shard_cifar_fds_metric[(num_partitions, num_shards_per_partition)] = avg_metric

In [46]:
shard_emd_results = pd.Series(parameters_to_shard_cifar_fds_metric).unstack(level=1)
shard_emd_results.index.name = "num_partitions"
shard_emd_results.columns.name = "num_shards"
shard_emd_results.style.background_gradient(axis=None)

## InnerDirichlet

In [None]:
dataset_name = "cifar10"
# num_partitions = 10
# sigma = 0.3
# partition_sizes = create_lognormal_partition_sizes(dataset_name, num_partitions, sigma)
# 
# alpha = 0.1
# dirichlet_partitioner = InnerDirichletPartitioner(partition_sizes=partition_sizes, partition_by="label", alpha=0.1)
# cifar_dir = FederatedDataset(dataset="cifar10", partitioners={"train" : dirichlet_partitioner})
# cifar_dir_partitions = [cifar_dir.load_partition(i) for i in range(num_partitions)]

num_partitions_to_cifar_dir_fds = {}
num_partitions_list = [3, 10, 30, 100, 300, 1000]
alpha_list = [0.1, 0.3, 1., 3., 10., 100., 100.]
sigma_list = [0.1, 0.3, 1., 3.]
partition_sizes_dict = {}
print("Data Generation")
for num_partitions, alpha, sigma in itertools.product(num_partitions_list, alpha_list, sigma_list):
    print(num_partitions, alpha, sigma)
    partition_sizes = create_lognormal_partition_sizes(dataset_name, num_partitions, sigma)
    dir_partitioner =  InnerDirichletPartitioner(partition_sizes=partition_sizes, partition_by="label", alpha=alpha)
    cifar_dir = FederatedDataset(dataset="cifar10", partitioners={"train" : dir_partitioner})
    num_partitions_to_cifar_dir_fds[(num_partitions, alpha, sigma)] = cifar_dir
    partition_sizes_dict[(num_partitions, alpha, sigma)] = partition_sizes
    # cifar_iid_partitions = [cifar_iid.load_partition(i) for i in range(num_partitions)]
    # num_partitions_to_cifar_iid_partitions[num_partitions] = cifar_iid_partitions

num_partitions_to_cifar_dir_metric_list = {}
num_partitions_to_cifar_dir_metric = {}
print("Metrics calculation")
for (num_partitions, alpha, sigma), cifar_dir_fds in num_partitions_to_cifar_dir_fds.items():
    print((num_partitions, alpha, sigma))
    try:
        metric_list, avg_metric = compute_kl_divergence(cifar_dir_fds.partitioners["train"])
    except:
        print(f"Sampling failed for {(num_partitions, alpha, sigma)}")
        metric_list, avg_metric = np.nan, np.nan
    num_partitions_to_cifar_dir_metric_list[(num_partitions, alpha, sigma)] = metric_list
    num_partitions_to_cifar_dir_metric[(num_partitions, alpha, sigma)] = avg_metric