In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import gpytorch
from tqdm.notebook import trange
import heapq
import math
from algorithms.cd import con_div
from algorithms.ccr import con_conv_rate
from metrics.class_imbalance import get_classes, class_proportion

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

## Dataset

In [None]:
def sample_GMM(means, covs, num_samples):
    """
    Samples equally from clusters of normal distributions.
    """
    assert(means.shape[0] == covs.shape[0])
    assert(means.shape[1] == covs.shape[1])
    assert(covs.shape[1] == covs.shape[2])
    
    n = means.shape[0]
    d = means.shape[1]
    samples = np.zeros((num_samples, d))
    clusters = np.zeros(num_samples, dtype=np.int32)
    
    for i in range(num_samples):
        cluster = np.random.randint(n)
        samples[i] = np.random.multivariate_normal(means[cluster], covs[cluster], check_valid='raise')
        clusters[i] = cluster
    
    return samples, clusters

In [None]:
num_clusters = 5
d = 2
num_samples = 1000

In [None]:
np.random.seed(2)

In [None]:
means = np.random.uniform(size=(num_clusters, d))
covs = np.zeros((num_clusters, d, d))
for i in range(num_clusters):
    covs[i] = np.eye(d)/200

In [None]:
train_sets = np.zeros((num_clusters, num_samples, d))
test_sets = np.zeros((num_clusters, num_samples, d))

In [None]:
for i in range(num_clusters):
    train_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')
    test_sets[i] = np.random.multivariate_normal(means[i], covs[i], size=(num_samples), check_valid='raise')

In [None]:
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "sans-serif"})

plt.figure(figsize=(10, 6), dpi=300)
#plt.gca().set_aspect('equal', adjustable='box')
for i in range(num_clusters):
    plt.scatter(train_sets[i, :, 0], train_sets[i, :, 1], s=2, color=cm.get_cmap('Set1')(i*(1/9)), label="{0}".format(i))

    plt.legend()

## Controlled divergence (CD)

In [None]:
candidates_clusters = [sample_GMM(means, covs, 5000) for i in range(num_clusters)]
candidates = np.array([pair[0] for pair in candidates_clusters])
clusters = np.array([pair[1] for pair in candidates_clusters])

reference = sample_GMM(means, covs, num_samples)
phi = [0, 0.25, 0.5, 0.75, 1]

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
res, deltas, mus = con_div(candidates, reference, phi, test_sets, kernel)

In [None]:
for party in range(5):
    plt.figure(figsize=(12, 6), dpi=300)
    #plt.gca().set_aspect('equal', adjustable='box')
    for i in range(num_clusters):
        if i != party:
            plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    plt.scatter(test_sets[party, :, 0], test_sets[party, :, 1], s=10, color=cm.get_cmap('Set1')(0*(1/9)), label="Party {}".format(party))

    added = np.array(res[party])
    alphas = [1-i*(1/len(added)) for i in range(len(added))]
    rgba_colors = np.zeros((len(added),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=10, color=rgba_colors, label="Added")

### Weighted sampling

In [None]:
num_clusters = 20

In [None]:
gmm_clusters = [sample_GMM(means, covs, 8000) for i in range(num_clusters)]
gmm = np.array([pair[0] for pair in gmm_clusters])
clusters = np.array([pair[1] for pair in gmm_clusters])

In [None]:
reference = sample_GMM(means, covs, num_samples)[0]
candidates = np.array([gmm[0]]*5)
phi = [0, 0, 0, 0, 0]
D = np.array([test_sets[2]] * 5)

In [None]:
greeds = list(np.exp(np.linspace(-2, 2, 18)))
ls.insert(0, 0)
ls.append(-1)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
res, deltas, mus = con_div(candidates, reference, phi, D, kernel, num_perms=100, greeds=greeds)

In [None]:
for party in range(5):
    plt.figure(figsize=(12, 6), dpi=300)
    #plt.gca().set_aspect('equal', adjustable='box')
    for i in range(num_clusters):
        if i != party:
            plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    plt.scatter(D[party, :, 0], D[party, :, 1], s=10, color=cm.get_cmap('Set1')(0*(1/9)), label="Party {}".format(party))

    added = np.array(res[party])
    alphas = [1-i*(1/len(added)) for i in range(len(added))]
    rgba_colors = np.zeros((len(added),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=10, color=rgba_colors, label="Added")

## Controlled convergence rate (CCR)

In [None]:
gmm = np.array([sample_GMM(means, covs, 500) for i in range(num_clusters)])
candidates = np.array([gmm[0]]*5)
phi = [1, 0.75, 0.5, 0.25, 0]
D = np.array([test_sets[0]] * 5)

In [None]:
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
kernel.base_kernel.lengthscale = [1, 1]
kernel.outputscale = 1

In [None]:
R, deltas, mus = con_conv_rate(candidates, reference, phi, D, kernel)

In [None]:
x = list(range(1, len(mus[0])+1))
plt.figure(figsize=(12, 6), dpi=300)

for i in range(len(mus)):
    plt.plot(x, mus[i], 'C0', linewidth=1, color=cm.get_cmap('Set1')(i*(1/9)), label="$\phi = ${}".format(1 - i*0.25))

    plt.legend()

In [None]:
for reward in range(len(R)):
    plt.figure(figsize=(12, 6), dpi=300)
    plt.scatter(D[0, :, 0], D[0, :, 1], s=20, color=cm.get_cmap('Set1')(0*(1/9)), label="Party")
    for i in range(num_clusters):
        plt.scatter(test_sets[i, :, 0], test_sets[i, :, 1], s=0.1, color='grey')
    
    added = np.array(R[reward])
    alphas = [1-i*(1/len(R[reward])) for i in range(len(R[reward]))]
    rgba_colors = np.zeros((len(R[reward]),4))
    rgba_colors[:, 3] = alphas
    rgba_colors[:, :3] = (0.21568627450980393, 0.49411764705882355, 0.7215686274509804)
    plt.scatter(added[:, 0], added[:, 1], s=20, color=rgba_colors, label="Added")
    plt.legend()
    plt.title("$\phi = {}$".format(1 - reward*0.25))