In [None]:
%load_ext autoreload
%autoreload 2

# AI Lab tools/utils

In [None]:
import ailab as lab

# References

###  For more info on SHAP (SHapley Additive exPlanations) see

- https://github.com/slundberg/shap
- http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions
- https://christophm.github.io/interpretable-ml-book/shapley.html

### More about Plotly Dash

- https://dash.plot.ly/

# Read Data

In [None]:
import pandas as pd

feature_names=["Age", "Workclass", "Final Weight", "Education", "Education-Num", "Marital Status",
               "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss",
               "Hours per week", "Country", "Income"]

df_train=pd.read_csv("../../data/adult.data",
                     names=feature_names)

df_test=pd.read_csv("../../data/adult.test",
                    skiprows =1,
                    names=feature_names)



# Data prep & problem definition

In [None]:
def df_prep(df):    
    LABEL="Income"
    y_train=df[LABEL].replace([" <=50K"," <=50K."," >50K"," >50K."],[0,0,1,1])
    df_X=df.drop(LABEL, axis=1)
    return df_X,y_train

df_X_train,y_train=df_prep(df_train)
df_X_test,y_test=df_prep(df_test)

# Load trained model

## Load with joblib

In [None]:
import os
import joblib

def load_model(model_id):
    
    load_path=f"../../models/{model_id}.pickle"
    return joblib.load(load_path)
   
model=load_model("adult_randomforest")
model


## Check model is working!

In [None]:
from sklearn.metrics import roc_curve, auc

y_score = model.predict_proba(df_test)[:,1]
fpr, tpr, _ = roc_curve(y_test,y_score)
roc_auc= auc(fpr, tpr)
print("roc_auc",roc_auc)

assert roc_auc>=.85,"something is wrong, check that score, auc<.8"

# Calculate Shapley values

## Get transformed feature values/matrix

In [None]:
import sklearn
from copy import deepcopy

# Split pipeline & classifier, allowing Shap to use fast tree method
# *Work in progress* - could we bypass this? Let me know!

model_prep=deepcopy(model)
model_clf=model_prep.steps.pop()[1]

# Using 100 rows, just for quick testing
# more rows add more detail to shap explanations
SAMPLE_ROWS=100

df_explanation=df_test.sample(SAMPLE_ROWS)

# Get transformed set from raw data (shap will need it instead of raw data)
X_explanation_rows=model_prep.transform(df_explanation)
X_explanation_rows


In [None]:
import shap

feature_names=X_explanation_rows.columns.values

# use fast shapley tree explainer
shap_explainer=shap.TreeExplainer(model_clf)
shap_values = shap_explainer.shap_values(X_explanation_rows)

# Use [1] (true label is in index 1, get that)
shap_values=shap_values[1]

# can take a while depending on cores/proc speed
# 100 rows on i5 laptop, for this dataset, ~ 10 secs

# Data prep for explanations & Dash

In [None]:
# Some columns that help on table sorting, filtering, etc

df_explanation["SCORE_PROB"]=model_clf.predict_proba(X_explanation_rows)[:,1]

# Use baseline probability from shap_explainer to create ratio to base rate
df_explanation["SCORE_RATIO"]=df_explanation.SCORE_PROB/shap_explainer.expected_value[1]

# Just because...
df_explanation[["SCORE_PROB","SCORE_RATIO"]]=df_explanation[["SCORE_PROB","SCORE_RATIO"]].round(2)
df_explanation


In [None]:
# Cast calculated shap values as dataframe, to use index based getters later on
df_shap_values=pd.DataFrame(shap_values,columns=feature_names)
df_shap_values.index=X_explanation_rows.index

# Add top N positive/negative features to each row (this allows table search, which is cool) 
N=5

topN=df_shap_values.apply(lambda s: ",".join(s.nlargest(N).index.tolist()), axis=1)
bottomN=df_shap_values.apply(lambda s: ",".join(s.nsmallest(N).index.tolist()), axis=1)

df_explanation["TOP_POS"]=topN.values
df_explanation["TOP_NEG"]=bottomN.values

df_explanation

## Let's check! local reason codes

In [None]:
# Check shapley values (first row)
index=df_shap_values.index[0]
from IPython.display import display

pd.set_option('display.max_colwidth', 200)
display(df_explanation[["TOP_POS","TOP_NEG"]].loc[[index]].T)

shap.summary_plot(df_shap_values.loc[[index]].values,
                  X_explanation_rows.loc[[index]].values,
                  feature_names=feature_names)



## Global Shapley importances

In [None]:
shap.summary_plot(df_shap_values.values, X_explanation_rows,feature_names=feature_names)

# Prepare dataframe for Dash

