In [1]:
import pandas as pd
import numpy as np
from IPython.display import Javascript, display
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import ipywidgets as widgets
from ipywidgets import interact, interactive, interactive_output, fixed, interact_manual, Layout
import io

import scipy.stats as st
import statsmodels.stats.api as sms
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import shap
from sklearn.preprocessing import KBinsDiscretizer
from copy import deepcopy

In [2]:
#### 0.0 Data Load ####

In [3]:
raw_data = pd.read_pickle('./top_60_data.pkl')
raw_data = raw_data.reset_index(drop=True)
data = deepcopy(raw_data)

In [4]:
raw_shaps = pd.read_pickle('./top_60_shap_df.pkl')
shaps = deepcopy(raw_shaps)

In [5]:
# check dtypes - some conts as cats

cont_idxs = data.columns[np.where(data.dtypes!='category')]
cat_idxs = data.columns[np.where(data.dtypes=='category')]

In [6]:
cat_drops = [
    'prev views COMME DES GARCONS COMME DES GARCONS',
    'prev views GROUND ZERO', 
    'prev views DIESEL RED TAG',
    'num checkout past week', 
    'prev views NIKE KIDS',
    'prev views AMIR SLAMA',
    'prev views Boat Shoes'
]

In [7]:
cat_idxs = cat_idxs.drop(cat_drops)

In [8]:
cont_drops = [
    'land basket qty',
    'avg prev unique action per view',
]

In [9]:
cont_idxs = cont_idxs.drop(cont_drops)

In [10]:
#### 1.0 Widgets ####

In [11]:
cont_dd = widgets.Dropdown(
    options=cont_idxs,
    value=cont_idxs[0],
    description='',
    disabled=False,
)

In [12]:
cat_dd = widgets.Dropdown(
    options=cat_idxs,
    value=cat_idxs[0],
    description='',
    disabled=False,
)

In [13]:
y_axis_checkbox = widgets.Checkbox(
    value=False,
    description='Same y-axis scales',
    disabled=False,
    indent=False

)

