In [5]:
import numpy as np
import torch
import pickle

from metrics.phi_div import dkl, average_dkl
from metrics.wasserstein import wasserstein_2
from core.kernel import get_kernel
from core.mmd import mmd_neg_unbiased_batched
from metrics.class_imbalance import get_classes, class_proportion
from scipy import stats

In [2]:
datasets = ['gmm', 'mnist', 'cifar']
splits = ['equaldisjoint', 'unequal']
greeds = [1, 2, 4, 8]
condition = 'stable'
keys = ['party_datasets', 'party_labels', 'reference_dataset', 'candidate_datasets', 'candidate_labels',
        'rewards', 'deltas', 'mus', 'alpha', 'class_props', 'dkl_before', 'dkl_after', 'wass_before', 'wass_after']
num_parties = 5

## Load data

In [3]:
results_dict = {}
# Add empty dicts 
for ds in datasets:
    results_dict[ds] = {}
    for split in splits:
        results_dict[ds][split] = {}
        for greed in greeds:
            results_dict[ds][split][greed] = {}

In [4]:
# Set run_id manually
results_dict['gmm']['equaldisjoint'][1]['run'] = 78
results_dict['gmm']['equaldisjoint'][2]['run'] = 77
results_dict['gmm']['equaldisjoint'][4]['run'] = 76
results_dict['gmm']['equaldisjoint'][8]['run'] = 75

results_dict['gmm']['unequal'][1]['run'] = 71
results_dict['gmm']['unequal'][2]['run'] = 72
results_dict['gmm']['unequal'][4]['run'] = 73
results_dict['gmm']['unequal'][8]['run'] = 74

###
results_dict['mnist']['equaldisjoint'][1]['run'] = 15
results_dict['mnist']['equaldisjoint'][2]['run'] = 44
results_dict['mnist']['equaldisjoint'][4]['run'] = 46
results_dict['mnist']['equaldisjoint'][8]['run'] = 48

results_dict['mnist']['unequal'][1]['run'] = 13
results_dict['mnist']['unequal'][2]['run'] = 14
results_dict['mnist']['unequal'][4]['run'] = 45
results_dict['mnist']['unequal'][8]['run'] = 47

# Set alpha for cifar manually
results_dict['cifar']['equaldisjoint'][1]['alpha'] = [0.3794023848214341, 1.0, 0.6826898867992747, 0.906607589647515, 0.11702653001907139]
results_dict['cifar']['equaldisjoint'][2]['alpha'] = [0.3794023848214341, 1.0, 0.6826898867992747, 0.906607589647515, 0.11702653001907139]
results_dict['cifar']['equaldisjoint'][4]['alpha'] = [0.3794023848214341, 1.0, 0.6826898867992747, 0.906607589647515, 0.11702653001907139]
results_dict['cifar']['equaldisjoint'][8]['alpha'] = [0.3794023848214341, 1.0, 0.6826898867992747, 0.906607589647515, 0.11702653001907139]

results_dict['cifar']['unequal'][1]['alpha'] = [1.0, 0.9706108673457912, 0.2638529401880037, 0.4769904835896915, 0.17301571953644704]
results_dict['cifar']['unequal'][2]['alpha'] = [1.0, 0.9706108673457912, 0.2638529401880037, 0.4769904835896915, 0.17301571953644704]
results_dict['cifar']['unequal'][4]['alpha'] = [1.0, 0.9706108673457912, 0.2638529401880037, 0.4769904835896915, 0.17301571953644704]
results_dict['cifar']['unequal'][8]['alpha'] = [1.0, 0.9706108673457912, 0.2638529401880037, 0.4769904835896915, 0.17301571953644704]

In [5]:
# Set kernel lengthscales manually
results_dict['gmm']['equaldisjoint'][1]['lengthscale'] = 0.2646284961700439
results_dict['gmm']['equaldisjoint'][2]['lengthscale'] = 0.2646284961700439
results_dict['gmm']['equaldisjoint'][4]['lengthscale'] = 0.2646284961700439
results_dict['gmm']['equaldisjoint'][8]['lengthscale'] = 0.2646284961700439

