In [4]:
%load_ext autoreload
%autoreload 2

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
from os.path import join as oj
from copy import deepcopy
import pandas as pd
import numpy.random as npr
import time
sys.path.append('../../src')
sys.path.append('../../interp')
import utils, lcp, train
from scipy.stats import ttest_ind, spearmanr
from typing import Dict
from compare_stats import compare_stats
import gen_data

# sklearn models
from sklearn.model_selection import train_test_split
from sklearn import metrics
from all_scores import get_scores
from style import style_tab

cred = (234 / 255, 51 / 255, 86 / 255)
cblue = (57 / 255, 138 / 255, 242 / 255)
out_dir = '../../results/interp_sim'
os.makedirs(out_dir, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
def get_data(seed=15, sim_num=1):
    np.random.seed(seed)
    # generate data
    d = 10
    n = 1000
    var_eps = 0.1

    if sim_num == 1:
        var_eps = 1
    elif sim_num == 2:
        var_eps = 10
    elif sim_num == 3:
        d = 100
    beta = np.zeros(d)
    beta[0] = 1
    beta[1] = 2       
        
        
    # pick beta
    X, y, _ = gen_data.gen_gaussian_linear_data(n=n, d=d, beta=beta, var_eps=var_eps, 
                 s=None, shift_type='None', shift_val=0.1, logistic=True)
    return train_test_split(X, y), beta # split defaults to 0.75: 0.25 split

In [None]:
sim_nums = list(range(2))
num_points = 1
tabs_list = []
class_weights = [0.5, 1.0, 2.0]
for point_num in tqdm(range(num_points)):
    
    vals_list = []
    for sim_num in sim_nums:
        (X_train, X_test, y_train, y_test), beta = get_data(sim_num=sim_num)

        # train and get importance scores
        ms = train.train_models(X_train, y_train, 
                                class_weights=class_weights, model_type='logistic')
        scores = get_scores(ms, X_train, X_test[point_num], mode='classification')
        # pd.DataFrame(scores).style.background_gradient(cmap='viridis')
        
        # dataset of statistics based on importance scores
        ds = compare_stats(beta, {k: scores[k] for k in scores if not 'std' in k})

        # record accuracy of model
        ds['r2'] = {'': metrics.r2_score(ms[1].predict(X_test), y_test)}
        ds['std'] = {k: scores[k] for k in scores if 'std' in k}

        ind_tuples = [(outerKey, innerKey) for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())]
        ind = pd.MultiIndex.from_tuples(ind_tuples, names=['score', 'metric'])
        vals = np.array([values for outerKey, innerDict in sorted(ds.items()) for innerKey, values in sorted(innerDict.items())])
        vals_list.append(vals.reshape(-1, vals.size))
    vals = np.array(vals_list).squeeze()
    tab = pd.DataFrame(vals, columns=ind, index=[f'Sim {str(i)}' for i in sim_nums])
    tabs_list.append(tab)
tab = pd.concat(tabs_list).groupby(level=0).mean().round(decimals=2)

In [20]:
style_tab(tab)

score,ice-contrib,ice-contrib,ice-contrib,ice-sensitivity,ice-sensitivity,ice-sensitivity,lime,lime,lime,r2,shap,shap,shap,std,std,std,std
metric,Fraction Correct Signs,Fraction Intersect,Rank Corr,Fraction Correct Signs,Fraction Intersect,Rank Corr,Fraction Correct Signs,Fraction Intersect,Rank Corr,Unnamed: 10_level_1,Fraction Correct Signs,Fraction Intersect,Rank Corr,ice-contrib_std,ice-sensitivity_std,lime_std,shap_std
Sim 0,0.5,0.5,1,1,1,1,0.5,1,-1,0.31,0.5,1,1,0.17,0.05,0.17,0.18
Sim 1,0.5,0.5,1,1,1,1,0.5,1,-1,0.11,0.5,1,1,0.17,0.05,0.17,0.17


# convert to html/latex

In [None]:
vals = vals.set_properties(**{'text-align': 'center'})
html = vals.render()
with open(oj(out_dir, 'table_classification.html'), 'w') as f:
    f.write(html)

In [None]:
s = tab.transpose().to_latex(multicolumn_format='c')
s = s.replace('\\toprule', '').replace('\\bottomrule', '')
print(s)