In [14]:
min_max_slider = widgets.FloatRangeSlider(
    value=[0, data[cont_dd.value].max()],
    min=0,
    max=data[cont_dd.value].max(),
    step=0.1,
    description='',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

In [15]:
z_val_tb = widgets.FloatText(
    value=3.0,
    description='z score:',
    disabled=False,
    #layout=Layout(width='33%')
)

In [16]:
cat_subsetter = widgets.Dropdown(
    options=cat_idxs,
    value=cat_idxs[0],
    description='',
    disabled=False,
)

In [17]:
subsetter_values = widgets.Dropdown(
    options=raw_data[cat_subsetter.value].unique(),
    description='',
    disabled=False,
)

In [18]:
subset_button = widgets.Button(
    description='Filter Results',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
)

In [19]:
reset_button = widgets.Button(
    description='Reset Data',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
)

In [20]:
#### 2.0 Functions ####

In [65]:
def reset_data(event):
    
    global data
    global shaps
    data = deepcopy(raw_data)
    shaps = deepcopy(raw_shaps)
    size_out.clear_output()
    
    with size_out:
        print('No subset selected')
        
    with important_out:
        important_out.clear_output()
        interact(top_features, max_display=widgets.IntText(value=10, description='Plot top:', disabled=False))
        
    with cat_table_output:
        cat_table_output.clear_output()
        cat_table_interactive = widgets.interactive_output(display_cat_summary, {'col':cat_dd})
        display(cat_table_interactive)
        
    with cat_vis_output:
        cat_vis_output.clear_output()
        cat_vis_interactive = widgets.interactive_output(cat_vis, {'col':cat_dd, 'sharey':y_axis_checkbox})
        display(cat_vis_interactive)
    
    with cont_output:
        cont_output.clear_output()
        cont_vis_interactive = widgets.interactive_output(
            cont_vis, 
            {'col':cont_dd, 'min_max':min_max_slider, 'z_val':z_val_tb}
        )
        display(cont_vis_interactive)

In [66]:
reset_button.on_click(reset_data)

In [21]:
def subsetter_options_update(change):
    cat = change.new
    subsetter_values.options = raw_data[cat].unique()

In [22]:
cat_subsetter.observe(subsetter_options_update, names='value')

In [60]:
def filter_results(event):
    mask = raw_data[cat_subsetter.value]==subsetter_values.value
    data_subset = raw_data.loc[mask, :]
    shap_subset = raw_shaps.loc[mask, :]
    global data
    data = data_subset
    global shaps
    shaps = shap_subset
    
    # if latency becomes an issue:
    # try without routing through additional outputs via context managers and just redefining interactive_outputs,
    # i.e. cat_table_interactive.clear_output(), cat_table_interactive = ...
    
    with size_out:
        size_out.clear_output()
        print('{}'.format(data.shape[0]))
    
    with important_out:
        important_out.clear_output()
        interact(top_features, max_display=widgets.IntText(value=10, description='Plot top:', disabled=False))
        
    with cat_table_output:
        cat_table_output.clear_output()
        cat_table_interactive = widgets.interactive_output(display_cat_summary, {'col':cat_dd})
        display(cat_table_interactive)
        
    with cat_vis_output:
        cat_vis_output.clear_output()
        cat_vis_interactive = widgets.interactive_output(cat_vis, {'col':cat_dd, 'sharey':y_axis_checkbox})
        display(cat_vis_interactive)
    
    with cont_output:
        cont_output.clear_output()
        cont_vis_interactive = widgets.interactive_output(
            cont_vis, 
            {'col':cont_dd, 'min_max':min_max_slider, 'z_val':z_val_tb}
        )
        display(cont_vis_interactive)

    

In [61]:
subset_button.on_click(filter_results)

In [27]:
def min_max_slider_range(change):
    #col = cont_dd.value
    col = change.new
    vals = data[col]
    vals_outlier_mask = np.abs(st.zscore(vals)) < z_val_tb.value
    vals = vals.loc[vals_outlier_mask]
    #min_max_slider.min = vals.min()
    min_max_slider.max = vals.max()
    min_max_slider.value = (0, vals.max())

In [28]:
cont_dd.observe(min_max_slider_range, names='value')

In [29]:
def cat_summary(col):

    stat_dict = {}
    categories = data[col].unique()
    column_shaps = shaps[col]
    var = data[col]
    #mean_shaps = column_shaps.mean()
    
    for cat in categories:
        
        shap_subset = shaps.loc[var==cat, col]
        mean_shap_subset = shap_subset.mean()
        
        counterfactual = shaps.loc[var!=cat, col]
        counterfactual_mean = shaps.loc[var!=cat, col].mean()
        
        # effect size
        mean_diff_shaps = mean_shap_subset - counterfactual_mean
        
        # statistical inference (confidence intervals and p values) for differences in means 
        cm_total = sms.CompareMeans(sms.DescrStatsW(shap_subset), sms.DescrStatsW(counterfactual))
        total_diff_conf_int = cm_total.tconfint_diff(usevar='unequal')
        total_diff_p_val = cm_total.ttest_ind(usevar='unequal')[1]
        
        # Dictionary of Summary Stats
        stat_dict[cat] = {'Mean SHAP': round(mean_shap_subset,4),
                          'Average Effect of Feature Value on SHAP': round(mean_diff_shaps,4),
                          'CI Lower':round(float(total_diff_conf_int[0]),4),
                          'CI Upper':round(float(total_diff_conf_int[1]),4),
                          'p value':round(float(total_diff_p_val),4)
                       }
        
    summary = pd.DataFrame(stat_dict).T.sort_values('Mean SHAP', ascending=True)
    
    return summary

In [30]:
def display_results(func):
    def inner(*args, **kwargs):
        results = func(*args, **kwargs)
        display(results)
    return inner

In [31]:
@display_results
def display_cat_summary(*args, **kwargs):
    return cat_summary(*args, **kwargs)

In [32]:
#@cat_vis_interactive.capture(clear_output=True, wait=True)
def cat_vis(col, sharey): 

    col_shaps = shaps[col]
    col_data = data[col]
    summary_data = cat_summary(col)
    row_order = summary_data.index
    
    joined_instance = pd.concat([col_data, col_shaps], axis=1)
    joined_instance.columns = [col, col+' SHAP']
    g = sns.FacetGrid(joined_instance, row=col, aspect=5, height=2, sharey=sharey, row_order=row_order)
    
    for ax, row in zip(g.axes, g.row_names):
        ax[0].axvline(0, lw=1, ls='dashed', c='black')
        shap_subset = col_shaps.loc[col_data==row]
        mask_alpha = (shap_subset > 0)
        bins = np.linspace(min(shap_subset),max(shap_subset), int(np.sqrt(shap_subset.size)))
        ax[0].hist(shap_subset, bins=bins, histtype='bar',color=sns.xkcd_rgb["windows blue"],lw=0)
        ax[0].hist(shap_subset[mask_alpha], bins=bins, histtype='bar',color=sns.xkcd_rgb["light red"],lw=0)
        row = 'NULL' if row==-999 else row
        ax[0].set_title('{} = {}'.format(col, row), )

    g.set_xlabels('SHAP values for {}'.format(col), **label_kwargs)
    #g.set_ylabels('Count of visits', **label_kwargs)
    #g.set_title("Relationship between {} & SHAP Values".format(col), **title_kwargs)
    #g.set(yticks=[])
    #g.despine(left=True)
    plt.show()


In [33]:
def cont_vis(
    col,
    z_val=3,
    min_max=None
    ):
    
    vals = data[col]
    vals_outlier_mask = np.abs(st.zscore(vals)) < z_val
    vals_nan_mask = vals!=-999
    
    col_shaps = shaps[col]
    
    vals = vals.loc[vals_outlier_mask].loc[vals_nan_mask]
    
    col_shaps = col_shaps.loc[vals_outlier_mask].loc[vals_nan_mask]
    
    min_val, max_val = min_max if min_max else None
        
    if max_val:
        ulim = (vals < max_val)
        vals = vals.loc[ulim]
        col_shaps = col_shaps.loc[ulim]
        
    if min_val:
        llim = (vals > min_val)
        vals = vals.loc[llim]
        col_shaps = col_shaps.loc[llim]
                
    sns.set(**sns_set_kwargs)
    
    g = sns.JointGrid(x=vals, y=col_shaps, height=8, space=0.4)
    g.fig.set_figwidth(16)
    
    offset = mcolors.TwoSlopeNorm(
        vmin=col_shaps.min(),
        vcenter=0.,
        vmax=col_shaps.max()
    )

    g = g.plot_joint(
        plt.scatter, 
        s=10, 
        c=np.array(offset(col_shaps.values)), 
        cmap=sns.diverging_palette(h_neg=240, h_pos=10, as_cmap=True)
    )
    
    try:
        g = g.plot_marginals(sns.distplot, kde=True, color=".5")
    except RuntimeError as re:
        if str(re).startswith("Selected KDE bandwidth is 0. Cannot estimate density."):
            g = g.plot_marginals(sns.distplot, kde=False, color=".5")
        else:
            raise re
    
    g.ax_joint.axhline(0, lw=1, ls='dashed', c='black')
    
    g.ax_joint.set_xlabel(col+' value', **label_kwargs)
    
    g.ax_joint.set_ylabel("SHAP Values for "+col, **label_kwargs)
    
    g.ax_joint.set_title("Relationship between "+col+" & SHAP Values", **title_kwargs)
        
    plt.show()
    

In [34]:
def top_features(max_display):
    shap.summary_plot(
        shaps.values,
        data,
        plot_type='bar',
        max_display=max_display
    )

In [35]:
label_kwargs = {
    'fontstyle' : 'normal', 
    'color' : '#585858', 
    'fontsize' : 16,
    'fontweight' : 'bold',
}

title_kwargs = {
    'fontstyle' : 'normal', 
    'fontsize' : 16,
    'fontweight' : 'bold',
}

sns_set_kwargs = { 
    'font_scale' : 1.2, 
    'font' : 'Arial',
    'rc' : {'figure.figsize':(21.7,12.7)},
}

In [36]:
#### 3.0 Output ####

In [37]:
### output capturing ###

In [70]:
data_size = widgets.Output()
size_out = widgets.Output()
important_out = widgets.Output()
cat_vis_output = widgets.Output()
cat_table_output = widgets.Output()
cont_output = widgets.Output()

In [39]:
### interactive output widgets ###

In [40]:
cat_vis_interactive = widgets.interactive_output(cat_vis, {'col':cat_dd, 'sharey':y_axis_checkbox})

In [41]:
cat_table_interactive = widgets.interactive_output(display_cat_summary, {'col':cat_dd})

In [42]:
cont_vis_interactive = widgets.interactive_output(
    cont_vis, 
    {'col':cont_dd, 'min_max':min_max_slider, 'z_val':z_val_tb}
)

In [56]:
### initialisation values ###

In [58]:
with size_out:
    print('No subset selected')

In [71]:
with data_size:
    print('{}'.format(raw_data.shape[0]))

In [44]:
with important_out:
    interact(top_features, max_display=widgets.IntText(value=10, description='Plot top:', disabled=False))

In [45]:
with cat_vis_output:
    display(cat_vis_interactive)

In [46]:
with cat_table_output:
    display(cat_table_interactive)

In [47]:
with cont_output:
    display(cont_vis_interactive)

# Subset

Choose a filter by which to subset the data by using the following dropdown options. 

Hit 'Filter Results' when you've made your selection in order to update the plots. 

You can return to analysing the whole dataset using the 'Reset Data' button.

In [55]:
widgets.HBox(
    [
        widgets.Label('Filter data by:'),
        cat_subsetter,
        widgets.Label('Is equal to:'),
        subsetter_values,
        subset_button
    ]

)

HBox(children=(Label(value='Filter data by:'), Dropdown(options=('Marketing Channel', 'Navigate Photos', 'Devi…

In [75]:
widgets.HBox(
    [
        widgets.Label('Raw dataset size:'),
        data_size,
        widgets.Label('Selected subset size:'),
        size_out
    ]
)

HBox(children=(Label(value='Raw dataset size:'), Output(outputs=({'output_type': 'stream', 'text': '83152\n', …

In [64]:
reset_button

Button(description='Reset Data', style=ButtonStyle())

# Most Important Features

The most important features overall for the selected data.

In [49]:
important_out

Output()

# Categorical Feature Importances

Use the dropdown to select a categorical feature for plotting. SHAP values above zero contribute to an increased likelihood of a single page visit, whilst those below zero decrease SPV probability. 

Tick the 'Same y-axis scales' box to plot distributions on the same scale if it aids comparisons.

In [50]:
widgets.VBox(
    [
        widgets.HBox(
            [
                widgets.HBox([widgets.Label('Categorical variables: '), cat_dd]), 
                y_axis_checkbox
            ]
        ), 
        cat_table_output, 
        cat_vis_output
    ]
)

VBox(children=(HBox(children=(HBox(children=(Label(value='Categorical variables: '), Dropdown(options=('Market…

# Continuous Feature Importances

Continuous features are selected for plotting using the dropdown box; it's possible to 'zoom in' on a specific area of the plot using the range slider.

The z-score controls the extents of the plotting distribution. For example, a z-score of 3 means points beyond three standard deviations from the mean won't be plotted.

In [51]:
widgets.VBox(
    [
        widgets.HBox(
            [
                widgets.HBox([widgets.Label('Continuous variables: '), cont_dd, widgets.Label('x-axis range:'), min_max_slider, z_val_tb])
            ]
        ), 
        cont_output
    ]
)

VBox(children=(HBox(children=(HBox(children=(Label(value='Continuous variables: '), Dropdown(options=('user sp…

In [52]:

#design builds: 
# slice data by subsets according to feature values
# config class to avoid global vars
