In [None]:
import sys
import os
from os.path import join as oj

os.chdir('../')

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import gpytorch
from tqdm.notebook import trange
import heapq
import math
import pickle
from algorithms.cd import con_div
from algorithms.ccr import con_conv_rate
from utils.class_imbalance import get_classes, class_proportion
from utils.mmd import mmd, perm_sampling

In [None]:
from ast import literal_eval
from utils.utils import tabulate_dict, prepare_loaders, evaluate, init_deterministic, load_dataset
from run import construct_kernel

In [None]:
# gpu_to_use = 1

# # setting device on GPU if available, else CPU
# device = torch.device('cuda:{}'.format(gpu_to_use) if torch.cuda.is_available() else 'cpu')
# print('Using device:', device)
# print()

# #Additional Info when using cuda
# if device.type == 'cuda':
#     print(torch.cuda.get_device_name(0))
#     print('Memory Usage:')
#     print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
#     print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

## Load parameters and dataset

In [None]:
# Experiment dir
load_dir = 'MNIST/N2000-E30-B832'


# Read the kernel architecture hyperparameters
args = {} 
with open(oj(load_dir,'settings_dict.txt') ,'r') as file:
    for line in file.readlines():
        (key, value) = line.strip().split(' : ', 1)

        try: 
            args[key] = literal_eval(value)
        except Exception as e:
            args[key] = value
init_deterministic(args['noise_seed']) # comment this out for faster runtime


# Initialize the kernel, including initializing and loading pretrained weights for the shared feature extrator 
kernel, _ = construct_kernel(args)
kernel.eval()

# Load pretrained weights: including the individual MLP layers and the Gpytorch Hyperparameters
trained_kernel_dir = oj(load_dir, 'trained_kernels', 'model_-E25.pth')
kernel.load_state_dict(torch.load(trained_kernel_dir), strict=False)


# Construct data loaders for a quick evaluation
joint_loader, train_loaders, joint_test_loader, test_loaders = prepare_loaders(args, repeat=False)

test_logs_dir = oj("test_logs_dir", args['dataset'])
os.makedirs(test_logs_dir, exist_ok=True)
if args['include_joint']:
    train_loaders = [joint_loader] + train_loaders
    test_loaders = [joint_test_loader] + test_loaders

In [None]:
num_parties = 6  # 5 + joint

all_party_datasets = []
all_party_labels = []
for i in range(num_parties):
    iterator = iter(train_loaders[i])
    party_images = []
    party_labels = []
    while True:
        try:
            images, labels = next(iterator)
            party_images.append(images.cpu().numpy())
            party_labels.append(labels.cpu().numpy())
        except StopIteration:
            break
        
    party_dataset = np.concatenate(party_images)
    party_labels = np.concatenate(party_labels)
    all_party_datasets.append(party_dataset)
    all_party_labels.append(party_labels)

In [None]:
# Check
party = 0
i = 21

plt.imshow(np.transpose(all_party_datasets[party][i], [1, 2, 0]))
all_party_labels[party][i]

In [None]:
all_test_datasets = []
all_test_labels = []
for i in range(1, num_parties):
    iterator = iter(test_loaders[i])
    party_images = []
    party_labels = []
    while True:
        try:
            images, labels = next(iterator)
            party_images.append(images.cpu().numpy())
            party_labels.append(labels.cpu().numpy())
        except StopIteration:
            break
        
    party_dataset = np.concatenate(party_images)
    party_labels = np.concatenate(party_labels)
    all_test_datasets.append(party_dataset)
    all_test_labels.append(party_labels)

In [None]:
# Check
party = 2
i = 219

plt.imshow(np.transpose(all_test_datasets[party][i], [1, 2, 0]))
all_test_labels[party][i]

In [None]:
def get_features(X, model):
    with torch.no_grad():
        X_tens = torch.tensor(X).cuda()
        X_feat = kernel.indi_feature_extractors(model.get_vae_features(X_tens))
    return X_feat.cpu().numpy()

In [None]:
# Convert everything to features
all_party_feats = [get_features(X, kernel) for X in all_party_datasets]
all_test_feats = [get_features(X, kernel) for X in all_test_datasets]

