In [None]:
import numpy as np
import sys
sys.path.append('../tools')
from utils import get_shaps
from scipy.stats import sem

%load_ext autoreload
%autoreload 2

In [None]:
task = 'classification'
dataset_name = 'gaussian'
num_datapoints = 50

In [None]:
model_name = 'logistic'
metric = 'accuracy'
seed = 2022
repeat_num = 10

In [None]:
num_samples = 2000
num_bootstrap = 20
xi = 1e-3
methods = ['random', 'stratified', 'owen', 'Sobol', 'kernel', 'active-0', 'active-2', 'active-5', 'active-100']

In [None]:
path = '../experiment_data/symmetry' # use the same exp results as symmetry for simplicity

In [None]:
"""
Test out all sampling methods for data shapley
"""

res_data_shap = []

for method in methods:
    all_mcs = np.load(f"{path}/small_all_mcs_data_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    all_afs = np.load(f"{path}/small_all_afs_data_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    all_min_afs = np.load(f"{path}/small_all_min_afs_data_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    res_data_shap.append((method, all_mcs, all_afs, all_min_afs))

In [None]:
for i, item in enumerate(res_data_shap):
    method, all_mcs, all_afs, all_min_afs = item
    all_min_afs = np.asarray(all_min_afs)
    all_min_afs_mean = np.mean(all_min_afs, axis=0)
    all_min_afs_sem = sem(all_min_afs,axis=0)
    res_data_shap[i] = (method, all_mcs, all_afs, all_min_afs, all_min_afs_mean, all_min_afs_sem)

In [None]:
"""
Test out all sampling methods for beta shapley
"""
res_beta_shap = []

for method in methods:
    all_mcs = np.load(f"{path}/small_all_mcs_beta_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    all_afs = np.load(f"{path}/small_all_afs_beta_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    all_min_afs = np.load(f"{path}/small_all_min_afs_beta_shap_{method}_{dataset_name}_{model_name}_{num_datapoints}_{num_samples}_{num_bootstrap}.npy", allow_pickle=True)
    res_beta_shap.append((method, all_mcs, all_afs, all_min_afs))

In [None]:
for i, item in enumerate(res_beta_shap):
    method, all_mcs, all_afs, all_min_afs = item
    all_min_afs = np.asarray(all_min_afs)
    all_min_afs_mean = np.mean(all_min_afs, axis=0)
    all_min_afs_std = np.std(all_min_afs, axis=0)
    all_min_afs_sem = sem(all_min_afs, axis=0)
    res_beta_shap[i] = (method, all_mcs, all_afs, all_min_afs, all_min_afs_mean, all_min_afs_sem)

### Point Removal and Point Addition Experiment

In [None]:
sys.path.append('../')
from vol_utils.utils import set_up_plotting
plt = set_up_plotting()

In [None]:
for metric in ['data', 'beta']:
    for value_low in [True, False]:
        for addition in [True, False]:
            res_shap = res_data_shap if metric == 'data' else res_beta_shap
            vals_all = [[] for _ in range(len(res_shap))]

            for i in range(repeat_num):
                for k, (method, all_mcs, _, _, _, _) in enumerate(res_shap):
                    shaps = np.asarray(get_shaps(all_mcs[i]))
                    idx = np.argsort(shaps)
                    vals = []

                    iterator = range(1, 51) if addition else range(1, 91)
                    for j in iterator:
                        n = len(idx)
                        if value_low and addition:
                            truncated_idx = idx[:j]
                        elif not value_low and addition:
                            truncated_idx = idx[n-j:]
                        elif not value_low and not addition:
                            truncated_idx = idx[:n-j]
                        else:
                            truncated_idx = idx[j:]
                        try:
                            X_trunc = np.concatenate([X[truncated_idx], X_init])
                            y_trunc = np.concatenate([y[truncated_idx], y_init])
                            runner.model.fit(X_trunc, y_trunc)
                            val = runner.value()
                        except Exception as e:
                            print(e)
                            continue
                        vals.append(val)
                    vals_all[k].append(vals)
            vals_all = np.asarray(vals_all)
            plt.figure(figsize=(8,6))
            for i in range(len(res_data_shap)):
                vals = vals_all[i]
                vals_mean = np.mean(vals, axis=0)
                vals_sem = sem(vals, axis=0)
                method = methods[i]
                if method == 'random':
                    method = 'MC'
                if method == 'owen':
                    method = 'Owen'
                if method.startswith('active'):
                    alpha = int(method.split('-')[-1])
                    method = rf'Ours ($\alpha$ = {alpha})'
                plt.plot(vals_mean, color=f'C{i}', label=method)
                plt.fill_between(np.arange(len(vals_mean)), vals_mean - vals_sem, vals_mean + vals_sem, color=f'C{i}', alpha=0.3)
            if not addition:
                plt.xticks([0,9,19,29,39,49,59,69,79,89], [1,10,20,30,40,50,60,70,80,90])
            else:
                plt.xticks([0,4,9,14,19,24,29,34,39,44,49], [1,5,10,15,20,25,30,35,40,45,50])
            if addition:
                pass
                # plt.legend(loc = 'lower right', fontsize=22)
            else:
                pass
                # plt.legend(loc = 'upper left', fontsize=20)
            plt.ylabel("Accuracy")
            plt.xlabel("Number of {} Value Data {}".format("Low" if value_low else "High", "Added" if addition else "Removed"))
            plt.savefig("../figs/point_{}_{}_{}_shap_{}_{}_{}_no_legend.pdf".format(
                            "addition" if addition else "removal", "low" if value_low else "high", metric, dataset_name, model_name, num_datapoints
                            ), format='pdf', dpi=300, bbox_inches='tight')