# Implementation of TreeSHAP
*Understanding TreeSHAP: A complete tutorial*

## Libraries

In [1]:
pip install shap

Collecting shap
  Downloading shap-0.43.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (532 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m532.9/532.9 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Collecting slicer==0.0.7 (from shap)
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.43.0 slicer-0.0.7


In [2]:
import numpy as np
import pandas as pd
import shap

from sklearn.datasets import load_diabetes
from sklearn.tree import DecisionTreeRegressor

## Code of TreeSHAP from the article

In [3]:
def UNWIND_weights(weights_wound, ones_ki, zeros_ki):
  Dk = len(weights_wound) - 1
  weights_unwound = np.zeros(Dk)

  # First case
  if ones_ki == 0:
    for h in range(0, Dk):
      weights_unwound[h] = (Dk + 1)/(Dk - h) * weights_wound[h] / zeros_ki

  # Second case
  else:
    omega = np.zeros(Dk+1)
    omega[Dk] = weights_wound[Dk]

    for h in range(Dk-1, -1, -1): # h = Dk-1, Dk-2, ..., 1, 0
      weights_unwound[h] = omega[h+1] * (Dk+1)/(h+1) * 1/ones_ki
      omega[h] = weights_wound[h] - (Dk-h)/(Dk+1) * weights_unwound[h] * zeros_ki

  return list(weights_unwound)

def UNWIND(features_wound, ones_wound, zeros_wound, weights_wound, i):

  ## UNWIND_weights ##
  weights_unwound = UNWIND_weights(weights_wound, ones_wound[i], zeros_wound[i])

  ## UNWIND features_wound, ones_wound, zeros_wound ##
  features_unwound = list(np.delete(features_wound, i))
  ones_unwound = list(np.delete(ones_wound, i))
  zeros_unwound = list(np.delete(zeros_wound, i))

  return features_unwound, ones_unwound, zeros_unwound, weights_unwound

def EXTEND(features_P, ones_P, zeros_P, weights_P, feature_e, one_e, zero_e):

  # We save the encountered feature, the condition and the proportion
  features_Pe = features_P + [feature_e]
  ones_Pe = ones_P + [one_e]
  zeros_Pe = zeros_P + [zero_e]

  D = len(weights_P)-1 # Number of distinct features met uptill now

  # First case: weights_P is empty, we leave 'All Data' node
  if not weights_P:
    weights_Pe = [1]

  # Second case: weights_P is not empty
  else:
    weights_P_ = weights_P + [0] # We add w[D+1 , P] = 0

    # We compute the new weights
    weights_Pe = np.zeros(D+2)
    weights_Pe[0] = (D+1)/(D+2) * weights_P_[0] * zero_e
    for h in range(0, D+1):
      weights_Pe[h+1] += (h+1)/(D+2) * weights_P_[h] * one_e
      weights_Pe[h+1] += (D-h)/(D+2) * weights_P_[h+1] * zero_e

  return features_Pe, ones_Pe, zeros_Pe, list(weights_Pe)


def treeSHAP(tree, x):
  phi = np.zeros(len(x)) # We initialize the SHAP values at 0

  def RECURSE(node, features_, ones_, zeros_, weights_, feature_e, one_e, zero_e):

    args = (features_, ones_, zeros_, weights_, feature_e, one_e, zero_e)
    features_, ones_, zeros_, weights_ = EXTEND(*args)

    if node.is_leaf: # Case 1: the node is a leaf
      for i in range(1, len(weights_)):
        phi[features_[i]] += sum(UNWIND_weights(weights_, ones_[i], zeros_[i])) * (ones_[i] - zeros_[i]) * node.value

    else: # Case 2: the node is internal
      # We establish which child is hot and which one is cold
      if x[node.feature] <= node.threshold:
        hot_child, cold_child = node.left_child, node.right_child
      else:
        cold_child, hot_child = node.left_child, node.right_child

      feature_e = node.feature # Identical for both hot and cold children
      hot_one_e = 1 # By definition of the hot child
      cold_one_e = 0
      hot_zero_e = hot_child.n_samples / node.n_samples
      cold_zero_e = cold_child.n_samples / node.n_samples

      # We check whether feature_e has already been encountered
      feature_e_indexes = np.where(features_==feature_e)[0]
      if len(feature_e_indexes) > 0: # the feature has already been encountered
        i = feature_e_indexes[0]
        hot_one_e *= ones_[i]
        hot_zero_e *= zeros_[i]
        cold_zero_e *= zeros_[i]

        # UNWIND to delete feature_e
        args = (features_, ones_, zeros_, weights_, i)
        features_, ones_, zeros_, weights_ = UNWIND(*args)

      # Final recursive calls
      RECURSE(hot_child, features_, ones_, zeros_, weights_, feature_e, hot_one_e, hot_zero_e)
      RECURSE(cold_child, features_, ones_, zeros_, weights_, feature_e, cold_one_e, cold_zero_e)

  # We call RECURSE on the root of the tree,
  # by setting feature_e = -1 to indicate there is no feature on the first edge
  RECURSE(tree.root, [], [], [], [], -1, 1, 1)

  return phi

## Tests on different decision trees

We first implement a custom decision tree.

In [4]:
class Leaf():

  def __init__(self, value=None, n_samples=None, name=None):
    self.value = value
    self.n_samples = n_samples
    self.is_leaf = True
    self.name = name # Useful for debugging

  def predict(self, X):
    return self.value

class Node():

  def __init__(self, left_child=None, right_child=None, feature=None, threshold=None, value=None, name=None):
    self.left_child = left_child
    self.right_child = right_child
    self.feature = feature
    self.threshold = threshold
    self.value = value
    self.n_samples = None
    self.name = name # Useful for debugging
    self.is_leaf = False

    self.root = self # To keep along with the notation of the article

  def predict(self, X):
    if X[self.feature] <= self.threshold:
      return self.left_child.predict(X)
    else:
      return self.right_child.predict(X)


def tree_from_dict(tree_dict, i=0): # Useful later on.

  if tree_dict['features'][i] == -2: # We are on a leaf
    value = tree_dict["values"][i][0] # We unpack the value
    n_samples = tree_dict["node_sample_weight"][i]
    return Leaf(value=value, n_samples=n_samples)

  else:
    node = Node()
    node.feature = tree_dict['features'][i]
    node.threshold = tree_dict['thresholds'][i]
    node.value = tree_dict['values'][i][0]
    node.n_samples = tree_dict["node_sample_weight"][i]

    node.left_child = tree_from_dict(tree_dict, tree_dict['children_left'][i])
    node.right_child = tree_from_dict(tree_dict, tree_dict['children_right'][i])

    return node

### Decision tree from the article

In [5]:
 # children_left[i] corresponds to the index of the left child of node i. children_left[i] = -1 means i does not have a left child
 # features[i] = -2 means there is no condition (hence, i is a leaf). Also we adpoted the following indexation: {'age': 0, 'weight': 1, 'height': 2, 'bodyfat': 3}
children_left = np.array([1, 2, -1, 4, -1, -1, 7, 8, -1, -1, -1])
children_right = np.array([6, 3, -1, 5, -1, -1, 10, 9, -1, -1, -1])
features = np.array([0, 1, -2, 2, -2, -2, 3, 0, -2, -2, -2])
thresholds = np.array([50, 100, -2, 180, -2, -2, 30, 70, -2, -2, -2])
values = np.array([0.41, 0.32666666666666667, 0.2, 0.8333333333333334, 0.9, 0.7, 0.66, 0.4666666666666667, 0.4, 0.6, 0.95]) # Values of internal nodes are weighted average of the values of their children.
node_sample_weight = np.array([100, 75, 60, 15, 10, 5, 25, 15, 10, 5, 10])

# Formatting the custom tree
tree_dict = {
    "children_left": children_left,
    "children_right": children_right,
    "children_default": children_right.copy(), # We write it to keep along with the formatting of the package shap.
    "features": features,
    "thresholds": thresholds,
    "values": np.reshape(values,(-1,1)), # Same remark.
    "node_sample_weight": node_sample_weight
}
model = {"trees": [tree_dict]}

In [6]:
# Custom samples: x = (age in years, weight in kg, height in cm, bodyfat in %)
x_1 = (56, 93, 187, 37) # Gérard, from the article
x_2 = (43, 25, 157, 16)
x_3 = (47, 117, 175, 24)
x_4 = (34, 101, 187, 23)
x_5 = (71, 75, 178, 27)
x_6 = (60, 100, 160, 25)
L = [x_1, x_2, x_3, x_4, x_5, x_6]

In [7]:
# Our implementation of TreeSHAP
tree = tree_from_dict(tree_dict) # Tree from the article
print(f'Homemade SHAP values: {treeSHAP(tree, x_1)}')

# Implementation of TreeSHAP from package shap
explainer = shap.TreeExplainer(model)
print(f'Actual SHAP values: {explainer.shap_values(np.array(x_1))}')

Homemade SHAP values: [ 0.38958333 -0.04416667 -0.00666667  0.20125   ]
Actual SHAP values: [ 0.38958333 -0.04416667 -0.00666667  0.20125   ]


### Another example, not from the article

In [8]:
X, y = load_diabetes(return_X_y=True)
regressor = DecisionTreeRegressor(random_state=0, max_depth=7)
regressor.fit(X, y)

In [9]:
# Formatting the custom tree
tree_dict_bis = {
    "children_left": regressor.tree_.children_left,
    "children_right": regressor.tree_.children_right,
    "children_default": regressor.tree_.children_right.copy(), # We write it to keep along with the formatting of the package shap.
    "features": regressor.tree_.feature,
    "thresholds": regressor.tree_.threshold,
    "values": np.reshape(regressor.tree_.value,(-1,1)),
    "node_sample_weight": regressor.tree_.n_node_samples
}
model_bis = {"trees": [tree_dict_bis]}

In [10]:
# Our implementation of TreeSHAP
tree_bis = tree_from_dict(tree_dict_bis) # Tree from the article
print(f'Homemade SHAP values: {treeSHAP(tree_bis, X[0])}')

# Implementation of TreeSHAP from package shap
explainer = shap.TreeExplainer(model_bis)
print(f'Actual SHAP values: {explainer.shap_values(X[0])}')

Homemade SHAP values: [-0.05576818 -1.58100166 26.26286909 16.29865298  3.54881159  0.09810506
  0.29857304  0.36819254 17.79555359 -2.85977991]
Actual SHAP values: [-0.05576818 -1.58100166 26.26286909 16.29865298  3.54881159  0.09810506
  0.29857304  0.36819254 17.79555359 -2.85977991]
