In [17]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
from os.path import join as oj
from copy import deepcopy
import pandas as pd
import numpy.random as npr
import time
sys.path.append('../../src')
sys.path.append('../../interp')
import utils, lcp, train
from scipy.stats import ttest_ind, spearmanr
from typing import Dict
from compare_stats import compare_stats
import gen_data

# sklearn models
from sklearn.model_selection import train_test_split
from sklearn import metrics
from all_scores import get_scores

cred = (234/255, 51/255, 86/255)
cblue = (57/255, 138/255, 242/255)
out_dir = '../../results/interp_sim'
os.makedirs(out_dir, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
def get_data(seed=15, sim_num=1):
    np.random.seed(seed)
    # generate data
    d = 10
    n = 100
    var_eps = 0.1

    if sim_num == 0:
        n = 1000
    elif sim_num == 1:
        var_eps = 10
    elif sim_num == 2:
        var_eps = 10
    if 0 <= sim_num <= 2:
        beta = np.zeros(d)
        beta[0] = 1
        beta[1] = 2

    if sim_num > 2:
        d = 100
        n = 1000 
        
        beta = np.zeros(d)
        beta[0] = 1
        beta[1] = 2
        
        
        
    # pick beta
    X, y, _ = gen_data.gen_gaussian_linear_data(n=n, d=d, beta=beta, var_eps=var_eps, 
                 s=None, shift_type='None', shift_val=0.1)
    return train_test_split(X, y), beta # split defaults to 0.75: 0.25 split

In [47]:
sim_nums = list(range(2))
num_points = 3
tabs_list = []
for point_num in range(num_points):
    
    vals_list = []
    for sim_num in sim_nums:
        (X_train, X_test, y_train, y_test), beta = get_data(sim_num=sim_num)

        # train and get importance scores

        m = train.regress(X_train, y_train, model_type='linear')
        scores = get_scores(m, X_train, X_test[point_num])

        # pd.DataFrame(scores).style.background_gradient(cmap='viridis')

        # dataset of statistics based on importance scores
        ds = compare_stats(beta, scores)

        # record accuracy of model
        ds['r2'] = {'': metrics.r2_score(m.predict(X_test), y_test)}

        ind_tuples = [(outerKey, innerKey) for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())]
        ind = pd.MultiIndex.from_tuples(ind_tuples, names=['score', 'metric'])
        vals = np.array([values for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())])
        vals_list.append(vals.reshape(-1, vals.size))
    vals = np.array(vals_list).squeeze()
    tab = pd.DataFrame(vals, columns=ind, index=[f'dset {str(i)}' for i in sim_nums])
    tabs_list.append(tab)
tab = pd.concat(tabs_list).groupby(level=0).mean().round(decimals=2)

W0802 01:32:58.997391 140437223470912 kernel.py:108] Using 750 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
W0802 01:33:04.247501 140437223470912 kernel.py:108] Using 750 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
W0802 01:33:09.564643 140437223470912 kernel.py:108] Using 750 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K we

In [48]:
tab.style.background_gradient(cmap='viridis')

score,ice-contrib,ice-contrib,ice-contrib,ice-sensitivity,ice-sensitivity,ice-sensitivity,lime,lime,lime,r2,shap,shap,shap
metric,frac_correct_signs,frac_intersect,rank_corr,frac_correct_signs,frac_intersect,rank_corr,frac_correct_signs,frac_intersect,rank_corr,Unnamed: 10_level_1,frac_correct_signs,frac_intersect,rank_corr
dset 0,0.33,1.0,-1.0,1,1,1,0.0,1,1.0,0.98,0.0,1.0,-1.0
dset 1,1.0,0.33,0.33,1,1,1,0.67,1,0.33,-0.04,0.67,0.5,-0.33


# convert to latex

In [49]:
s = tab.transpose().to_latex(multicolumn_format='c')
s = s.replace('\\toprule', '').replace('\\bottomrule', '')
print(s)

\begin{tabular}{llrr}

     &           &    dset 0 &    dset 1 \\
score & metric &           &           \\
\midrule
ice-contrib & frac\_correct\_signs &  0.333333 &  1.000000 \\
     & frac\_intersect &  1.000000 &  0.333333 \\
     & rank\_corr & -1.000000 &  0.333333 \\
ice-sensitivity & frac\_correct\_signs &  1.000000 &  1.000000 \\
     & frac\_intersect &  1.000000 &  1.000000 \\
     & rank\_corr &  1.000000 &  1.000000 \\
lime & frac\_correct\_signs &  0.000000 &  0.666667 \\
     & frac\_intersect &  1.000000 &  1.000000 \\
     & rank\_corr &  1.000000 &  0.333333 \\
r2 &           &  0.980631 & -0.039831 \\
shap & frac\_correct\_signs &  0.000000 &  0.666667 \\
     & frac\_intersect &  1.000000 &  0.500000 \\
     & rank\_corr & -1.000000 & -0.333333 \\

\end{tabular}

