In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import gpytorch
from tqdm.notebook import trange
import heapq
import math
import pickle
from algorithms.cd import con_div
from algorithms.ccr import con_conv_rate
from utils.class_imbalance import get_classes, class_proportion

## Dataset

In [None]:
def sample_GMM(means, covs, num_samples):
    """
    Samples equally from clusters of normal distributions.
    """
    assert(means.shape[0] == covs.shape[0])
    assert(means.shape[1] == covs.shape[1])
    assert(covs.shape[1] == covs.shape[2])
    
    n = means.shape[0]
    d = means.shape[1]
    samples = np.zeros((num_samples, d))
    clusters = np.zeros(num_samples, dtype=np.int32)
    
    for i in range(num_samples):
        cluster = np.random.randint(n)
        samples[i] = np.random.multivariate_normal(means[cluster], covs[cluster], check_valid='raise')
        clusters[i] = cluster
    
    return samples, clusters

In [None]:
num_clusters = 5
d = 2
num_samples = 1000

In [None]:
np.random.seed(2)

In [None]:
means = np.random.uniform(size=(num_clusters, d))
covs = np.zeros((num_clusters, d, d))
for i in range(num_clusters):
    covs[i] = np.eye(d)/200

In [None]:
train_sets = np.zeros((num_clusters, num_samples, d))
test_sets = np.zeros((num_clusters, num_samples, d))

In [None]:
for i in range(num_clusters):
    train_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')
    test_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')

In [None]:
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "sans-serif"})

plt.figure(figsize=(10, 6), dpi=300)
#plt.gca().set_aspect('equal', adjustable='box')
for i in range(num_clusters):
    plt.scatter(train_sets[i, :, 0], train_sets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))

    plt.legend()

## Data valuation 

In [None]:
from utils.mmd import mmd
from sympy.utilities.iterables import multiset_permutations

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
def shapley_mmd(parties_datasets, reference_dataset, kernel):
    """
    """
    num_parties = len(parties_datasets)
    shapley_sums = np.zeros(num_parties)
    
    perms = multiset_permutations([i for i in range(num_parties)])
    
    for perm in perms:
        print(perm)
        current_neg_mmd = 0
        for i in range(num_parties):
            party = perm[i]
            if i == 0:
                current_dataset = parties_datasets[party]
            else:
                current_dataset = np.concatenate([current_dataset, parties_datasets[party]])
           
            prev_neg_mmd = current_neg_mmd
            current_neg_mmd = -mmd(current_dataset, reference_dataset, kernel)[0]
            diff = current_neg_mmd - prev_neg_mmd
            shapley_sums[party] += diff
    
    return (1/math.factorial(num_parties)) * shapley_sums

In [None]:
reference_datasets, reference_labels = sample_GMM(means, covs, num_samples * num_clusters)

In [None]:
shap = shapley_mmd(test_sets, reference_datasets, kernel)

In [None]:
proportions = np.array([[0.2, 0.2, 0.2, 0.2, 0.2],
                        [0.2, 0.2, 0.2, 0.2, 0.2],
                        [0.6, 0.1, 0.1, 0.1, 0.1],
                        [0.0, 0.5, 0.5, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.5, 0.5]])

In [None]:
def split_proportions(dataset, proportions):
    """
    :param dataset: array of shape (num_classes, N, d).
    :param proportions: array of probability simplices of shape (num_classes, num_classes). Must sum to 1 along
    all rows and columns
    """
    num_classes, N, d = dataset.shape
    split_datasets = [[] for i in range(num_classes)]
    dataset_idx = [0 for i in range(num_classes)]
    
    for i in range(num_classes):
        for j in range(num_classes):
            prop = proportions[i, j]
            for k in range(int(prop * N)):
                split_datasets[i].append(dataset[j, dataset_idx[j]])
                dataset_idx[j] += 1
    
    return np.array(split_datasets)

In [None]:
split_datasets = split_proportions(test_sets, proportions)

In [None]:
shap = shapley_mmd(split_datasets, reference_datasets, kernel)

In [None]:
shap

In [None]:
np.sum(shap)

In [None]:
-mmd(reference_datasets, reference_datasets, kernel)[0]

In [None]:
for ds in split_datasets:
    print(-mmd(ds, reference_datasets, kernel)[0])

In [None]:
-mmd(np.concatenate([split_datasets[0], split_datasets[1]]), reference_datasets, kernel)[0]

In [None]:
split_datasets[0]

