In [1]:
import pandas as pd
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual, Layout
import io

import scipy.stats as st
import statsmodels.stats.api as sms
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import shap
from sklearn.preprocessing import KBinsDiscretizer


In [2]:
#### 0.0 Data Load ####

In [3]:
uploader = widgets.FileUpload(
    multiple=False
)

In [4]:
display(uploader)

FileUpload(value={}, description='Upload')

In [None]:
`ff-marketing-analytics.analytics.SPV_top_60_dats`

ff-marketing-analytics.analytics.spv_shaps

In [None]:
def file_load(source):
    
    if source = 'BigQuery':
        disp = 'Please enter full address of data table in the form '

In [5]:
data = pd.read_pickle('./data/top_60_data.pkl')
data = data.reset_index(drop=True)

In [6]:
shaps = pd.read_pickle('./data/top_60_shap_df.pkl')

In [7]:
#### 1.0 Widgets ####

In [8]:
# check dtypes - some conts as cats

cont_idxs = data.columns[np.where(data.dtypes!='category')]
cat_idxs = data.columns[np.where(data.dtypes=='category')]

In [9]:
cont_dd = widgets.Dropdown(
    options=cont_idxs,
    value=cont_idxs[0],
    description='',
    disabled=False,
)

In [10]:
cat_dd = widgets.Dropdown(
    options=cat_idxs,
    value=cat_idxs[0],
    description='',
    disabled=False,
)

In [11]:
y_axis_checkbox = widgets.Checkbox(
    value=False,
    description='Same y-axis scales',
    disabled=False,
    indent=False

)