results_dict['gmm']['unequal'][1]['lengthscale'] = 0.06721988677978516
results_dict['gmm']['unequal'][2]['lengthscale'] = 0.06721988677978516
results_dict['gmm']['unequal'][4]['lengthscale'] = 0.06721988677978516
results_dict['gmm']['unequal'][8]['lengthscale'] = 0.06721988677978516

results_dict['mnist']['equaldisjoint'][1]['lengthscale'] = 3.358317041397094
results_dict['mnist']['equaldisjoint'][2]['lengthscale'] = 3.358317041397094
results_dict['mnist']['equaldisjoint'][4]['lengthscale'] = 3.358317041397094
results_dict['mnist']['equaldisjoint'][8]['lengthscale'] = 3.358317041397094

results_dict['mnist']['unequal'][1]['lengthscale'] = 1.5272806644439696
results_dict['mnist']['unequal'][2]['lengthscale'] = 1.5272806644439696
results_dict['mnist']['unequal'][4]['lengthscale'] = 1.5272806644439696
results_dict['mnist']['unequal'][8]['lengthscale'] = 1.5272806644439696

results_dict['cifar']['equaldisjoint'][1]['lengthscale'] = 2.8843456459045416
results_dict['cifar']['equaldisjoint'][2]['lengthscale'] = 2.8843456459045416
results_dict['cifar']['equaldisjoint'][4]['lengthscale'] = 2.8843456459045416
results_dict['cifar']['equaldisjoint'][8]['lengthscale'] = 2.8843456459045416

results_dict['cifar']['unequal'][1]['lengthscale'] = 1.2373665714263915
results_dict['cifar']['unequal'][2]['lengthscale'] = 1.2373665714263915
results_dict['cifar']['unequal'][4]['lengthscale'] = 1.2373665714263915
results_dict['cifar']['unequal'][8]['lengthscale'] = 1.2373665714263915

In [6]:
# Load data
for ds in ['gmm', 'mnist']:
    for split in splits:
        for greed in greeds:
            run_id = results_dict[ds][split][greed]['run']
            tup = pickle.load(open("data/{}/cgm-results/CGM-{}-{}-greed{}-{}-run{}.p".format(ds,
                                                                               ds,
                                                                               split,
                                                                               greed,
                                                                               condition,
                                                                               run_id), "rb"))
            for i in range(len(keys)):
                results_dict[ds][split][greed][keys[i]] = tup[i]

# cifar
ds = 'cifar'
for split in splits:
    for greed in greeds:
        tup = pickle.load(open("data/{}/cgm-results/CGM-{}-{}-greed{}-{}.p".format(ds,
                                                                               ds,
                                                                               split,
                                                                               greed,
                                                                               condition), "rb"))
        for i in range(len(tup)):
            results_dict[ds][split][greed][keys[i]] = tup[i]

In [7]:
# Cut down reward at maximum attained if stopped early
for ds in datasets:
    for split in splits:
        for greed in greeds:
            dic = results_dict[ds][split][greed]
            for party in range(5):
                mus = dic['mus'][party]
                max_mu_idx = np.argmax(mus)
                if max_mu_idx == len(mus) - 1:
                    print("No early stopping for {}-{}-{} party {}".format(ds, split, greed, party+1))
                else:
                    print('{}-{}-{} party {}: max at {}, total length is {}'.format(ds, split, greed, party+1, max_mu_idx, len(mus)))

No early stopping for gmm-equaldisjoint-1 party 1
No early stopping for gmm-equaldisjoint-1 party 2
No early stopping for gmm-equaldisjoint-1 party 3
No early stopping for gmm-equaldisjoint-1 party 4
No early stopping for gmm-equaldisjoint-1 party 5
No early stopping for gmm-equaldisjoint-2 party 1
No early stopping for gmm-equaldisjoint-2 party 2
No early stopping for gmm-equaldisjoint-2 party 3
No early stopping for gmm-equaldisjoint-2 party 4
No early stopping for gmm-equaldisjoint-2 party 5
No early stopping for gmm-equaldisjoint-4 party 1
No early stopping for gmm-equaldisjoint-4 party 2
No early stopping for gmm-equaldisjoint-4 party 3
No early stopping for gmm-equaldisjoint-4 party 4
No early stopping for gmm-equaldisjoint-4 party 5
No early stopping for gmm-equaldisjoint-8 party 1
No early stopping for gmm-equaldisjoint-8 party 2
No early stopping for gmm-equaldisjoint-8 party 3
No early stopping for gmm-equaldisjoint-8 party 4
No early stopping for gmm-equaldisjoint-8 party 5