## Controlled divergence (CD)

### All clusters

In [None]:
num_candidate_points = 10000
num_parties = 10

phi = np.linspace(0.05, 1, num_parties)

gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])

reference = sample_GMM(means, covs, num_samples)[0]
candidates = np.array([gmm[0]]*num_parties)

greeds = np.ones(num_parties)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
cd_all_res = []
cd_all_deltas = []
cd_all_mus = []

In [None]:
eta = 100/((len(candidates) + len(reference))/2) 

for i in range(num_clusters):
    D = np.array([test_sets[i]] * num_parties)
    res, deltas, mus = con_div(candidates, reference, phi, D, kernel, num_perms=1000, greeds=greeds, eta=eta)
    cd_all_res.append(res)
    cd_all_deltas.append(deltas)
    cd_all_mus.append(mus)

In [None]:
pickle.dump((gmm, clusters, reference, candidates, test_sets, greeds, cd_all_res, cd_all_deltas, cd_all_mus), open("CDWS-allclusters10000cands.p", "wb"))

In [None]:
(gmm, clusters, reference, candidates, test_sets, greeds, cd_all_res, cd_all_deltas, cd_all_mus) = pickle.load(open("results/CDWS-allclusters-10000cands.p", "rb"))

In [None]:
for i in range(num_clusters):
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(phi, [len(result) for result in cd_all_res[i]])
    plt.xlabel("$\phi$")
    plt.ylabel("Number of points added")
    plt.title("Cluster {}".format(i))

In [None]:
all_class_props = []
all_bad_props = []
for i in range(num_clusters):
    class_props = []
    bad_props = []
    res = cd_all_res[i]
    for result in res:
        class_props.append(class_proportion(get_classes(np.array(result), gmm[0], clusters[0]) + 
                                            [i for _ in range(len(test_sets[i]))], num_clusters))
        bad_props.append(class_proportion([i for _ in range(len(test_sets[i]))] + 
                                          list(np.random.randint(0, num_clusters, len(result))), num_clusters))
    all_class_props.append(class_props)
    all_bad_props.append(bad_props)

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

In [None]:
phi = np.linspace(0.05, 1, 10)

In [None]:
for i in range(num_clusters):
    plt.figure(figsize=(3, 3), dpi=300)
    plt.plot(phi, [prop[1] for prop in all_class_props[i]], label="CD", color=cm.get_cmap('Spectral')(0.9))
    plt.plot(phi, [prop[1] for prop in all_bad_props[i]], label="Random sampling", color=cm.get_cmap('Spectral')(0.1))
    plt.xlabel("$\phi$", fontsize=16)
    plt.ylabel("$\\rho$", fontsize=16)
    plt.legend()
    plt.title("Party {} (GMM cluster {})".format(i+1, i+1))

In [None]:
all_corrcoef = []
for i in range(num_clusters):
    class_props = all_class_props[i]
    props = [pair[1] for pair in class_props]
    all_corrcoef.append(np.corrcoef(np.array(list(zip(phi, props))).T)[0,1])
print("Average correlation coefficient: {}".format(np.mean(all_corrcoef)))

In [None]:
for cluster in range(num_clusters):
    res = cd_all_res[cluster]
    party = 9  # Look at highest reward
    
    plt.figure(figsize=(12, 6), dpi=300)
    for i in range(num_clusters):
        if i != party:
            plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    plt.scatter(test_sets[cluster, :, 0], test_sets[cluster, :, 1], s=10, color=cm.get_cmap('Set1')(0*(1/9)), label="Party {}".format(party))

    added = np.array(res[party])
    alphas = [1-i*(1/len(added)) for i in range(len(added))]
    rgba_colors = np.zeros((len(added),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=10, color=rgba_colors, label="Added")
    
    plt.xlabel("$x_0$")
    plt.ylabel("$x_1$")
    
    plt.title("Party {}, $\phi = 1.0$".format(cluster))
    plt.legend()

## Controlled convergence rate (CCR)

### All clusters

In [None]:
num_candidate_points = 2000
num_parties = 10

gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])

reference = sample_GMM(means, covs, num_samples)[0]
candidates = np.array([gmm[0]]*num_parties)
phi = np.linspace(0.1, 1, num_parties)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
ccr_all_res = []
ccr_all_deltas = []
ccr_all_mus = []

