In [1]:
# Setting options for the plots
%matplotlib inline
%config InlineBackend.figure_formats={'retina', 'svg'}
%config InlineBackend.rc={'savefig.dpi': 150}

test


# Explanation Report

In [None]:
import shap
import pickle
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import os
import time
import pandas as pd
import itertools

from IPython import sys_info
from IPython.display import display, HTML, Image, Javascript, Markdown, SVG
from os.path import abspath, relpath, exists, join
from rsmtool.reader import DataReader
from rsmtool.writer import DataWriter
from rsmtool.utils.files import parse_json_with_comments
from rsmtool.utils.notebook import (float_format_func,
                                    int_or_float_format_func,
                                    compute_subgroup_plot_params,
                                    bold_highlighter,
                                    color_highlighter,
                                    show_thumbnail)

sns.set_context('notebook')

In [None]:
rsm_report_dir = os.environ.get('RSM_REPORT_DIR', None)

if rsm_report_dir is None:
    rsm_report_dir = os.getcwd()

rsm_environ_config = join(rsm_report_dir, '.environ.json')
if not exists(rsm_environ_config):
    raise FileNotFoundError('The file {} cannot be located. '
                            'Please make sure that either (1) '
                            'you have set the correct directory with the `RSM_REPORT_DIR` '
                            'environment variable, or (2) that your `.environ.json` '
                            'file is in the same directory as your notebook.'.format(rsm_environ_config))
    
environ_config = parse_json_with_comments(rsm_environ_config)

<style type="text/css">
  div.prompt.output_prompt { 
    color: white; 
  }
  
  span.highlight_color {
    color: red;
  }
  
  span.highlight_bold {
    font-weight: bold;  
  }
    
  @media print {
    @page {
      size: landscape;
      margin: 0cm 0cm 0cm 0cm;
    }

    * {
      margin: 0px;
      padding: 0px;
    }

    #toc {
      display: none;
    }

    span.highlight_color, span.highlight_bold {
        font-weight: bolder;
        text-decoration: underline;
    }

    div.prompt.output_prompt {
      display: none;
    }
    
    h3#Python-packages, div#packages {
      display: none;
  }
</style>

In [None]:
experiment_id = environ_config.get('EXPERIMENT_ID')
description = environ_config.get('DESCRIPTION')
explanation_path = environ_config.get('EXPLANATION')
background_size = environ_config.get('BACKGROUND_SIZE')
id_path = environ_config.get('IDs')
csv_path = environ_config.get('CSV_DIR')
fig_path = environ_config.get('FIG_DIR')
display_features = int(environ_config.get('DISPLAY_NUM'))

# here we load some objects that we need for our plots
with open(explanation_path, 'rb') as pickle_in:
    explanations = pickle.load(pickle_in, encoding='bytes')
with open(id_path, 'rb') as pickle_in:
    ids = pickle.load(pickle_in, encoding='bytes')

abs_values = DataReader.read_from_file(csv_path+'/abs_shap_values.csv')
mean_values = DataReader.read_from_file(csv_path+'/mean_shap_values.csv')
max_values = DataReader.read_from_file(csv_path+'/max_shap_values.csv')
min_values = DataReader.read_from_file(csv_path+'/min_shap_values.csv')#, index_col=0)



# we need to load the explanations and ids:

# javascript path
javascript_path = environ_config.get("JAVASCRIPT_PATH")

In [None]:
with open(join(javascript_path, "sort.js"), "r", encoding="utf-8") as sortf:
    display(Javascript(data=sortf.read()))

In [None]:
Markdown('''This report presents the shap explanations for **{}**: {}'''.format(experiment_id, description))

In [None]:
HTML(time.strftime('%c'))

In [None]:
%%html
<div id="toc"></div>

## Description of the Data Passed

In [None]:
display(Markdown("A total of {} rows of data were passed to the explainer. \n".format(len(ids.keys()))))



In [None]:
if len(ids.values())<100:
    display(Markdown("The following row-ids were selected for explanation: \n {}".format([i for i in ids.values()])))
else:
     display(Markdown("Too many rows (>100) were selected to display a list of all row-ids. Please refer to the \'ids.pkl\' file if you need to check which rows were sampled for explanation." ))

### Background Distribution

