In [1]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import gpytorch
from tqdm.notebook import trange
import heapq
import math
import pickle
import itertools
from algorithms.cd import con_div
from algorithms.ccr import con_conv_rate
from utils.class_imbalance import get_classes, class_proportion

from algorithms.cgm import *

## Dataset

In [3]:
def sample_GMM(means, covs, num_samples):
    """
    Samples equally from clusters of normal distributions.
    """
    assert(means.shape[0] == covs.shape[0])
    assert(means.shape[1] == covs.shape[1])
    assert(covs.shape[1] == covs.shape[2])
    
    n = means.shape[0]
    d = means.shape[1]
    samples = np.zeros((num_samples, d))
    clusters = np.zeros(num_samples, dtype=np.int32)
    
    for i in range(num_samples):
        cluster = np.random.randint(n)
        samples[i] = np.random.multivariate_normal(means[cluster], covs[cluster], check_valid='raise')
        clusters[i] = cluster
    
    return samples, clusters

In [4]:
num_clusters = 5
d = 2
num_samples = 1000

In [5]:
np.random.seed(2)

In [6]:
means = np.random.uniform(size=(num_clusters, d))
covs = np.zeros((num_clusters, d, d))
for i in range(num_clusters):
    covs[i] = np.eye(d)/200

In [7]:
train_sets = np.zeros((num_clusters, num_samples, d))
test_sets = np.zeros((num_clusters, num_samples, d))

In [8]:
for i in range(num_clusters):
    train_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')
    test_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')

In [9]:
# plt.figure(figsize=(10, 6), dpi=300)
# for i in range(num_clusters):
#     plt.scatter(train_sets[i, :, 0], train_sets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))
#     plt.legend()

## Unequal split

In [10]:
num_parties = 5

In [11]:
unequal_prop = np.array([[0.2, 0.2, 0.2, 0.2, 0.2], 
                         [0.2, 0.2, 0.2, 0.2, 0.2],
                         [0.6, 0.4, 0.0, 0.0, 0.0],
                         [0.0, 0.2, 0.6, 0.2, 0.0],
                         [0.0, 0.0, 0.0, 0.4, 0.6]])

In [12]:
party_datasets = split_proportions(train_sets, unequal_prop)

In [13]:
# # Check
# plt.figure(figsize=(10, 6), dpi=300)
# plt.xlim(0, 0.8)
# plt.ylim(-0.2, 1.0)
# for i in range(num_parties):
#     if i == 0:
#         plt.scatter(party_datasets[i, :, 0], party_datasets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))

# plt.legend()

In [14]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [0.05, 0.05]
kernel.outputscale = 1

In [15]:
perm_samp_dataset = np.concatenate(party_datasets)
reference_dataset = np.concatenate(party_datasets)

In [16]:
v = get_v(party_datasets, reference_dataset, kernel)

In [17]:
v

{'{1}': 0.05570573732256889,
 '{2}': 0.05577566847205162,
 '{3}': -0.011796899139881134,
 '{4}': 0.022628188133239746,
 '{5}': 0.004606999456882477,
 '{1, 2}': 0.056074611842632294,
 '{1, 3}': 0.03952357918024063,
 '{1, 4}': 0.047915369272232056,
 '{1, 5}': 0.04281418025493622,
 '{2, 3}': 0.03890645503997803,
 '{2, 4}': 0.04757314175367355,
 '{2, 5}': 0.0437910333275795,
 '{3, 4}': 0.04332202672958374,
 '{3, 5}': 0.047700025141239166,
 '{4, 5}': 0.03924782574176788,
 '{1, 2, 3}': 0.04870377853512764,
 '{1, 2, 4}': 0.05246029049158096,
 '{1, 2, 5}': 0.050514526665210724,
 '{1, 3, 4}': 0.05072297900915146,
 '{1, 3, 5}': 0.052403904497623444,
 '{1, 4, 5}': 0.04855206608772278,
 '{2, 3, 4}': 0.05028881877660751,
 '{2, 3, 5}': 0.05255601555109024,
 '{2, 4, 5}': 0.04882633686065674,
 '{3, 4, 5}': 0.05618234723806381,
 '{1, 2, 3, 4}': 0.053039707243442535,
 '{1, 2, 3, 5}': 0.05416601523756981,
 '{1, 2, 4, 5}': 0.05201444774866104,
 '{1, 3, 4, 5}': 0.056237753480672836,
 '{2, 3, 4, 5}': 0.0562

In [18]:
phi = shapley(v, num_parties)
print(phi)

[0.01871132633338372, 0.018728525067369144, 0.0021091196686029434, 0.01057593896985054, 0.006143640416363875]


In [19]:
alpha = norm(phi)
print(alpha)

[0.9990816824109983, 1.0, 0.1126153640511542, 0.5646968424799815, 0.3280365322023136]


In [20]:
vN = get_vN(v, num_parties)
print(vN)

0.05626855045557022


In [21]:
v_is = get_v_is(v, num_parties)
print(v_is)

[0.05570573732256889, 0.05577566847205162, -0.011796899139881134, 0.022628188133239746, 0.004606999456882477]


## Max condition

In [22]:
best_eta, q = get_eta_q(vN, alpha, v_is, perm_samp_dataset, reference_dataset, kernel, high=0.5, num_iters=5, mode="max")

HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


Iteration 0
current_high=0.5, current_low=0.001
Evaluating for eta = 0.2505


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.2505
Iteration 1
current_high=0.5, current_low=0.2505
Evaluating for eta = 0.37525


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition satisfied, setting current_high to 0.37525
Iteration 2
current_high=0.37525, current_low=0.2505
Evaluating for eta = 0.312875


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.312875
Iteration 3
current_high=0.37525, current_low=0.312875
Evaluating for eta = 0.3440625


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition satisfied, setting current_high to 0.3440625
Iteration 4
current_high=0.3440625, current_low=0.312875
Evaluating for eta = 0.32846875


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.32846875
Iteration 5
current_high=0.3440625, current_low=0.32846875
Evaluating for eta = 0.336265625


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition satisfied, setting current_high to 0.336265625
Iteration 6
current_high=0.336265625, current_low=0.32846875
Evaluating for eta = 0.3323671875


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.3323671875
Iteration 7
current_high=0.336265625, current_low=0.3323671875
Evaluating for eta = 0.33431640625000003


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.33431640625000003
Iteration 8
current_high=0.336265625, current_low=0.33431640625000003
Evaluating for eta = 0.335291015625


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition satisfied, setting current_high to 0.335291015625
Iteration 9
current_high=0.335291015625, current_low=0.33431640625000003
Evaluating for eta = 0.33480371093750005


HBox(children=(HTML(value='Permutation sampling'), FloatProgress(value=0.0, max=200.0), HTML(value='')))


max condition not satisfied, setting current_low to 0.33480371093750005


In [28]:
best_eta

0.335291015625

In [24]:
v_is

[0.05570573732256889,
 0.05577566847205162,
 -0.011796899139881134,
 0.022628188133239746,
 0.004606999456882477]

In [25]:
#all condition
r = list(map(q, alpha))
print(r)

[0.056095633655786514, 0.05626855045557022, 0.055810458958148956, 0.05593058466911316, 0.055879123508930206]


In [26]:
num_candidate_points = 5000
gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])
cand_datasets = np.array([gmm[0]]*num_parties)

