In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('../../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import gpytorch
from tqdm.notebook import trange
import heapq
import math
import pickle
import itertools
from utils.class_imbalance import get_classes, class_proportion

from algorithms.cgm import *

## Dataset

In [None]:
def sample_GMM(means, covs, num_samples):
    """
    Samples equally from clusters of normal distributions.
    """
    assert(means.shape[0] == covs.shape[0])
    assert(means.shape[1] == covs.shape[1])
    assert(covs.shape[1] == covs.shape[2])
    
    n = means.shape[0]
    d = means.shape[1]
    samples = np.zeros((num_samples, d))
    clusters = np.zeros(num_samples, dtype=np.int32)
    
    for i in range(num_samples):
        cluster = np.random.randint(n)
        samples[i] = np.random.multivariate_normal(means[cluster], covs[cluster], check_valid='raise')
        clusters[i] = cluster
    
    return samples, clusters

In [None]:
num_clusters = 5
d = 2
num_samples = 1000

In [None]:
np.random.seed(2)

In [None]:
means = np.random.uniform(size=(num_clusters, d))
covs = np.zeros((num_clusters, d, d))
for i in range(num_clusters):
    covs[i] = np.eye(d)/200

In [None]:
train_sets = np.zeros((num_clusters, num_samples, d))
test_sets = np.zeros((num_clusters, num_samples, d))

In [None]:
for i in range(num_clusters):
    train_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')
    test_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')

In [None]:
plt.figure(figsize=(10, 6), dpi=300)
for i in range(num_clusters):
    plt.scatter(train_sets[i, :, 0], train_sets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))
    plt.legend()

## Equal disjoint

In [None]:
num_parties = 5

In [None]:
disjoint_prop = np.eye(5)

In [None]:
party_datasets = split_proportions(train_sets, disjoint_prop)

In [None]:
# Check
plt.figure(figsize=(10, 6), dpi=300)
plt.xlim(0, 0.8)
plt.ylim(-0.2, 1.0)
for i in range(num_parties):
    if i == 4:
        plt.scatter(party_datasets[i, :, 0], party_datasets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))

plt.legend()

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [0.5, 0.5]
kernel.outputscale = 1

In [None]:
#reference_dataset, reference_labels = sample_GMM(means, covs, num_samples=5000)

In [None]:
#perm_samp_dataset, perm_samp_labels = sample_GMM(means, covs, num_samples=5000)

In [None]:
perm_samp_dataset = np.concatenate(party_datasets)

In [None]:
reference_dataset = np.concatenate(party_datasets)

In [None]:
v = get_v(party_datasets, reference_dataset, kernel)

In [None]:
v

In [None]:
phi = shapley(v, num_parties)
print(phi)

In [None]:
alpha = norm(phi)
print(alpha)

In [None]:
1e-03

In [None]:
# etas = [0.001, 0.01, 0.1, 0.25, 0.5]
# all_sorted_vX = []
# for eta in etas:
#     sorted_vX = perm_sampling_neg_biased(perm_samp_dataset, reference_dataset, kernel, num_perms=200, eta=eta)
#     all_sorted_vX.append(sorted_vX)
#     print("Eta = {} - Mean:{} \\ Median:{} \\ Min: {} \\ Max: {}".format(eta, np.mean(sorted_vX), np.median(sorted_vX), np.min(sorted_vX), np.max(sorted_vX)))

In [None]:
# all_sorted_vX_variant = []
# for eta in etas:
#     sorted_vX = perm_sampling_neg_biased_variant(perm_samp_dataset, reference_dataset, kernel, num_perms=200, eta=eta)
#     all_sorted_vX_variant.append(sorted_vX)
#     print("Eta = {} - Mean:{} \\ Median:{} \\ Min: {} \\ Max: {}".format(eta, np.mean(sorted_vX), np.median(sorted_vX), np.min(sorted_vX), np.max(sorted_vX)))

In [None]:
vN = get_vN(v, num_parties)
print(vN)

In [None]:
v_is = get_v_is(v, num_parties)
print(v_is)

## R6

In [None]:
q, rho = get_q_rho(alpha, v_is, vN, phi, v, cond="R6")

In [None]:
rho

In [None]:
v_is

In [None]:
#all condition
r = list(map(q, alpha))
print(r)

In [None]:
num_candidate_points = 8000
gmm_clusters = [sample_GMM(means, covs, num_candidate_points) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])
cand_datasets = np.array([gmm[0]]*num_parties)

In [None]:
greeds = np.ones(num_parties) * 3

In [None]:
rewards, deltas, mus = reward_realization(cand_datasets, 
                                          reference_dataset, 
                                          r, 
                                          party_datasets, 
                                          kernel, 
                                          greeds=greeds,
                                          rel_tol=1e-5)

In [None]:
x = np.array([0.1])

In [None]:
x

In [None]:
len(np.delete(x, 0))

In [None]:
pickle.dump((gmm, clusters, reference_dataset, cand_datasets, party_datasets, greeds, rewards, deltas, mus), open("results/CGM-GMM-rho-equaldisjoint-greed3-stable.p", "wb"))

In [None]:
class_props = []
for result in rewards:
        class_props.append(class_proportion(get_classes(np.array(result), gmm[0], clusters[0]), num_clusters))

In [None]:
class_props

In [None]:
for i in range(num_parties):
    print(mmd_neg_biased(np.concatenate([party_datasets[i], np.array(rewards[i])], axis=0), reference_dataset, kernel)[0])