In [51]:
import numpy as np
import torch
import gpytorch
from data.pipeline import get_data_features
from core.reward_calculation import get_v, shapley, get_vN, get_v_is, get_eta_q
from core.utils import norm
from tqdm import tqdm

In [8]:
configs = [
    ('gmm', 5, 2, 5, 1000, 10000, 'equaldisjoint', 0),
    ('gmm', 5, 2, 5, 1000, 10000, 'unequal', 0),
    ('mnist', 10, 16, 5, 10000, 40000, 'equaldisjoint', 0),
    ('mnist', 10, 16, 5, 10000, 40000, 'unequal', 0),
]

In [32]:
def median_heuristic(data):
    """
    :param data: array of shape (n, d)
    """
    n, d = data.shape
    norms = np.zeros(n*(n-1)//2)
    idx = 0
    for i in tqdm(range(len(data))):
        current_norms = np.linalg.norm(data[i:i+1] - data[i+1:], axis=1)
        norms[idx:idx+len(current_norms)] = current_norms
        idx = idx + len(current_norms)
    return np.median(norms), norms

In [46]:
def get_se_kernel(lengthscale, d):
    kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d))
    kernel.base_kernel.lengthscale = [lengthscale for _ in range(d)]
    kernel.outputscale = 1
    return kernel

## SE, lengthscale median heuristic

In [56]:
device = 'cuda:0'

In [57]:
for config in configs:
    print("Stats for {}-{}-gamma{}".format(config[0], config[-2], config[-1]))
    party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features(*config)
    med, norms = median_heuristic(reference_dataset)
    print("Median lengthscale: {}".format(med))
    
    kernel = get_se_kernel(med, config[2])
    
    v = get_v(party_datasets, reference_dataset, kernel, device=device, batch_size=128)
    print("Coalition values:\n{}".format(v))
    phi = shapley(v, config[3])
    print("Shapley values:\n{}".format(phi))
    alpha = norm(phi)
    print("alpha:\n{}".format(alpha))

Stats for gmm-equaldisjoint-gamma0


100%|██████████| 15000/15000 [00:04<00:00, 3339.17it/s]