In [None]:
def kernel_wrapper(kernel):
    """
    But I'm not a rapper.
    """
    def wrapper(X, Y=None):
        with torch.no_grad():
            X_tens = X.clone().detach().cuda()
            k = kernel.gp_layer.covar_module
            X_feat = kernel.indi_feature_extractors(kernel.get_vae_features(X_tens))
            
            if Y is None:
                retval = k(X_feat, X_feat)
            else:
                Y_tens = Y.clone().detach().cuda()
                Y_feat = kernel.indi_feature_extractors(kernel.get_vae_features(Y_tens))
                retval = k(X_feat, Y_feat)
                del Y_tens
                
            del X_tens
            torch.cuda.empty_cache()
            return retval
            
    return wrapper

In [None]:
k = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
k.base_kernel.lengthscale = kernel.gp_layer.covar_module.base_kernel.kernels[0].lengthscale
# k.base_kernel.lengthscale_prior = kernel.gp_layer.covar_module.base_kernel.kernels[0].lengthscale_prior
# k.base_kernel.distance_module = kernel.gp_layer.covar_module.base_kernel.kernels[0].distance_module
k.outputscale = kernel.gp_layer.covar_module.outputscale

In [None]:
# for i in range(num_parties):
#     print("Party {}".format(i))
#     for j in range(num_parties):
#         print("With party {}: {}".format(j, mmd(all_party_feats[i], all_party_feats[j], k)[0]))  

In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

# Use test dataset as reference and candidates
ref_cand_dataset, ref_cand_labels = unison_shuffled_copies(np.concatenate(all_test_datasets), 
                                                           np.concatenate(all_test_labels))
ref_cand_feats = get_features(ref_cand_dataset, kernel)

## CD

### All clusters

In [None]:
num_exps = 10

In [None]:
num_reference = len(all_party_datasets[1])

In [None]:
ref_dataset = ref_cand_feats[:num_reference]
ref_labels = ref_cand_labels[:num_reference]
cand_dataset = ref_cand_feats[num_reference:]
cand_labels = ref_cand_labels[num_reference:]

In [None]:
cand_dataset = np.tile(cand_dataset, (num_exps, 1, 1))

In [None]:
phi = np.linspace(0.1, 1, num_exps)
greeds = np.ones(num_exps) * 4
eta = 0.01

In [None]:
phi

In [None]:
cd_all_res = []
cd_all_deltas = []
cd_all_mus = []

In [None]:
for i in range(1, num_parties):
    D = np.array([all_party_feats[i]] * num_exps)
    res, deltas, mus = con_div(candidates=cand_dataset, 
                               Y=ref_dataset, 
                               phi=phi, 
                               D=D, 
                               kernel=k,
                               perm_samp_dataset=np.concatenate(all_party_feats[1:]),
                               num_perms=200, 
                               greeds=greeds, 
                               eta=eta)
    cd_all_res.append(res)
    cd_all_deltas.append(deltas)
    cd_all_mus.append(mus)

In [None]:
pickle.dump((cd_all_res, cd_all_deltas, cd_all_mus), open("experiments/results/MNIST-CD-allclusters.p", "wb"))

In [None]:
cd_all_res, cd_all_deltas, cd_all_mus = pickle.load(open("experiments/results/MNIST-CD-allclusters.p", "rb"))

In [None]:
for i in range(num_parties-1):
    plt.figure(figsize=(12, 6), dpi=300)
    plt.plot(phi, [len(result) for result in cd_all_res[i]])
    plt.xlabel("$\phi$")
    plt.ylabel("Number of points added")
    plt.title("Party {}".format(i))

In [None]:
num_clusters=10

In [None]:
all_class_props = []
all_bad_props = []
for i in range(num_parties-1):
    class_props = []
    bad_props = []
    res = cd_all_res[i]
    
    existing_classes = all_party_labels[i+1]
    
    for result in res:
        class_props.append(class_proportion(np.concatenate([get_classes(np.array(result), cand_dataset[0], cand_labels),
                                                            existing_classes]), num_clusters))
        bad_props.append(class_proportion(np.concatenate([list(np.random.randint(0, num_clusters, len(result))),
                                                         existing_classes]), num_clusters))
    all_class_props.append(class_props)
    all_bad_props.append(bad_props)

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

In [None]:
for i in range(num_parties-1):
    plt.figure(figsize=(3, 3), dpi=300)
    plt.plot(phi, [prop[1] for prop in all_class_props[i]], label="CD", color=cm.get_cmap('Spectral')(0.9))
    plt.plot(phi, [prop[1] for prop in all_bad_props[i]], label="Random sampling", color=cm.get_cmap('Spectral')(0.1))
    plt.xlabel("$\phi$", fontsize=16)
    plt.ylabel("$\\rho$", fontsize=16)
    plt.legend()
    plt.title("Party {} (MNIST digits {}, {})".format(i+1, i*2, i*2+1))