In [12]:
min_max_slider = widgets.FloatRangeSlider(
    value=[0, data[cont_dd.value].max()],
    min=0,#data[cont_dd.value].min(),
    max=data[cont_dd.value].max(),
    step=0.1,
    description='',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [32]:
z_val_tb = widgets.FloatText(
    value=3.0,
    description='z score:',
    disabled=False,
    #layout=Layout(width='33%')
)

In [14]:
# need a limit on number of categories to display

In [15]:
#### 2.0 Functions ####

In [16]:
def min_max_slider_range(change):
    #col = cont_dd.value
    col = change.new
    vals = data[col]
    vals_outlier_mask = (pd.Series(np.abs(st.zscore(vals)) < z_val_tb.value))
    vals = vals.loc[vals_outlier_mask]
    min_max_slider.min = vals.min()
    min_max_slider.max = vals.max()
    min_max_slider.value = (0, vals.max())

In [17]:
cont_dd.observe(min_max_slider_range, names='value')

In [18]:
def cat_summary(col):

    stat_dict = {}
    categories = data[col].unique()
    column_shaps = shaps[col]
    mean_shaps = column_shaps.mean()
    for cat in categories:
        # indexes for instances where the category is present
        var = data.loc[:,col]
        idxs = var[var==cat].index
        # shaps for instances where the category is present
        shap_subset = shaps.loc[idxs, col]
        # mean shap for instances where the category is present
        mean_shap_subset = shap_subset.mean()
        # difference between mean shap when present to total mean shap
        mean_diff_shaps = mean_shap_subset - mean_shaps
        # statistical inference (confidence intervals and p values) for differences in means 
        cm_total = sms.CompareMeans(sms.DescrStatsW(shap_subset), sms.DescrStatsW(column_shaps))
        total_diff_conf_int = cm_total.tconfint_diff(usevar='unequal')
        total_diff_p_val = cm_total.ttest_ind(usevar='unequal')[1]
        # Dictionary of Summary Stats
        stat_dict[cat] = {'Mean SHAP': round(mean_shap_subset,4),
                        'Mean Difference in SHAP to Total': round(mean_diff_shaps,4),
                        'Mean Difference in SHAP to Total (CI Lower)':round(float(total_diff_conf_int[0]),4),
                        'Mean Difference in SHAP to Total (CI Upper)':round(float(total_diff_conf_int[1]),4),
                       }
        
    summary = pd.DataFrame(stat_dict).T.sort_values('Mean SHAP', ascending=True)
    
    return summary

In [19]:
def display_results(func):
    def inner(*args, **kwargs):
        results = func(*args, **kwargs)
        display(results)
    return inner

In [20]:
@display_results
def display_cat_summary(*args, **kwargs):
    return cat_summary(*args, **kwargs)

In [21]:
#@cat_vis_interactive.capture(clear_output=True, wait=True)
def cat_vis(col, sharey): 

    col_shaps = shaps[col]
    col_data = data[col]
    summary_data = cat_summary(col)
    row_order = summary_data.index
    
    joined_instance = pd.concat([col_data, col_shaps], axis=1)
    joined_instance.columns = [col, col+' SHAP']
    g = sns.FacetGrid(joined_instance, row=col, aspect=5, height=2, sharey=sharey, row_order=row_order)
    
    for ax, row in zip(g.axes, g.row_names):
        ax[0].axvline(0, lw=1, ls='dashed', c='black')
        shap_subset = col_shaps.loc[col_data==row]
        mask_alpha = (shap_subset > 0)
        bins = np.linspace(min(shap_subset),max(shap_subset), int(np.sqrt(shap_subset.size)))
        ax[0].hist(shap_subset, bins=bins, histtype='bar',color=sns.xkcd_rgb["windows blue"],lw=0)
        ax[0].hist(shap_subset[mask_alpha], bins=bins, histtype='bar',color=sns.xkcd_rgb["light red"],lw=0)
        ax[0].set_title('{} = {}'.format(col, row), )

    g.set_xlabels('SHAP values for {}'.format(col), **label_kwargs)
    #g.set_ylabels('Count of visits', **label_kwargs)
    #g.set_title("Relationship between {} & SHAP Values".format(col), **title_kwargs)
    #g.set(yticks=[])
    #g.despine(left=True)
    plt.show()


In [22]:
def cont_vis(
    col,
    z_val=3,
    min_max=None
    ):
    
    vals = data[col]
    vals_outlier_mask = (pd.Series(np.abs(st.zscore(vals)) < z_val))
    
    col_shaps = shaps[col]
    
    vals = vals.loc[vals_outlier_mask]
    
    col_shaps = col_shaps.loc[vals_outlier_mask]
    
    min_val, max_val = min_max if min_max else None
        
    if max_val:
        ulim = (vals < max_val)
        vals = vals.loc[ulim]
        col_shaps = col_shaps.loc[ulim]
        
    if min_val:
        llim = (vals > min_val)
        vals = vals.loc[llim]
        col_shaps = col_shaps.loc[llim]
                
    sns.set(**sns_set_kwargs)
    
    g = sns.JointGrid(x=vals, y=col_shaps, height=8, space=0.4)
    g.fig.set_figwidth(16)
    
    offset = mcolors.TwoSlopeNorm(
        vmin=col_shaps.min(),
        vcenter=0.,
        vmax=col_shaps.max()
    )

    g = g.plot_joint(
        plt.scatter, 
        s=10, 
        c=np.array(offset(col_shaps.values)), 
        cmap=sns.diverging_palette(h_neg=240, h_pos=10, as_cmap=True)
    )
    
    try:
        g = g.plot_marginals(sns.distplot, kde=True, color=".5")
    except RuntimeError as re:
        if str(re).startswith("Selected KDE bandwidth is 0. Cannot estimate density."):
            g = g.plot_marginals(sns.distplot, kde=False, color=".5")
        else:
            raise re
    
    g.ax_joint.axhline(0, lw=1, ls='dashed', c='black')
    
    g.ax_joint.set_xlabel(col+' value', **label_kwargs)
    
    g.ax_joint.set_ylabel("SHAP Values for "+col, **label_kwargs)
    
    g.ax_joint.set_title("Relationship between "+col+" & SHAP Values", **title_kwargs)
        
    plt.show()
    

In [23]:
label_kwargs = {
    'fontstyle' : 'normal', 
    'color' : '#585858', 
    'fontsize' : 16,
    'fontweight' : 'bold',
}

title_kwargs = {
    'fontstyle' : 'normal', 
    'fontsize' : 16,
    'fontweight' : 'bold',
}

sns_set_kwargs = { 
    'font_scale' : 1.2, 
    'font' : 'Arial',
    'rc' : {'figure.figsize':(21.7,12.7)},
}

In [24]:
#### 3.0 Output ####

# Most Important Features

In [48]:
@interact(max_display=widgets.IntText(value=10, description='Plot top:', disabled=False))
def top_features(max_display):
    shap.summary_plot(
        shaps.values,
        data,
        plot_type='bar',
        max_display=max_display
    )

interactive(children=(IntText(value=10, description='Plot top:'), Output()), _dom_classes=('widget-interact',)…

In [25]:
cat_vis_interactive = widgets.interactive_output(cat_vis, {'col':cat_dd, 'sharey':y_axis_checkbox})

In [26]:
cat_table_interactive = widgets.interactive_output(display_cat_summary, {'col':cat_dd})

# Categorical Feature Importances

In [27]:
widgets.VBox(
    [
        widgets.HBox(
            [
                widgets.HBox([widgets.Label('Categorical variables: '), cat_dd]), 
                y_axis_checkbox
            ]
        ), 
        cat_table_interactive, 
        cat_vis_interactive
    ]
)

VBox(children=(HBox(children=(HBox(children=(Label(value='Categorical variables: '), Dropdown(options=('Market…

In [28]:
#cont_vis_interactive = widgets.interactive(cont_vis, col=cont_dd, min_max=min_max_slider, z_val=fixed(5), z_shap=fixed(5))


In [33]:
cont_vis_interactive = widgets.interactive_output(cont_vis, {'col':cont_dd, 'min_max':min_max_slider, 'z_val':z_val_tb})


# Continuous Feature Importances

In [34]:
widgets.VBox(
    [
        widgets.HBox(
            [
                widgets.HBox([widgets.Label('Continuous variables: '), cont_dd, widgets.Label('x-axis range:'), min_max_slider, z_val_tb])
            ]
        ), 
        cont_vis_interactive
    ]
)

VBox(children=(HBox(children=(HBox(children=(Label(value='Continuous variables: '), Dropdown(options=('user sp…

In [31]:

#design builds: 
# slice data by subsets according to feature values
# config class to avoid global vars