Median lengthscale: 0.2987956604080148
Coalition values:
{'{1}': 0.1612558601115266, '{2}': 0.3409984767494998, '{3}': 0.521881311035312, '{4}': 0.09681005060487369, '{5}': 0.4577612178991164, '{1, 2}': 0.492373820604909, '{1, 3}': 0.49437374319394256, '{1, 4}': 0.4823997120005754, '{1, 5}': 0.44695926834617095, '{2, 3}': 0.4847544716837512, '{2, 4}': 0.44012500181403935, '{2, 5}': 0.5478277583911825, '{3, 4}': 0.5066218362641931, '{3, 5}': 0.5255465447019498, '{4, 5}': 0.4744616131172803, '{1, 2, 3}': 0.5402080548926744, '{1, 2, 4}': 0.5622810833569096, '{1, 2, 5}': 0.5542919816232108, '{1, 3, 4}': 0.5726259937870141, '{1, 3, 5}': 0.5251799817691436, '{1, 4, 5}': 0.5443839161378425, '{2, 3, 4}': 0.5295906000731383, '{2, 3, 5}': 0.5457637881581776, '{2, 4, 5}': 0.5504541942262788, '{3, 4, 5}': 0.5500074885327094, '{1, 2, 3, 4}': 0.585043934077701, '{1, 2, 3, 5}': 0.5627217935060447, '{1, 2, 4, 5}': 0.5889335928979118, '{1, 3, 4, 5}': 0.5778771244496737, '{2, 3, 4, 5}': 0.56765292445699

100%|██████████| 15000/15000 [00:05<00:00, 2946.87it/s]


Median lengthscale: 0.2987956604080148


  0%|          | 12/90000 [00:00<13:04, 114.78it/s]

Coalition values:
{'{1}': 0.5938009897454339, '{2}': 0.5936672958403694, '{3}': 0.46333125969915145, '{4}': 0.5578179238115365, '{5}': 0.5015951779552444, '{1, 2}': 0.5937990109153629, '{1, 3}': 0.5615393294919617, '{1, 4}': 0.5845026031492045, '{1, 5}': 0.5707831780274703, '{2, 3}': 0.56035161449981, '{2, 4}': 0.5846519232584221, '{2, 5}': 0.5717182733615772, '{3, 4}': 0.5710063998999032, '{3, 5}': 0.5832093712699838, '{4, 5}': 0.5625113642328525, '{1, 2, 3}': 0.5791066971492863, '{1, 2, 4}': 0.5896142155438461, '{1, 2, 5}': 0.5840153761862878, '{1, 3, 4}': 0.5836936841009059, '{1, 3, 5}': 0.589266676417537, '{1, 4, 5}': 0.5797749433473216, '{2, 3, 4}': 0.5832470301423869, '{2, 3, 5}': 0.5891692558923022, '{2, 4, 5}': 0.5802717606449168, '{3, 4, 5}': 0.5937960200161139, '{1, 2, 3, 4}': 0.5878855366666043, '{1, 2, 3, 5}': 0.5913008539865158, '{1, 2, 4, 5}': 0.5861314148171489, '{1, 3, 4, 5}': 0.5938198926164228, '{2, 3, 4, 5}': 0.5938107794673495, '{1, 2, 3, 4, 5}': 0.5938261732508391}

100%|██████████| 90000/90000 [05:05<00:00, 294.92it/s] 


Median lengthscale: 52.74210044729543


  0%|          | 15/90000 [00:00<10:21, 144.78it/s]

Coalition values:
{'{1}': 0.5927954793882254, '{2}': 0.562948739960665, '{3}': 0.5902514606890453, '{4}': 0.5936531420291831, '{5}': 0.5858464681297961, '{1, 2}': 0.5945742229218517, '{1, 3}': 0.6072820788427044, '{1, 4}': 0.6040226027015335, '{1, 5}': 0.607673434841876, '{2, 3}': 0.6019023482008421, '{2, 4}': 0.6026792865493289, '{2, 5}': 0.604635674326433, '{3, 4}': 0.6034820822773167, '{3, 5}': 0.5975038974803051, '{4, 5}': 0.6024590783075494, '{1, 2, 3}': 0.6076709910915178, '{1, 2, 4}': 0.6061896763680882, '{1, 2, 5}': 0.6095491823202183, '{1, 3, 4}': 0.609160774797752, '{1, 3, 5}': 0.6089938040502756, '{1, 4, 5}': 0.6093694859840709, '{2, 3, 4}': 0.6094890583812289, '{2, 3, 5}': 0.6085690012500905, '{2, 4, 5}': 0.6107386451792889, '{3, 4, 5}': 0.604892129045851, '{1, 2, 3, 4}': 0.6110295526150055, '{1, 2, 3, 5}': 0.6119126456325369, '{1, 2, 4, 5}': 0.6121055962236617, '{1, 3, 4, 5}': 0.6102874748332906, '{2, 3, 4, 5}': 0.6115781154343584, '{1, 2, 3, 4, 5}': 0.6131349182083294}
Sh

100%|██████████| 90000/90000 [05:16<00:00, 283.97it/s] 


Median lengthscale: 54.80023367481321
Coalition values:
{'{1}': 0.61102814485441, '{2}': 0.6111770033964119, '{3}': 0.5971572090420568, '{4}': 0.6020125303596179, '{5}': 0.5975632462335663, '{1, 2}': 0.6112642782321124, '{1, 3}': 0.6073495831451806, '{1, 4}': 0.6088445225866453, '{1, 5}': 0.6084212567285863, '{2, 3}': 0.6075318125081322, '{2, 4}': 0.6088221163196598, '{2, 5}': 0.608133901873044, '{3, 4}': 0.6074352995524652, '{3, 5}': 0.6098769327360573, '{4, 5}': 0.6069650577632716, '{1, 2, 3}': 0.6094689264720928, '{1, 2, 4}': 0.6101673322159146, '{1, 2, 5}': 0.6101677058722879, '{1, 3, 4}': 0.6093688596534538, '{1, 3, 5}': 0.610760276701029, '{1, 4, 5}': 0.6095910474296021, '{2, 3, 4}': 0.6094233523025487, '{2, 3, 5}': 0.6106970144218785, '{2, 4, 5}': 0.6094368359815893, '{3, 4, 5}': 0.6111529082857703, '{1, 2, 3, 4}': 0.6101400421294887, '{1, 2, 3, 5}': 0.6110287408649727, '{1, 2, 4, 5}': 0.6103901677703281, '{1, 3, 4, 5}': 0.6112530218168443, '{2, 3, 4, 5}': 0.6112025315591993, '{

## SE, sum of lengthscales (0.1, 1, 10) \* med

In [58]:
for config in configs:
    print("Stats for {}-{}-gamma{}".format(config[0], config[-2], config[-1]))
    party_datasets, party_labels, reference_dataset, candidate_datasets, candidate_labels = get_data_features(*config)
    med, norms = median_heuristic(reference_dataset)
    print("Median lengthscale: {}".format(med))
    
    kernel = get_se_kernel(med*0.1, config[2])
    kernel += get_se_kernel(med, config[2])
    kernel.kernels.append(get_se_kernel(med*10, config[2]))
    
    v = get_v(party_datasets, reference_dataset, kernel, device=device, batch_size=128)
    print("Coalition values:\n{}".format(v))
    phi = shapley(v, config[3])
    print("Shapley values:\n{}".format(phi))
    alpha = norm(phi)
    print("alpha:\n{}".format(alpha))

Stats for gmm-equaldisjoint-gamma0


100%|██████████| 15000/15000 [00:04<00:00, 3332.34it/s]


Median lengthscale: 0.2987956604080148
Coalition values:
{'{1}': 1.0996315447157834, '{2}': 1.299279148305191, '{3}': 1.4905293723748778, '{4}': 1.0371786879298037, '{5}': 1.4183878933601515, '{1, 2}': 1.4824361753246866, '{1, 3}': 1.4891948094277951, '{1, 4}': 1.469910321954365, '{1, 5}': 1.4370283439334173, '{2, 3}': 1.4758631782479885, '{2, 4}': 1.4294000679148489, '{2, 5}': 1.5466461255864488, '{3, 4}': 1.502292231266351, '{3, 5}': 1.5138446747679952, '{4, 5}': 1.4665342944336592, '{1, 2, 3}': 1.544504065178447, '{1, 2, 4}': 1.565655208869425, '{1, 2, 5}': 1.5607937769996747, '{1, 3, 4}': 1.5798055383970635, '{1, 3, 5}': 1.527969166896223, '{1, 4, 5}': 1.5487437461422249, '{2, 3, 4}': 1.5336927445674309, '{2, 3, 5}': 1.5485799444854993, '{2, 4, 5}': 1.5572751357938517, '{3, 4, 5}': 1.5538432053563536, '{1, 2, 3, 4}': 1.5964468518683015, '{1, 2, 3, 5}': 1.5727748319775814, '{1, 2, 4, 5}': 1.6011791729979894, '{1, 3, 4, 5}': 1.5889852316483184, '{2, 3, 4, 5}': 1.57797325530807, '{1, 

100%|██████████| 15000/15000 [00:04<00:00, 3318.85it/s]


Median lengthscale: 0.2987956604080148


  0%|          | 15/90000 [00:00<10:34, 141.78it/s]

Coalition values:
{'{1}': 1.6080256280789207, '{2}': 1.6078895454091156, '{3}': 1.4492971289991494, '{4}': 1.5591132244792905, '{5}': 1.4932294251585942, '{1, 2}': 1.608369469470193, '{1, 3}': 1.5694336925936625, '{1, 4}': 1.5955386169687644, '{1, 5}': 1.5797912921140458, '{2, 3}': 1.5677894769098586, '{2, 4}': 1.5958766729918175, '{2, 5}': 1.5810547169192608, '{3, 4}': 1.5799330976815131, '{3, 5}': 1.5943373418621825, '{4, 5}': 1.5704370744257046, '{1, 2, 3}': 1.5907953614897419, '{1, 2, 4}': 1.6026789599728632, '{1, 2, 5}': 1.5964130350408183, '{1, 3, 4}': 1.5959095163798218, '{1, 3, 5}': 1.602633013782543, '{1, 4, 5}': 1.5914110729241389, '{2, 3, 4}': 1.595344121271688, '{2, 3, 5}': 1.6024788936887044, '{2, 4, 5}': 1.5921379624777923, '{3, 4, 5}': 1.608576697471172, '{1, 2, 3, 4}': 1.6011945657831435, '{1, 2, 3, 5}': 1.6053887815115786, '{1, 2, 4, 5}': 1.599234732831707, '{1, 3, 4, 5}': 1.6086596030719726, '{2, 3, 4, 5}': 1.608665929691815, '{1, 2, 3, 4, 5}': 1.608703238054911}
Shap

100%|██████████| 90000/90000 [05:18<00:00, 282.39it/s] 


Median lengthscale: 52.74210044729543


  0%|          | 12/90000 [00:00<12:45, 117.59it/s]

Coalition values:
{'{1}': 1.5863275028838808, '{2}': 1.5570638677345585, '{3}': 1.584836800593175, '{4}': 1.5882941686140923, '{5}': 1.5803047884959034, '{1, 2}': 1.5891588556871314, '{1, 3}': 1.6020768789026998, '{1, 4}': 1.5987717506513253, '{1, 5}': 1.602469584337128, '{2, 3}': 1.5967027974381192, '{2, 4}': 1.597478457727335, '{2, 5}': 1.5994562082326977, '{3, 4}': 1.5982961860086748, '{3, 5}': 1.5922145191561932, '{4, 5}': 1.597251551403357, '{1, 2, 3}': 1.6026139952111293, '{1, 2, 4}': 1.6011056352256243, '{1, 2, 5}': 1.6045158259903893, '{1, 3, 4}': 1.6041245322399609, '{1, 3, 5}': 1.6039527597345684, '{1, 4, 5}': 1.6043383428414855, '{2, 3, 4}': 1.6044127697505186, '{2, 3, 5}': 1.603476515831599, '{2, 4, 5}': 1.605675782734332, '{3, 4, 5}': 1.5997348076188591, '{1, 2, 3, 4}': 1.6060559391256053, '{1, 2, 3, 5}': 1.6069530909750518, '{1, 2, 4, 5}': 1.6071478110436888, '{1, 3, 4, 5}': 1.605299710041461, '{2, 3, 4, 5}': 1.6065374768118768, '{1, 2, 3, 4, 5}': 1.6082010309285473}
Shap

100%|██████████| 90000/90000 [04:53<00:00, 307.12it/s] 


Median lengthscale: 54.80023367481321
Coalition values:
{'{1}': 1.6059355066039214, '{2}': 1.6060898659405474, '{3}': 1.5914607713292548, '{4}': 1.5967263115419406, '{5}': 1.592207031886669, '{1, 2}': 1.6062286712345362, '{1, 3}': 1.60217766076544, '{1, 4}': 1.603753041184474, '{1, 5}': 1.6033240059028304, '{2, 3}': 1.602360712063749, '{2, 4}': 1.6037309950865568, '{2, 5}': 1.6030333563551094, '{3, 4}': 1.602296024896106, '{3, 5}': 1.604776725735998, '{4, 5}': 1.6018063839065593, '{1, 2, 3}': 1.604398003597911, '{1, 2, 4}': 1.6051221273262082, '{1, 2, 5}': 1.6051235255043064, '{1, 3, 4}': 1.6043093686565526, '{1, 3, 5}': 1.6057233621996925, '{1, 4, 5}': 1.6045183193269918, '{2, 3, 4}': 1.604363775485991, '{2, 3, 5}': 1.6056583897181043, '{2, 4, 5}': 1.6043621924470821, '{3, 4, 5}': 1.606124713710754, '{1, 2, 3, 4}': 1.6051102193807594, '{1, 2, 3, 5}': 1.6060136360443709, '{1, 2, 4, 5}': 1.6053492739208868, '{1, 3, 4, 5}': 1.6062422579276328, '{2, 3, 4, 5}': 1.6061905519237185, '{1, 2, 