In [None]:
all_added_props = []
for i in range(num_parties-1):
    class_props = []
    res = cd_all_res[i]    
    for result in res:
        class_props.append(class_proportion(get_classes(np.array(result), cand_dataset[0], cand_labels), num_clusters))
        bad_props.append(class_proportion(list(np.random.randint(0, num_clusters, len(result))), num_clusters))
    all_added_props.append(class_props)

In [None]:
all_added_props[1]

In [None]:
all_corrcoef = []
for i in range(num_parties-1):
    class_props = all_class_props[i]
    props = [pair[1] for pair in class_props]
    all_corrcoef.append(np.corrcoef(np.array(list(zip(phi, props))).T)[0,1])
print("Average correlation coefficient: {}".format(np.mean(all_corrcoef)))

## CCR

In [None]:
num_exps = 10
num_reference = len(all_party_datasets[1])
phi = np.linspace(0.1, 1, num_exps)

In [None]:
ref_dataset = ref_cand_feats[:num_reference]
ref_labels = ref_cand_labels[:num_reference]
cand_dataset = ref_cand_feats[num_reference:]
cand_labels = ref_cand_labels[num_reference:]
cand_dataset = np.tile(cand_dataset, (num_exps, 1, 1))

In [None]:
ccr_all_res = []
ccr_all_deltas = []
ccr_all_mus = []

In [None]:
for i in range(1, num_parties):
    D = np.array([all_party_feats[i]] * num_exps)
    res, deltas, mus = con_conv_rate(candidates=cand_dataset, 
                                     Y=ref_dataset, 
                                     phi=phi, 
                                     D=D, 
                                     kernel=k)
    ccr_all_res.append(res)
    ccr_all_deltas.append(deltas)
    ccr_all_mus.append(mus)

In [None]:
pickle.dump((ccr_all_res, ccr_all_deltas, ccr_all_mus), open("experiments/results/MNIST-CCR-allclusters.p", "wb"))

In [None]:
(ccr_all_res, ccr_all_deltas, ccr_all_mus) = pickle.load(open("experiments/results/MNIST-CCR-allclusters.p", "rb"))

In [None]:
for j in range(num_parties-1):
    mus = ccr_all_mus[j]
    x = list(range(1, len(mus[0])+1))
    plt.figure(figsize=(3, 3), dpi=300)
    phi_labels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    for i in range(len(mus)):
        if int(phi[i]*10) % 2 == 0:
            plt.plot(x, mus[i], 'C0', linewidth=1, color=cm.get_cmap('Spectral')(phi[i]), label="$\phi = ${}".format(phi_labels[i]))

    plt.legend()
    plt.title("Party {} (MNIST digits {}, {})".format(j+1, j*2, j*2+1))
    plt.ylabel("$z(D \cup R_i)$", fontsize=16)
    plt.xlabel("$|R_i|$", fontsize=16)

In [None]:
num_clusters=10

In [None]:
all_Es = []
all_class_prop_AUCs = []
all_corrcoeff = []

for party in range(num_parties-1):
    R = ccr_all_res[party]
    class_props = [[] for i in range(num_exps)]
    num_candidate_points = cand_dataset.shape[1]
    deltas = ccr_all_deltas[party]
    
    for i in range(num_exps):
        reward_set = R[i]
        classes = get_classes(np.array(reward_set), cand_dataset[0], cand_labels)
        for j in range(num_candidate_points):
            current_classes = np.concatenate((classes[:j+1], all_party_labels[party+1]))
            class_props[i].append(class_proportion(current_classes, num_clusters)[1])
    
    Es = []
    class_prop_AUCs = []
    for i in range(num_exps):
        delta = np.array(deltas[i])
        Es.append(np.sum(delta[:-1] * np.arange(num_candidate_points-1, 0, -1)))
        props = np.array(class_props[i])
        class_prop_AUCs.append(np.sum(props[:-1] * np.arange(num_candidate_points-1, 0, -1)))
    
    all_Es.append(Es)
    all_class_prop_AUCs.append(class_prop_AUCs)
    
    all_corrcoeff.append(np.corrcoef(np.array(list(zip(Es, class_prop_AUCs))).T)[0, 1])

print("Average correlation coefficient: {}".format(np.mean(all_corrcoeff)))