In [8]:
pickle.dump(results_dict, open('data/results_dict.p', 'wb'))

## Load metrics

In [9]:
metrics_keys = ['class_props', 'wass_before', 'wass_after', 'dkls_before', 'dkls_after']

### Wass-2 and average DKL

In [10]:
# Load pre-computed Wass-2 and average DKL
for ds in datasets:
    for split in splits:
        for greed in greeds:
            tup = pickle.load(open("data/metrics-{}-{}-{}.p".format(ds, split, greed), "rb"))
            for i in range(len(metrics_keys)):
                results_dict[ds][split][greed][metrics_keys[i]] = tup[i]

In [11]:
# Num_rewards
for ds in datasets:
    for split in splits:
        for greed in greeds:
            rewards = results_dict[ds][split][greed]['rewards']
            results_dict[ds][split][greed]['num_rewards'] = [len(rewards[i]) for i in range(len(rewards))]

In [12]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        for greed in greeds:
            print("Greed: {}".format(greed))
            alpha = results_dict[ds][split][greed]['alpha']
            wass = results_dict[ds][split][greed]['wass_after']
            dkl = results_dict[ds][split][greed]['dkls_after']
            wass_before = results_dict[ds][split][greed]['wass_before']
            dkl_before = results_dict[ds][split][greed]['dkls_before']
            num_rewards = results_dict[ds][split][greed]['num_rewards']
            
            print("Correlation between alpha and Wass-2: {}".format(np.corrcoef(alpha, wass)[0, 1]))
            print("Correlation between alpha and average DKL: {}".format(np.corrcoef(alpha, dkl)[0, 1]))
            print("Correlation between alpha and number of rewards: {}".format(np.corrcoef(alpha, num_rewards)[0, 1]))
            print("Wass-2 before minus Wass-2 after: {}".format(np.array(wass_before) - np.array(wass)))
            print("Average DKL before minus average DKL after: {}".format(np.array(dkl_before) - np.array(dkl)))
            print("--")
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
Greed: 1
Correlation between alpha and Wass-2: -0.9846220273672672
Correlation between alpha and average DKL: -0.9235107096967035
Correlation between alpha and number of rewards: 0.7340356609029092
Wass-2 before minus Wass-2 after: [0.09625895 0.04958604 0.02072756 0.10384659 0.03020694]
Average DKL before minus average DKL after: [1.16293447 1.02392497 0.8122535  1.20304905 0.99337203]
--
Greed: 2
Correlation between alpha and Wass-2: -0.9866509579811745
Correlation between alpha and average DKL: -0.8763971437955874
Correlation between alpha and number of rewards: 0.7117470908381857
Wass-2 before minus Wass-2 after: [0.09652848 0.04949487 0.02072677 0.10415698 0.03011462]
Average DKL before minus average DKL after: [1.15053109 1.02395831 0.80836672 1.19521533 0.96964937]
--
Greed: 4
Correlation between alpha and Wass-2: -0.9857439191541563
Correlation between alpha and average DKL: -0.8558120769819207
Correlation between alpha and number of rewards: 0