In [None]:
for i in range(num_clusters):
    D = np.array([test_sets[i]] * num_parties)
    res, deltas, mus = con_conv_rate(candidates, reference, phi, D, kernel)
    ccr_all_res.append(res)
    ccr_all_deltas.append(deltas)
    ccr_all_mus.append(mus)

In [None]:
pickle.dump((gmm, clusters, reference, candidates, phi, test_sets, ccr_all_res, ccr_all_deltas, ccr_all_mus), open("CCR-allclusters.p", "wb"))

In [None]:
(gmm, clusters, reference, candidates, phi, test_sets, ccr_all_res, ccr_all_deltas, ccr_all_mus) = pickle.load(open("results/CCR-allclusters.p", "rb"))

In [None]:
for j in range(num_clusters):
    mus = ccr_all_mus[j]
    x = list(range(1, len(mus[0])+1))
    plt.figure(figsize=(3, 3), dpi=300)
    phi_labels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    for i in range(len(mus)):
        if int(phi[i]*10) % 2 == 0:
            plt.plot(x, mus[i], 'C0', linewidth=1, color=cm.get_cmap('Spectral')(phi[i]), label="$\phi = ${}".format(phi_labels[i]))

    plt.legend()
    plt.title("Party {} (GMM cluster {})".format(j+1, j+1))
    plt.ylabel("$z(D \cup R_i)$", fontsize=16)
    plt.xlabel("$|R_i|$", fontsize=16)

In [None]:
all_Es = []
all_class_prop_AUCs = []
all_corrcoeff = []

for cluster in range(num_clusters):
    R = ccr_all_res[cluster]
    class_props = [[] for i in range(num_parties)]
    num_candidate_points = candidates.shape[1]
    deltas = ccr_all_deltas[cluster]
    
    for i in range(num_parties):
        reward_set = R[i]
        classes = get_classes(np.array(reward_set), gmm[0], clusters[0])
        for j in range(num_candidate_points):
            current_classes = classes[:j+1] + [cluster for k in range(len(test_sets[cluster]))]
            class_props[i].append(class_proportion(current_classes, num_clusters)[1])
    
    Es = []
    class_prop_AUCs = []
    for i in range(num_parties):
        delta = np.array(deltas[i])
        Es.append(np.sum(delta[:-1] * np.arange(num_candidate_points-1, 0, -1)))
        props = np.array(class_props[i])
        class_prop_AUCs.append(np.sum(props[:-1] * np.arange(num_candidate_points-1, 0, -1)))
    
    all_Es.append(Es)
    all_class_prop_AUCs.append(class_prop_AUCs)
    
    all_corrcoeff.append(np.corrcoef(np.array(list(zip(Es, class_prop_AUCs))).T)[0, 1])

print("Average correlation coefficient: {}".format(np.mean(all_corrcoeff)))

In [None]:
for cluster in range(num_clusters):
    reward = 9
    R = ccr_all_res[cluster]
    plt.figure(figsize=(12, 6), dpi=300)
    plt.scatter(test_sets[cluster, :, 0], test_sets[cluster, :, 1], s=20, color=cm.get_cmap('Set1')(0*(1/9)), label="Party")
    for i in range(num_clusters):
        plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    
    added = np.array(R[reward])
    alphas = [1-i*(1/len(R[reward])) for i in range(len(R[reward]))]
    rgba_colors = np.zeros((len(R[reward]),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=20, color=rgba_colors, label="Added")
    plt.legend()
    plt.title("Cluster {}, $\phi = {}$".format(cluster, phi[reward]))
    
    plt.xlabel("$x_0$")
    plt.ylabel("$x_1$")

## CD: varying precision hyperparameter $\eta$ (permutation sampling)

In [None]:
from utils.mmd import perm_sampling
import scipy.stats as stats

In [None]:
num_candidate_points = 10000
num_parties = 10

gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
X = gmm_clusters[0][0]
Y = gmm_clusters[1][0]

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
log_etas = np.linspace(np.log(0.025), np.log(1.), 10)
etas = np.exp(log_etas)

In [None]:
etas

In [None]:
all_samps = []
for eta in etas:
    samps = perm_sampling(X[:4000], Y[:4000], kernel, eta=eta)
    all_samps.append(samps)

In [None]:
plt.figure(figsize=(12, 6), dpi=300)

plt.plot(etas, [np.std(samp) for samp in all_samps], label="Permutation sampling")
plt.plot(etas, 0.0001*np.sqrt(1/etas), label="$O(\sqrt {1/\eta})$")
    
plt.xlabel("$\eta$")
plt.ylabel("Standard deviation")

plt.legend()

In [None]:
all_x = []
all_density = []
for i in range(len(all_samps)):
    bins = np.histogram(all_samps[i], bins=50)[1]
    interval = bins[1] - bins[0]
    bins = np.concatenate(([bins[0] - interval*i for i in range(7, 0, -1)], bins))
    density = stats.gaussian_kde(all_samps[i])
    n, x, _ = plt.hist(all_samps[0], bins=bins, 
                   histtype=u'step', density=True)  
    all_x.append(x)
    all_density.append(density)

In [None]:
plt.figure(figsize=(12, 6), dpi=300)
plt.title("Effect of $\eta$ on variance of $\widehat{MMD}^2$ distribution")

for i in range(len(all_samps)):
    x = all_x[i]
    density = all_density[i]
    plt.plot(x, density(x), label="$\eta = {}$".format(etas[i]), color=cm.get_cmap('Spectral')(i*0.1), linewidth=2)
    plt.legend()
    plt.ylabel("Density")
    plt.xlabel("$\widehat{MMD}^2$")

### Effect on number of points distributed

In [None]:
num_candidate_points = 10000
num_parties = 1

In [None]:
gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])