In [None]:
# Add index, we'll use it on dash back & forward (hidden)
df_explanation["index"]=df_explanation.index.values
df_explanation

# Dash 

In [None]:
import sklearn.metrics as metrics
import plotly.graph_objs as go

def shap_local_importances(row_index,max_features=10):
    X_shap=X_explanation_rows.loc[[row_index],:]
    
    shap_values=df_shap_values.loc[[row_index]].values
    
    index=0

    # Create more friendly X axis, like feature=<feature value for that record>
    feature_names=X_shap.columns+"="+X_shap.iloc[[index],:].T.iloc[:,0].values.round(4).astype("str")
    
    df_row_shap=pd.DataFrame(shap_values[0,:],feature_names).reset_index()
    df_row_shap.columns=["feature","shapley"]
    df_row_shap["custom_data"]=X_shap.columns+"|"+str(row_index)
    df_row_shap["abs_shapley"]=df_row_shap.shapley.abs()
    df_row_shap.sort_values("abs_shapley",ascending=False,inplace=True)
    
    # Need to check plotly how to reverse axis, hacking for now, sorry!
    df_top_features=df_row_shap.head(max_features).sort_values("abs_shapley",ascending=True)
    
    figure={
            'data': [
                {
                    'x':  df_top_features.shapley.values,
                    'y': df_top_features.feature.values,
                    'customdata': df_top_features.custom_data.values,
                    'name': 'Local reason codes',
                    'mode': 'markers',
                    'marker': {'size': 10}
                }
            ],
            'layout':{
                "autosize":False,
                    "width":370,
                    "height":450,
                    "margin":go.Margin(
                        l=150,
                        r=50,
                        b=40,
                        t=0,
                      #  pad=50
                  ),
            }
        }
        
    figure = go.Figure(figure)
        
    return figure

# Test
shap_local_importances(X_explanation_rows.index[0])

In [None]:
import sklearn.metrics as metrics
import plotly.graph_objs as go

def shap_detail(row_index,feature):
    feature_shapley_values = df_shap_values[feature].values
    feature_values = X_explanation_rows[feature].values
    
    #friendly title feature=<feature value for that row>
    title="{0}={1}".format(feature,X_explanation_rows.loc[row_index][feature])
    
    figure={
            # All data points here
            'data': [

                {
                    'x':  feature_values,
                    'y': feature_shapley_values,
                    'name': 'Trace 1',
                    'mode': 'markers',
                    'marker': {'size': 5}
                },
                
                
            ],
            # Add annotation to highlight current feature value & position
            'layout':{
                "annotations":[
                        dict(
                            x=X_explanation_rows.loc[row_index][feature],
                            y=df_shap_values.loc[row_index][feature],
                            xref='x',
                            yref='y',
                            text=title,
                            showarrow=True,
                            arrowhead=7,
                            ax=0,
                            ay=-40,
                            bgcolor="red"
                        )
                    ],
                 #"title":title,
                "autosize":False,
                    "width":250,
                    "height":420,
                    "margin":go.Margin(
                        l=20,
                        r=10,
                        b=80,
                        t=20,
                        pad=60
                  ),
            },
            
        }
        
    figure = go.Figure(figure)
        
    return figure

# Test
shap_detail(X_explanation_rows.index[0],"Age")

In [None]:
import dash
from dash.dependencies import Input, Output, State
import dash_core_components as dcc
import dash_html_components as html
import dash_table_experiments as dt
import plotly


APP_NAME='SHAP Dash! Explanations on Dash - DevScope AI Lab'
app = dash.Dash(name=APP_NAME)

# Need this because we'll add some dynamic callbacks
app.config['suppress_callback_exceptions'] = True

# Hide index column
visible_cols=df_explanation.columns.drop(["index"]).values

# Reference: https://github.com/plotly/dash-svm/
external_css = [
    # Normalize the CSS
    "https://cdnjs.cloudflare.com/ajax/libs/normalize/7.0.0/normalize.min.css",
    # Fonts
    "https://fonts.googleapis.com/css?family=Open+Sans|Roboto",
    "https://maxcdn.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css",
    # Base Stylesheet, replace this with your own base-styles.css using Rawgit
    "https://rawgit.com/xhlulu/9a6e89f418ee40d02b637a429a876aa9/raw/f3ea10d53e33ece67eb681025cedc83870c9938d/base-styles.css",
    # Custom Stylesheet, replace this with your own custom-styles.css using Rawgit
    "https://cdn.rawgit.com/plotly/dash-svm/bb031580/custom-styles.css"
    #"https://gist.githubusercontent.com/rquintino/f67a1e9f2c13b9b3e0dae35ac6477295/raw/aa6d5c068882584ab2a019c510c32df80fd8c352/shap-dash-custom-styles.css"
]

