In [None]:
import numpy as np
import sys
sys.path.append('../tools')
from utils import get_shaps
from scipy.stats import sem
import seaborn as sns

%load_ext autoreload
%autoreload 2

In [None]:
task = 'classification'
dataset_name = 'breast_cancer'
num_datapoints = 20
dargs = {'n_data_to_be_valued': num_datapoints+2, 'n_val': 100, 'n_test': 1000, 'seed': 2020}

In [None]:
model_name = 'SVC'
metric = 'accuracy'
seed = 2022
repeat_num = 10
num_samples = 300

In [None]:
path = '../experiment_data/nullity'
methods = ['random', 'stratified', 'owen', 'Sobol', 'kernel', 'active-0', 'active-2', 'active-5', 'active-100']
true_mcs = np.load(f"{path}/nullity_exact_{dataset_name}_{model_name}_{num_samples}.npy", allow_pickle=True)
est_mcs_list = []
for method in methods:
    est_mcs = np.load(f"{path}/nullity_mcs_{method}_{dataset_name}_{model_name}_{num_samples}.npy", allow_pickle=True)
    all_afs = np.load(f"{path}/nullity_afs_{method}_{dataset_name}_{model_name}_{num_samples}.npy", allow_pickle=True)
    est_mcs_list.append((method, est_mcs, all_afs))

In [None]:
# compute exact Shapley values
true_shaps = np.asarray([np.asarray(get_shaps(true_mc)) for true_mc in true_mcs])
true_shaps = true_shaps.mean(axis=0)
true_shaps /= sum(true_shaps) # standardize shapley values
print(true_shaps)

In [None]:
# find least absolute idx
idx = abs(true_shaps) < 1e-2
print(true_shaps[idx])

In [None]:
all_errors = np.zeros((len(est_mcs_list), len(est_mcs_list[0][1])))
for i, (method, est_mcs, _) in enumerate(est_mcs_list):
    n = len(est_mcs)
    errors = all_errors[i]
    for i, est_mc in enumerate(est_mcs):
        shaps = np.asarray(get_shaps(est_mc))
        shaps = shaps / sum(shaps) # standardize shapley value (estimated)
        errors[i] = abs(shaps[idx] - true_shaps[idx]).sum()
    print(f"evaluated {method} => deviation: {np.mean(errors)} \pm {sem(errors)}")

In [None]:
sys.path.append('../')
from vol_utils.utils import set_up_plotting
plt = set_up_plotting()

In [None]:
# exclude KernelSHAP
idx = [i for i in range(len(methods)) if i != methods.index('kernel')]
all_errors_rest = all_errors[idx]

In [None]:
plt.figure(figsize=(8,6))
sns.set_style(style='white')
sns.boxplot(data=all_errors_rest.T, showfliers=False,palette=['C0', 'C1', 'C2', 'C3', 'C5', 'C6', 'C7'])
# sns.pointplot(data=all_errors_rest.T, dodge=True, join=False, ci='sd', palette=['C0', 'C1', 'C2', 'C3', 'C5'])
xlabels=['MC', 'strat.', 'Owen', 'Sobol', r'$\alpha=0$', r'$\alpha=2$', r'$\alpha=5$', r'$\alpha=100$']
plt.xticks([0,1,2,3,4,5,6,7], xlabels, rotation=30)
plt.ylabel("Error")
plt.xlabel("   ")
plt.savefig(f"../figs/nullity_{dataset_name}_{model_name}_{num_datapoints}.pdf", format="pdf", dpi=300, bbox_inches='tight')

### Pigou Dalton Principle

In [None]:
# all_afs
afs_methods = []
for method, all_mcs, all_afs_list in est_mcs_list:
    all_afs = all_afs_list[0]
    afs_methods.append((method, all_afs))

In [None]:
def log_nash_social_welfare(fs_list):
    fs_list = fs_list / sum(fs_list) * len(fs_list)
    return -sum(np.log(fs_list))

In [None]:
# compute log nash social welfare
for method, all_mcs, all_afs_list in est_mcs_list:
    log_nsw = np.asarray([log_nash_social_welfare(all_afs) for all_afs in all_afs_list])
    print("%s, %.3f (%.3f)" % (method, log_nsw.mean(), sem(log_nsw)))