In [None]:
reference = sample_GMM(means, covs, num_samples)[0]
candidates = np.array([gmm[0]]*num_parties)
D = np.array([test_sets[0]] * num_parties)

In [None]:
phi = np.zeros(num_parties)
greeds = np.ones(num_parties)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
log_etas = np.linspace(np.log(0.025), np.log(1.), 10)
etas = np.exp(log_etas)

In [None]:
etas

In [None]:
eta_all_res = []
eta_all_deltas = []
eta_all_mus = []

for eta in etas:
    res, deltas, mus = con_div(candidates, reference, phi, D, kernel, num_perms=1000, greeds=greeds, eta=eta)
    eta_all_res.append(res)
    eta_all_deltas.append(deltas)
    eta_all_mus.append(mus)

In [None]:
[len(res[0]) for res in eta_all_res]

In [None]:
pickle.dump((gmm, clusters, reference, candidates, phi, test_sets, etas, eta_all_res, eta_all_deltas, eta_all_mus), open("CD-alletas.p", "wb"))

## CD: varying greed hyperparameter $\gamma$

In [None]:
num_candidate_points = 10000
num_parties = 10

In [None]:
phi = [0] * num_parties

In [None]:
gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])

In [None]:
reference = sample_GMM(means, covs, num_samples)[0]
candidates = np.array([gmm[0]]*num_parties)
D = np.array([test_sets[2]] * num_parties)

In [None]:
greeds = list(np.exp(np.linspace(-2, 2, num_parties-2)))
greeds.insert(0, 0)
greeds.append(-1)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
eta = 100/((len(candidates) + len(reference))/2) 
res, deltas, mus = con_div(candidates, reference, phi, D, kernel, num_perms=1000, greeds=greeds, eta=eta)

In [None]:
plt.figure(figsize=(12, 6), dpi=300)
plt.plot(greeds[:-1], [len(result) for result in res[:-1]])
plt.hlines(len(res[-1]), xmin=0, xmax=greeds[-2], color=cm.get_cmap('Set1')(0*(1/9)), label="Greedy")
plt.legend()
plt.xlabel("Greed factor")
plt.ylabel("Number of points added")

In [None]:
class_props = []
for result in res:
    class_props.append(class_proportion(get_classes(np.array(result), gmm[0], clusters[0]) + [2 for i in range(len(D[0]))], num_clusters))

In [None]:
plt.figure(figsize=(12, 6), dpi=300)

for i in range(num_parties):
    plt.plot(greeds[:-1], [prop[1] for prop in class_props[:-1]])
    
plt.xlabel("$\gamma$")
plt.ylabel("Class imbalance")

In [None]:
for party in range(num_parties):
    plt.figure(figsize=(12, 6), dpi=300)
    #plt.gca().set_aspect('equal', adjustable='box')
    for i in range(num_clusters):
        if i != party:
            plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    plt.scatter(D[party, :, 0], D[party, :, 1], s=10, color=cm.get_cmap('Set1')(0*(1/9)), label="Party {}".format(party))

    added = np.array(res[party])
    alphas = [1-i*(1/len(added)) for i in range(len(added))]
    rgba_colors = np.zeros((len(added),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=10, color=rgba_colors, label="Added")