for css in external_css:
    app.css.append_css({"external_url": css})
    
    
app.layout = html.Div([
    # Reference: https://github.com/plotly/dash-svm/
    # .container class is fixed, .container.scalable is scalable
    html.Div(className="banner", children=[
        html.Div(className='container scalable', children=[
            html.H2(html.A(
                APP_NAME,
                href='https://github.com/DevScope/ai-lab',
                style={
                    'text-decoration': 'none',
                    'color': 'inherit'
                }
            )),

            html.A(
                html.Img(src="https://s3-us-west-1.amazonaws.com/plotly-tutorials/logo/new-branding/dash-logo-by-plotly-stripe-inverted.png"),
                href='https://plot.ly/products/dash/'
            )
        ]),
    ]),
    html.Div(id='body', className='container scalable', children=[
         html.Div([
            html.Div(
                [
                    dt.DataTable(
                        rows=df_explanation.to_dict('records'),
                        editable=False,
                        sortable=True,
                        columns=visible_cols,
                        row_selectable=True,
                        filterable=True,
                        id='score_table'
                        ),
                    html.Button('Explain', id='explain'),
                    html.Button('Clear', id='clear'),
                ],className="five columns"),
             html.Div(
                 [
                        html.Div(id="output",children=[
                            # Needed to add this blank chart on startup, otherwise charts wouldnt load?
                            # Dash tables have similar issue
                            dcc.Graph(
                                id='example-graph',
                                figure={'data': [], 'layout': { }
                            }
                        ),
                        ],style={"padding":"20px"})
                   ],id="results",className="seven columns",style={'height':'500px','overflow-y': 'scroll'})
        ],className="row")
     ])
])

@app.callback(
   Output('score_table', 'selected_row_indices'),
   [Input("clear","n_clicks")])
def click_clear(n_clicks):
    return []

@app.callback(
   Output('output', 'children'),
   [Input("explain","n_clicks")],
    state=[
        State('score_table', 'rows'),
        State('score_table', 'selected_row_indices')
    ])
def explain_selected_rows(n_clicks,rows,selected_row_indices):
    # note: currently, the only way I know to get selected indexes
    # we have to send whole table back from browser to python kernel...
    # new table component is expected this summer by plotly team!
    # https://github.com/plotly/dash-table-experiments/issues/15
    if not rows or len(rows)==0 or not selected_row_indices or len(selected_row_indices)==0:
        return []
    
    selected_row_indices.sort()
    
    # Reconstruct dataframe with original indexes
    df_rows = pd.DataFrame(rows)
    df_rows.index=df_rows["index"]
    
    indexes=df_rows.iloc[selected_row_indices].index
    
    # Check selectd rows on notebook (debug)
    #display(df_rows.loc[indexes])
    
    children=[]
    
    MAX_FEATURES=12
    
    i=0
    for index in indexes:
        i=i+1
        original_index=int(index)
        
        original_row=df_explanation.loc[[original_index]]
        X_original_row=X_explanation_rows.loc[[original_index]]
        
        title=" Age {0} (#)".format(
                             original_row.Age.iloc[0],
                             i+1
                            )
        
        score="{0:.2%} Probability".format(
                             original_row.SCORE_PROB.iloc[0]
                            )
        
        children.append(html.Div([html.H6(title,className="six columns"),
                                 html.H6(score,className="six columns",style={"text-align":"right"})
                                 ],className="row"))
        
        fig_shap=shap_local_importances(original_index,MAX_FEATURES)  
        
        children.append(html.Div([dcc.Graph(id='summary_'+str(i),figure=fig_shap,
                                            className="six columns",
                                            config={ 'displayModeBar': False}),
                                 html.Div(id="detail_"+str(i),className="six columns")
                                 ],className="row"))
    return children

# Ok... more than 100 and this won't work, have to dig deeper...
for i in range(100):    
    @app.callback(
        Output('detail_'+str(i), 'children'),
        [Input('summary_'+str(i), 'clickData')])
    def detail(clickData):
        if not clickData or len(clickData)==0:
            return []
        
        children=[]
        
        # Grab feature and row index from customdata, better way? let me know please.
        point=clickData["points"][0]
        
        feature,row_index=point["customdata"].split("|")
        row_index=int(row_index)
                                                        
        fig_shap=shap_detail(row_index,feature)  
        
        children.append(dcc.Graph(id='detail_chart_'+str(row_index),figure=fig_shap,
                                 config={ 'displayModeBar': False}))
        
        return children



In [None]:
# use <esc> i+i on Jupyter to quick interrupt & get control back to jupyter
lab.show_app(app)