In [13]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        for greed in greeds:
            print("Greed: {}".format(greed))
            alpha = results_dict[ds][split][greed]['alpha']
            wass = results_dict[ds][split][greed]['wass_after']
            dkl = results_dict[ds][split][greed]['dkls_after']
            wass_before = results_dict[ds][split][greed]['wass_before']
            dkl_before = results_dict[ds][split][greed]['dkls_before']
            num_rewards = results_dict[ds][split][greed]['num_rewards']
            
            print("Alpha: {}".format(alpha))
            print("Num rewards: {}".format(num_rewards))
            print("Wass: {}".format(wass))
            print("DKL: {}".format(dkl))
            
            
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
Greed: 1
Alpha: [0.4166184216522995, 0.7104218201257749, 1.0, 0.32040511909698494, 0.8806314309296782]
Num rewards: [1510, 1854, 6123, 1437, 2071]
Wass: [0.008951572897354692, 0.002763305334298526, 1.4735212896280413e-06, 0.011440534783063071, 0.0006241270603567486]
DKL: [0.14310273 0.08884614 0.00223889 0.17127795 0.09502109]
Greed: 2
Alpha: [0.4166184216522995, 0.7104218201257749, 1.0, 0.32040511909698494, 0.8806314309296782]
Num rewards: [1239, 1451, 4412, 1253, 1560]
Wass: [0.008682041061354645, 0.002854480999022661, 2.263883461769906e-06, 0.011130140182510547, 0.000716445923395083]
DKL: [0.15550611 0.0888128  0.00612566 0.17911167 0.11874374]
Greed: 4
Alpha: [0.4166184216522995, 0.7104218201257749, 1.0, 0.32040511909698494, 0.8806314309296782]
Num rewards: [1080, 1189, 3139, 1083, 1189]
Wass: [0.008539384610517065, 0.0031040448751797596, 5.084407974095273e-06, 0.011490243297486855, 0.0005604233125868868]
DKL: [0.22204413 0.14730658 0.00976006 0.21

### MMD unbiased

In [14]:
device = 'cuda:0'
for ds in datasets:
    if ds == 'gmm':
        d = 2
    else:
        d = 8
    for split in splits:
        for greed in greeds:
            rewards = results_dict[ds][split][greed]['rewards']
            reference_dataset = results_dict[ds][split][greed]['reference_dataset']
            party_datasets = results_dict[ds][split][greed]['party_datasets']
            reference_dataset_tens = torch.tensor(reference_dataset, device=device, dtype=torch.float32)
            ls = results_dict[ds][split][greed]['lengthscale']
            k = get_kernel('se', d, ls, device)
            mmd_unbiased_before = [0 for i in range(num_parties)]
            mmd_unbiased_after = [0 for i in range(num_parties)]
            for i in range(num_parties):
                party_dataset_tens = torch.tensor(party_datasets[i], device=device, dtype=torch.float32)
                party_dataset_with_rewards_tens =torch.tensor(np.concatenate([party_datasets[i], rewards[i]], axis=0), 
                                                              device=device, 
                                                              dtype=torch.float32)
                mmd_unbiased_before[i] = -mmd_neg_unbiased_batched(party_dataset_tens,
                                                                   reference_dataset_tens,
                                                                   k)
                mmd_unbiased_after[i] = -mmd_neg_unbiased_batched(party_dataset_with_rewards_tens,
                                                                  reference_dataset_tens,
                                                                  k)
            results_dict[ds][split][greed]['mmd_unbiased_before'] = mmd_unbiased_before
            results_dict[ds][split][greed]['mmd_unbiased_after'] = mmd_unbiased_after

In [15]:
for ds in datasets:
    for split in splits:
        for greed in greeds:
            results_dict[ds][split][greed]['mmd_unbiased_before'] = [tens.item() for tens in results_dict[ds][split][greed]['mmd_unbiased_before']]
            results_dict[ds][split][greed]['mmd_unbiased_after'] = [tens.item() for tens in results_dict[ds][split][greed]['mmd_unbiased_after']]

In [16]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        for greed in greeds:
            print("Greed: {}".format(greed))
            alpha = results_dict[ds][split][greed]['alpha']
            mmd_unbiased_before = results_dict[ds][split][greed]['mmd_unbiased_before']
            mmd_unbiased_after = results_dict[ds][split][greed]['mmd_unbiased_after']
            
            print("Correlation between alpha and unbiased MMD: {}".format(np.corrcoef(alpha, mmd_unbiased_after)[0, 1]))
            print("unbiased MMD before minus unbiased MMD after: {}".format(np.array(mmd_unbiased_before) - np.array(mmd_unbiased_after)))
            print("--")
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
Greed: 1
Correlation between alpha and unbiased MMD: -0.9930324723223307
unbiased MMD before minus unbiased MMD after: [0.42412418 0.26664156 0.09190416 0.47962302 0.16090262]
--
Greed: 2
Correlation between alpha and unbiased MMD: -0.9930251543431068
unbiased MMD before minus unbiased MMD after: [0.42412496 0.26669693 0.09192497 0.47967356 0.16094679]
--
Greed: 4
Correlation between alpha and unbiased MMD: -0.9931351225630042
unbiased MMD before minus unbiased MMD after: [0.42418116 0.26669127 0.09195203 0.47977161 0.16095859]
--
Greed: 8
Correlation between alpha and unbiased MMD: -0.9930004354368872
unbiased MMD before minus unbiased MMD after: [0.42417091 0.26671934 0.09194672 0.47969413 0.16098809]
--
Split: unequal
Greed: 1
Correlation between alpha and unbiased MMD: -0.9357005278399146
unbiased MMD before minus unbiased MMD after: [-1.65328383e-05 -2.14450061e-04  8.80095959e-02  4.05100137e-02
  7.44933337e-02]
--
Greed: 2
Correlation between a

In [17]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        for greed in greeds:
            print("Greed: {}".format(greed))
            alpha = results_dict[ds][split][greed]['alpha']
            wass = results_dict[ds][split][greed]['wass_after']
            dkl = results_dict[ds][split][greed]['dkls_after']
            wass_before = results_dict[ds][split][greed]['wass_before']
            dkl_before = results_dict[ds][split][greed]['dkls_before']
            num_rewards = results_dict[ds][split][greed]['num_rewards']
            mmd_unbiased_before = results_dict[ds][split][greed]['mmd_unbiased_before']
            mmd_unbiased_after = results_dict[ds][split][greed]['mmd_unbiased_after']
            
            print("MMD unbiased: {}".format(mmd_unbiased_after))
            print("--")
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
Greed: 1
MMD unbiased: [0.04220765829086304, 0.01683443784713745, -5.8650970458984375e-05, 0.054311156272888184, 0.006238877773284912]
--
Greed: 2
MMD unbiased: [0.04220688343048096, 0.016779065132141113, -7.94529914855957e-05, 0.05426061153411865, 0.006194710731506348]
--
Greed: 4
MMD unbiased: [0.042150676250457764, 0.016784727573394775, -0.00010651350021362305, 0.054162561893463135, 0.00618290901184082]
--
Greed: 8
MMD unbiased: [0.04216092824935913, 0.016756653785705566, -0.00010120868682861328, 0.05424004793167114, 0.006153404712677002]
--
Split: unequal
Greed: 1
MMD unbiased: [-6.861239671707153e-05, -3.2708048820495605e-05, 0.0007284432649612427, 1.2218952178955078e-05, 0.0005043745040893555]
--
Greed: 2
MMD unbiased: [-0.00012006610631942749, -5.5439770221710205e-05, 0.0006596073508262634, -6.047636270523071e-05, 0.0004238709807395935]
--
Greed: 4
MMD unbiased: [-0.00019358843564987183, -8.587539196014404e-05, 0.000611230731010437, -0.000143587

In [20]:
#pickle.dump(results_dict, open('data/results_dict.p', 'wb'))

In [3]:
#results_dict = pickle.load(open('data/results_dict.p', 'rb'))

In [6]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        corrs_num_rewards = []
        corrs_mmd = []
        for i in range(num_parties):
            #print("Party {}".format(i))
            num_rewards_greeds = []
            mmd_unbiased_greeds = []
            for greed in greeds:
                num_rewards_greeds.append(results_dict[ds][split][greed]['num_rewards'][i])
                mmd_unbiased_greeds.append(results_dict[ds][split][greed]['mmd_unbiased_after'][i])
            #print("Correlation between greeds and num_rewards:{}".format(np.corrcoef(greeds, num_rewards_greeds)[0, 1]))
            #print("Correlation between greeds and mmd_unbiased:{}".format(np.corrcoef(greeds, mmd_unbiased_greeds)[0, 1]))
            corrs_num_rewards.append(np.corrcoef(greeds, num_rewards_greeds)[0, 1])
            corrs_mmd.append(np.corrcoef(greeds, mmd_unbiased_greeds)[0, 1])
        print("Correlation between greeds and num_rewards mean and std err: {}, {}".format(np.mean(corrs_num_rewards),
                                                                                           stats.sem(corrs_num_rewards)))
        print("Correlation between greeds and mmd_unbiased mean and std err: {}, {}".format(np.mean(corrs_mmd),
                                                                                           stats.sem(corrs_mmd)))
        print("--")
    print("========")
print("################################")

Dataset: gmm
Split: equaldisjoint
Correlation between greeds and num_rewards mean and std err: -0.8505166460124054, 0.023281930215888723
Correlation between greeds and mmd_unbiased mean and std err: -0.7366033487399511, 0.07942287359198327
--
Split: unequal
Correlation between greeds and num_rewards mean and std err: -0.8348782068325195, 0.006490503788429742
Correlation between greeds and mmd_unbiased mean and std err: -0.9280641402120571, 0.027417744481543358
--
Dataset: mnist
Split: equaldisjoint
Correlation between greeds and num_rewards mean and std err: -0.86932429690349, 0.005061197203443362
Correlation between greeds and mmd_unbiased mean and std err: -0.6785470743652321, 0.10008327852506307
--
Split: unequal
Correlation between greeds and num_rewards mean and std err: -0.8372149050592732, 0.01638469337773892
Correlation between greeds and mmd_unbiased mean and std err: -0.867841118835855, 0.018335382613781516
--
Dataset: cifar
Split: equaldisjoint
Correlation between greeds and

### Class imbalance

In [99]:
### Reduce all party_labels to 5000
for ds in datasets:
    for split in splits:
        for greed in greeds:
            results_dict[ds][split][greed]['party_labels'] = [labels[:5000] for labels in results_dict[ds][split][greed]['party_labels']]
            print([len(labels) for labels in results_dict[ds][split][greed]['party_labels']])

[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[1000, 1000, 1000, 1000, 1000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]
[5000, 5000, 5000, 5000, 5000]


In [112]:
ds = 'cifar'

In [113]:
party_datasets = results_dict[ds]['equaldisjoint'][1]['party_datasets']
party_labels = results_dict[ds]['equaldisjoint'][1]['party_labels']
all_dataset = np.concatenate(party_datasets)
all_labels = np.concatenate(party_labels)

In [114]:
init_props = []
for ds in party_datasets:
    init_props.append(class_proportion(get_classes(np.array(ds), all_dataset, all_labels), num_classes=5))

In [117]:
for ds in datasets:
    if ds == 'gmm':
        num_classes = 5
    else:
        num_classes = 10
    for split in splits:
        for greed in greeds:
            party_datasets = results_dict[ds][split][greed]['party_datasets']
            party_labels = results_dict[ds][split][greed]['party_labels']
            rewards = results_dict[ds][split][greed]['rewards']
            candidate_dataset = results_dict[ds][split][greed]['candidate_datasets'][0]
            candidate_labels = results_dict[ds][split][greed]['candidate_labels']
            
            imba_after = []
            for i in range(num_parties):
                party_dataset = party_datasets[i]
                party_label = party_labels[i]
                party_ds_with_rewards = np.concatenate([party_dataset, rewards[i]], axis=0)
                all_dataset = np.concatenate([party_dataset, candidate_dataset], axis=0)
                all_labels = np.concatenate([party_label, candidate_labels], axis=0)
                imba_after.append(class_proportion(get_classes(party_ds_with_rewards, 
                                                               all_dataset, 
                                                               all_labels), num_classes)[1])
                
            results_dict[ds][split][greed]['imba_after'] = imba_after

In [120]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        for greed in greeds:
            print("Greed: {}".format(greed))
            alpha = results_dict[ds][split][greed]['alpha']
            imba_after = results_dict[ds][split][greed]['imba_after']
            
            print("Correlation between alpha and class imbalance: {}".format(np.corrcoef(alpha, imba_after)[0, 1]))
            print("--")
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
Greed: 1
Correlation between alpha and class imbalance: -0.9055507118345965
--
Greed: 2
Correlation between alpha and class imbalance: -0.7970402614929165
--
Greed: 4
Correlation between alpha and class imbalance: -0.6600381822338242
--
Greed: 8
Correlation between alpha and class imbalance: -0.5816821336536531
--
Split: unequal
Greed: 1
Correlation between alpha and class imbalance: -0.7868825598065387
--
Greed: 2
Correlation between alpha and class imbalance: -0.7095696869692971
--
Greed: 4
Correlation between alpha and class imbalance: -0.601985721293666
--
Greed: 8
Correlation between alpha and class imbalance: -0.5065659762747298
--
################################
Dataset: mnist
Split: equaldisjoint
Greed: 1
Correlation between alpha and class imbalance: -0.9852283067811822
--
Greed: 2
Correlation between alpha and class imbalance: -0.9744813304356874
--
Greed: 4
Correlation between alpha and class imbalance: -0.9601079817399606
--
Greed: 8
Corre

### Get mean and standard deviation of correlations across greeds

In [128]:
for ds in datasets:
    print("Dataset: {}".format(ds))
    for split in splits:
        print("Split: {}".format(split))
        all_wass = []
        all_dkl = []
        all_num_rewards = []
        all_mmd_u = []
        all_imba = []
        for greed in greeds:
            alpha = results_dict[ds][split][greed]['alpha']
            wass = results_dict[ds][split][greed]['wass_after']
            dkl = results_dict[ds][split][greed]['dkls_after']
            num_rewards = results_dict[ds][split][greed]['num_rewards']
            mmd_u = results_dict[ds][split][greed]['mmd_unbiased_after']
            imba = results_dict[ds][split][greed]['imba_after']
            
            all_wass.append(np.corrcoef(alpha, wass)[0, 1])
            all_dkl.append(np.corrcoef(alpha, dkl)[0, 1])
            all_num_rewards.append(np.corrcoef(alpha, num_rewards)[0, 1])
            all_mmd_u.append(np.corrcoef(alpha, mmd_u)[0, 1])
            all_imba.append(np.corrcoef(alpha, imba)[0, 1])
        
        results_dict[ds][split]['all_wass'] = all_wass
        results_dict[ds][split]['all_dkl'] = all_dkl
        results_dict[ds][split]['all_num_rewards'] = all_num_rewards
        results_dict[ds][split]['all_mmd_u'] = all_mmd_u
        results_dict[ds][split]['all_imba'] = all_imba
        
        print("MMD: {} +/- {} \n DKL: {} +/- {} \n Wass: {} +/- {} \n Class imbalance: {} +/- {} \n num rewards: {} +/- {} |".format(
        np.mean(all_mmd_u),
        stats.sem(all_mmd_u),
        np.mean(all_dkl),
        stats.sem(all_dkl),
        np.mean(all_wass),
        stats.sem(all_wass),
        np.mean(all_imba),
        stats.sem(all_imba),
        np.mean(all_num_rewards),
        stats.sem(all_num_rewards)
        ))
        
        print("========")
    print("################################")

Dataset: gmm
Split: equaldisjoint
MMD: -0.9930482961663323 +/- 2.9742520913144933e-05 
 DKL: -0.8366315972769915 +/- 0.050631133971426094 
 Wass: -0.9859351155331861 +/- 0.0004911574711359458 
 Class imbalance: -0.7360778223037476 +/- 0.07191287252266271 
 num rewards: 0.693774366740024 +/- 0.01887535052671798 |
Split: unequal
MMD: -0.918162171275308 +/- 0.007324203558815758 
 DKL: -0.5116205140004344 +/- 0.2132378159023349 
 Wass: -0.8116895724506624 +/- 0.021350500857664236 
 Class imbalance: -0.6512509860860579 +/- 0.0613445277658431 
 num rewards: 0.7192342221552249 +/- 0.018753857210088758 |
################################
Dataset: mnist
Split: equaldisjoint
MMD: -0.9996314027847917 +/- 1.6053033197138224e-06 
 DKL: 0.30500713025134274 +/- 0.12534292123855745 
 Wass: -0.9632966325765053 +/- 0.00049757114224472 
 Class imbalance: -0.9672968513814587 +/- 0.007885711400719825 
 num rewards: 0.7615549477870012 +/- 0.00883237032004705 |
Split: unequal
MMD: -0.9916158114544429 +/- 0.00