In [39]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
from os.path import join as oj
from copy import deepcopy
import pandas as pd
import numpy.random as npr
import time
sys.path.append('../../src')
sys.path.append('../../interp')
import utils, lcp, train
from scipy.stats import ttest_ind, spearmanr
from typing import Dict
from compare_stats import compare_stats
import gen_data
import imgkit 

# sklearn models
from sklearn.model_selection import train_test_split
from sklearn import metrics
from all_scores import get_scores

cred = (234/255, 51/255, 86/255)
cblue = (57/255, 138/255, 242/255)
out_dir = '../../results/interp_sim'
os.makedirs(out_dir, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
def get_data(seed=15, sim_num=1):
    np.random.seed(seed)
    # generate data
    d = 10
    n = 1000
    var_eps = 0.1

    if sim_num == 1:
        var_eps = 1
    elif sim_num == 2:
        var_eps = 10
    elif sim_num == 3:
        d = 100
    beta = np.zeros(d)
    beta[0] = 1
    beta[1] = 2       
        
        
    # pick beta
    X, y, _ = gen_data.gen_gaussian_linear_data(n=n, d=d, beta=beta, var_eps=var_eps, 
                 s=None, shift_type='None', shift_val=0.1)
    return train_test_split(X, y), beta # split defaults to 0.75: 0.25 split

In [None]:
sim_nums = list(range(3))
num_points = 1
tabs_list = []
for point_num in range(num_points):
    
    vals_list = []
    for sim_num in sim_nums:
        (X_train, X_test, y_train, y_test), beta = get_data(sim_num=sim_num)

        # train and get importance scores
        m = train.regress(X_train, y_train, model_type='linear')
        scores = get_scores(m, X_train, X_test[point_num])

        # pd.DataFrame(scores).style.background_gradient(cmap='viridis')

        # dataset of statistics based on importance scores
        ds = compare_stats(beta, scores)

        # record accuracy of model
        ds['r2'] = {'': metrics.r2_score(m.predict(X_test), y_test)}

        ind_tuples = [(outerKey, innerKey) for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())]
        ind = pd.MultiIndex.from_tuples(ind_tuples, names=['score', 'metric'])
        vals = np.array([values for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())])
        vals_list.append(vals.reshape(-1, vals.size))
    vals = np.array(vals_list).squeeze()
    tab = pd.DataFrame(vals, columns=ind, index=[f'Sim {str(i)}' for i in sim_nums])
    tabs_list.append(tab)
tab = pd.concat(tabs_list).groupby(level=0).mean().round(decimals=2)

In [27]:
tab

score,ice-contrib,ice-contrib,ice-contrib,ice-sensitivity,ice-sensitivity,ice-sensitivity,lime,lime,lime,r2,shap,shap,shap
metric,Fraction Correct Signs,Fraction Intersect,Rank Corr,Fraction Correct Signs,Fraction Intersect,Rank Corr,Fraction Correct Signs,Fraction Intersect,Rank Corr,Unnamed: 10_level_1,Fraction Correct Signs,Fraction Intersect,Rank Corr
Sim 0,0.59,0.98,-0.24,1.0,1.0,1.0,0.5,1.0,-0.06,0.98,0.52,1.0,0.04
Sim 1,0.59,0.95,-0.24,1.0,1.0,1.0,0.5,1.0,-0.06,0.81,0.52,0.93,0.04
Sim 2,0.59,0.86,-0.22,1.0,1.0,1.0,0.5,1.0,-0.06,-0.8,0.52,0.82,0.04


In [48]:
from visualize import background_gradient, cm

def style_tab(tab):
    vals = tab
    # vals = vals.drop(('r2', ''), axis=1)
    vals = vals.style.applymap(lambda val : 'color: black')

    importances = ['ice-contrib', 'ice-sensitivity', 'lime', 'shap']
    vals = vals.apply(background_gradient, axis=None, 
                            cmap=cm, cmin=0, cmax=1)
    vals = vals.apply(background_gradient, axis=None, 
                            cmap=cm, cmin=-1, cmax=1,
                            subset=[(imp, 'Rank Corr') for imp in importances])
    return vals

# convert to html/latex

In [49]:
# filter out things
s = tab

s = s.drop(('r2', ''), axis=1)
s = s.drop(('Sim 2'), axis=0)
s = style_tab(s)

In [50]:
vals = s.set_properties(**{'text-align': 'center'})
html = vals.render()
with open(oj(out_dir, 'table_regression.html'), 'w') as f:
    f.write(html)

In [None]:
# print table
s = s.to_latex(multicolumn_format='c')
s = s.replace('\\toprule', '').replace('\\bottomrule', '')