In [27]:
greeds = np.ones(num_parties) * 2

In [29]:
rewards, deltas, mus = reward_realization(cand_datasets, 
                                          reference_dataset, 
                                          r, 
                                          party_datasets, 
                                          kernel, 
                                          greeds=greeds)

Running weighted sampling algorithm with -MMD^2 target 0.056095633655786514Running weighted sampling algorithm with -MMD^2 target 0.05626855045557022Running weighted sampling algorithm with -MMD^2 target 0.055810458958148956


Running weighted sampling algorithm with -MMD^2 target 0.05593058466911316Running weighted sampling algorithm with -MMD^2 target 0.055879123508930206



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))








In [32]:
pickle.dump((gmm, clusters, reference_dataset, cand_datasets, party_datasets, greeds, rewards, deltas, mus), open("results/CGM-GMM-unequal-greed2-max.p", "wb"))

In [33]:
class_props = []
for result in rewards:
        class_props.append(class_proportion(get_classes(np.array(result), gmm[0], clusters[0]), num_clusters))

In [34]:
class_props

[(array([0.21576763, 0.2033195 , 0.19087137, 0.21576763, 0.17427386]),
  0.04025068438904289),
 (array([0.19641577, 0.19856631, 0.19641577, 0.2       , 0.20860215]),
  0.04002034917331484),
 (array([0.12568306, 0.17516697, 0.22222222, 0.23649059, 0.24043716]),
  0.0419200481897686),
 (array([0.27412281, 0.16447368, 0.12664474, 0.20833333, 0.22642544]),
  0.042581010503231764),
 (array([0.26326227, 0.26722856, 0.22607833, 0.13138324, 0.1120476 ]),
  0.04432915164089867)]

In [30]:
for i in range(num_parties):
    print(mmd_neg_biased(np.concatenate([party_datasets[i], np.array(rewards[i])], axis=0), reference_dataset, kernel)[0])

0.056041162461042404
0.05621238797903061
0.055756837129592896
0.05587471276521683
0.0558234266936779


In [31]:
r

[0.056095633655786514,
 0.05626855045557022,
 0.055810458958148956,
 0.05593058466911316,
 0.055879123508930206]

## All condition

In [None]:
best_eta, q = get_eta_q(vN, alpha, v_is, perm_samp_dataset, reference_dataset, kernel, mode="all")

In [None]:
best_eta

In [None]:
v_is

In [None]:
#all condition
r = list(map(q, alpha))
print(r)

In [None]:
num_candidate_points = 8000
gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])
cand_datasets = np.array([gmm[0]]*num_parties)

In [None]:
greeds = np.ones(num_parties) * 2

In [None]:
rewards, deltas, mus = reward_realization(cand_datasets, 
                                          reference_dataset, 
                                          r, 
                                          party_datasets, 
                                          kernel, 
                                          greeds=greeds)

In [None]:
pickle.dump((gmm, clusters, reference_dataset, cand_datasets, party_datasets, greeds, rewards, deltas, mus), open("results/CGM-GMM-unequal-greed2-all.p", "wb"))

In [None]:
class_props = []
for result in rewards:
        class_props.append(class_proportion(get_classes(np.array(result), gmm[0], clusters[0]), num_clusters))

In [None]:
class_props

In [None]:
for i in range(num_parties):
    print(mmd_neg_biased(np.concatenate([party_datasets[i], np.array(rewards[i])], axis=0), reference_dataset, kernel)[0])