The background distribution is responsible for generating the base value of your shap explanation. The base value should represent the mean model prediction in the data that was passed as background.

By default, rsmexplain generates a kmeans-clustered representation of your background sample. The size of which you can specify in the "background_size" parameter. Generally a kmeans sample of size 100 or higher is sufficient to generate an accurate base value.


In [None]:
if background_size:
    display(Markdown("A background sample of kmeans size {} was passed to the explainer. A smaller background sample will be faster but less accurate.".format(background_size)))
else:
    display(Markdown("No background size was specified for the background sample. This defaults the background sample to a size of kmeans=500. A smaller background sample will be faster but less accurate."))


## SHAP Values

### A brief introduction to shapley values

SHAP values are generated through the [SHAP library](https://shap.readthedocs.io/en/latest/index.html) and are approximations of [Shapley Values](https://en.wikipedia.org/wiki/Shapley_value), a concept derived from game-theory. A very abbreviated explanation of how these values are generated: for every model decision passed to the explainer, the explainer considers how the model decision is impacted by removing that feature. For a more in-depth explanation consider this [summary article](https://towardsdatascience.com/understanding-how-ime-shapley-values-explains-predictions-d75c0fceca5a).

Rsmexplain by default uses the [Sampling](https://shap.readthedocs.io/en/latest/generated/shap.explainers.Sampling.html#shap.explainers.Sampling) explainer model, which computes shap values through random permutations of the features, a method described [here](https://link.springer.com/article/10.1007/s10115-013-0679-x).

The sampling explainer is model agnostic, meaning it should in principle work for any type of model. Rsmexplain currently only supports regressors. 


### How to read shap values

Shap values are additive representations of a feature's impact on a model decision. The sum of all shap values and the base value for a prediction should yield the actual model output.

A shap value for a feature can be considered that feature's contribution to the decision during that specific prediction. By calculating an absolute mean of all shap values of a feature, we can calculate an average impact for the data that was passed to the explainer. Absolute mean shap values are saved in "/output/mean_shap_values.csv".



### Things to consider

Rsmexplain can only generate shap values for the data passed in the "explainable_data" and "range" parameters. If the dataset passed is small, then the values derived cannot be considered representative of the model as a whole. Plots that display mean values for your shap values should be taken with a grain of salt if your passed data was small, or not representative of the typical data the model deals with.

As long as sufficiently large background set was passed, the individual values for predictions can be considered trustworthy.

If you wish to investigate your shap values by hand, please refer to files in "/output/".

If you wish to use the generated shap Explanation object, you may unpickle "explanation.pkl". Your initial row ids are stored in "ids.pkl" in a dictionary format of \{array index: actual index\}.

### An overview over your shap values

This is a quick text overview over your shap values. Please refer to the Plots section for visualizations.

All values are rounded to $10^{-3}$ unless specified otherwise.

#### Absolute Mean Shap Values

The top 5 features in terms of absolute mean impact were:

In [None]:
if abs_values.shape[1] < 6:
    table = HTML(abs_values[0:5].to_html(classes=['sortable'], index=False, float_format=float_format_func))
else:
    display(Markdown("Your model has 5 or less features. Displaying all:"))
    table = HTML(abs_values.to_html(classes=['sortable'], index=False, float_format=float_format_func))
    

In [None]:
display(table)


__The following features have an absolute mean shap value of 0:__

In [None]:
try:
    value0 = mean_values.loc[mean_values['abs. mean shap'].isin([0])]
    if value0.shape[0] < 11 and value0.shape[0] > 0:
        table = HTML(value0.to_html(classes=['sortable'], index=False, float_format=float_format_func))
    elif value0.shape[0] > 10:
        display(Markdown("You have over 10 features with an absolute mean shap value of 0. Displaying 10 only."
                         " Please check your mean_shap_values.csv file for all the features."))
        table = HTML(value0[0:10].to_html(classes=['sortable'], index=False, float_format=float_format_func))
    else:
        display(Markdown("No features with a mean value of 0 found."))
        table = None
except:
    display(Markdown("Something went wrong with your feature table."))
    table = None

In [None]:
table

If features appear in the above list with a mean shap value of 0, then those features did not contribute to the model decisions. If the data set passed was large and representative of the data the model usually encounters, then this may mean that those features are not useful for the model.

Before you draw conclusions, make sure that those features were not simply set to 0 in all data instances that were passed to the model. This might accidentally create this effect.

The following features are the __bottom 10 features that have an absolute mean shap value of >0__, ranked by abs. mean shap value. The table includes the absolute mean shap value of each feature and the absolute max and min values of that feature. 
__Rounding is disabled for this table in order to avoid values appearing as 0.__

In [None]:
value_nonzero = abs_values.loc[abs_values['abs. mean shap'] != 0]
if abs_values.shape[1] < 11:
    table = HTML(value_nonzero[-11:-1].to_html(index=False, classes=['sortable']))
else:
    display(Markdown("Your model has 10 or less features. Displaying all:"))
    table = HTML(value_nonzero.to_html(index=False, classes=['sortable']))
    

In [None]:
table


# Absolute Max Shap Values

Here are the top 10 features in terms of absolute maximal impact:

In [None]:
if abs_values.shape[1] < 11:
    table = HTML(abs_values.sort_values(by=['abs. max shap'], ascending=False)[0:5].to_html(classes=['sortable'], 
                                                                                float_format=float_format_func,
                                                                                index=False))
else:
    display(Markdown("Your model has 10 or less features. Displaying all:"))
    table = HTML(abs_values.sort_values(by=['abs. max shap'], ascending=False).to_html(index=False,
        classes=['sortable'], float_format=float_format_func))

In [None]:
table

If the features in the above list do not overlap with the top 5 in terms of absolute mean impact, then these features have high outlier values, but less overall average impact.

## General SHAP Plots

These are general shap plots that cover all the data passed. By default, these plots only display the top 15 features according to their ranking metric. This number can be adjusted in the config file by adding a `"display_num"` parameter with an integer value.

In [None]:
display(Markdown("`display_num` set to: {}".format(display_features)))

### Heatmap Plot

This plot offers a condensed high-level overview over the data passed. It presents a plot with data instances on the x-axis, the model decisions on the y-axis, and the SHAP values encoded on a color scale. By default the samples are ordered based on hierarchical clustering by their explanation similarity. This results in samples that have similar model output for similar reasons getting grouped together.

Features are ranked by mean absolute impact, meaning the highest feature in this plot has the highest average impact on the model decisions given the data passed.


In [None]:
shap.plots.heatmap(explanations, max_display=display_features, show=False)
path = join(fig_path, 'heatmap_cluster')
plt.savefig(path,dpi=300, bbox_inches='tight')

#### Prediction value ordered heatmap plot

This heatmap plot has its x-axis sorted in descending order of the model output value. Starting at the highest output value down to the lowest output value.

This plot can be useful to spot features that display counter-intuitive behaviors or clustering. We expect the feature colors (which represent the shap value) to be on a gradient if they correlate with the model output. If the colors instead display clusters, then the feature does not necessarily correlate with the output.

In [None]:
shap.plots.heatmap(explanations,instance_order=explanations.sum(1), max_display=display_features, show=False)
path = join(fig_path, 'heatmap_output_ordered')
plt.savefig(path,dpi=300, bbox_inches='tight')

### Global Bar Plot

This plot gives a quick overview over the shap values of the data passed. Features are ranked by mean absolute impact.

The number to the right of the bar represents the mean absolute shap value of that feature.

The higher the mean shap value of your feature is, the higher the average contribution of that feature to a model decision is.


In [1]:
shap.plots.bar(explanations, max_display=display_features, show=False)
path = join(fig_path, 'global_bar')
plt.savefig(path,dpi=300, bbox_inches='tight')

NameError: name 'shap' is not defined

### Beeswarm Plot

The beeswarm plot gives an information-dense overview over your shap-values. Each row of data (i.e. model-decision) is represented by a dot on the given feature row in the plot.  The x-axis position of the dot is determined by the shap-value of that feature in that given decision. The further away from the 0-value a dot is, the higher the impact of that feature was for that decision. This impact can be negative (to the left) or positive (to the right).

The feature value (not shap value!) is marked by the color on the plot. Red signifies a high feature value, blue signifies a low feature value. Features are ranked by the mean-absolute impact they have on a model decision. The top feature in this plot will have the highest mean absolute impact. 


In [None]:
shap.plots.beeswarm(explanations, max_display=display_features, show=False)
path = join(fig_path, 'beeswarm')
plt.savefig(path,dpi=300, bbox_inches='tight')

#### Beeswarm ranked by maximum impact

This beeswarm plot is ranked by the absolute max-impact of your features. The highest ranked feature in this plot will have the highes maximum impact on a model decision. This can be relevant if you want to catch features that on average do not have a high impact, but have high maximum impact instead.

In [None]:
shap.plots.beeswarm(explanations, order=explanations.abs.max(0), max_display=display_features, show=False)
path = join(fig_path, 'beeswarm_max_impact')
plt.savefig(path,dpi=300, bbox_inches='tight')

#### Absolute mean beeswarm

This plot is equivalent to the first beeswarm plot, but has the values transformed for absolute impact. This is useful if you want to see how much impact a feature has on average while also displaying where those impact values are clustered. This is a information richer version of the simple bar-plot in the Bar Plot section.

**Disclaimer**: The beeswarm plot is known to have some ordering issues due to a rounding effect. If the feature order here differs from the order in the bar plot, then assume the bar plot to display the correct feature order.

In [None]:
shap.plots.beeswarm(explanations.abs, order=explanations.abs.mean(0), max_display=display_features, show=False)
path = join(fig_path, 'beeswarm_abs_impact')
plt.savefig(path,dpi=300, bbox_inches='tight')

## Auto cohort bar plots

This plot represents a bar plot split into 2 cohorts that optimally separate the SHAP values of the data instances using a sklearn DecisionTreeRegressor. 

The features are plotted as absolute mean impact.

The plot shows the two cohorts in differently shaded bars, the legend informs us about the metric along which the cohorts are chosen.

This plot can be useful to detect interaction effects between cohorts and features. If a cohort shows a high feature value, then there may be an interaction between that cohort and the feature.

In [None]:

# The cohort plot seems to break if the feature names in the explanation object are not defined as a list
if isinstance(explanations.feature_names, list):
    shap.plots.bar(explanations.cohorts(2).abs.mean(0), show=False)
    path = join(fig_path, 'auto_cohort_3')
    plt.savefig(path,dpi=300, bbox_inches='tight')
else:
    try:
        explanations.feature_names = list(explanations.feature_names)
        shap.plots.bar(explanations.cohorts(2).abs.mean(0), show=False)
        path = join(fig_path, 'auto_cohort_3')
        plt.savefig(path,dpi=300, bbox_inches='tight')
    except:
        display("There was an error generating your cohort plot. Likely cause: feature_names not in correct format.")
#display(explanations.shape)


### Auto cohort plot with 3 cohorts

In [None]:
shap.plots.bar(explanations.cohorts(3).abs.mean(0), show=False)
path = join(fig_path, 'auto_cohort_3')
plt.savefig(path,dpi=300, bbox_inches='tight')

### Auto cohort plot with 4 cohorts

In [None]:
shap.plots.bar(explanations.cohorts(4).abs.mean(0), show=False)
path = join(fig_path, 'auto_cohort_3')
plt.savefig(path,dpi=300, bbox_inches='tight')

In [None]:
%%javascript

// Code to dynamically generate table of contents at the top of the HTML file
var tocEntries = ['<ul>'];
var anchors = $('a.anchor-link');
var headingTypes = $(anchors).parent().map(function() { return $(this).prop('tagName')});
var headingTexts = $(anchors).parent().map(function() { return $(this).text()});
var subList = false;

$.each(anchors, function(i, anch) {
    var hType = headingTypes[i];
    var hText = headingTexts[i];
    hText = hText.substr(0, hText.length - 1);
    if (hType == 'H2') {
        if (subList) {
            tocEntries.push('</ul>')
            subList = false;
        }
        tocEntries.push('<li><a href="' + anch + '"</a>' + hText + '</li>')
    }
    else if (hType == 'H3') {
        if (!subList) {
            subList = true;
            tocEntries.push('<ul>')
        }
        tocEntries.push('<li><a href="' + anch + '"</a>' + hText + '</li>')
    }
});
tocEntries.push('</ul>')
$('#toc').html(tocEntries.join(' '))