In this notebook, we work with look-up tables that were created for §4.2 to compute Shapley values for the four CatBoost models from §4.1 which were trained on public datasets. A sample of size 100 from the test set is provided in each case, and is available in the folder "Samples". 

The look-up tables are created via a proprietary code of Discover Financial Services which is a fast implementation of Algorithm 3.12. These are located in folders "game_values_loc_1", "game_values_loc_2" etc. where  there is a .csv file for each tree of the ensemble under consideration containing all Shapley values arising from that tree. The rows are indexed by the leaves of the oblivious tree and the columns capture the features on which the tree splits. The non-relaizable leaves corresponding to vacuous regions are excluded. The .json file in each folder relates the local enumeration of features appearing in a tree to their global index in the training data. 

We verify these precomputed Shapley values through checking the 
[efficiency axiom](https://christophm.github.io/interpretable-ml-book/shapley.html#the-shapley-value-in-detail): Choosing a tree from one of the four ensembles randomly, for each data sample, the difference between tree's output (i.e. the leaf value) and the sum of Shapley values associated with the corresponding leaf is always a constant -- it should be equal to the avarage of outputs of that tree over the whole training data:

$$\sum_i\varphi_i[g](\mathbf{x})=g(\mathbf{x})-\mathbb{E}[g]\quad \forall\mathbf{x}.$$

In [1]:
import pandas as pd
import numpy as np
import os, os.path
import glob
import random
import pickle
import json

import catboost
from catboost import CatBoostClassifier, CatBoostRegressor

We first load the CatBoost model and the corresponding data sample. Only experiment_number should be declared (a number between 1 and 4).

In [2]:
experiment_number=4

if experiment_number==1 or experiment_number==2:
    model_type='Regressor'
elif experiment_number==3 or experiment_number==4:
    model_type='Classifier'
else:
    raise ValueError('experiment_number should be 1,2,3 or 4.')
    
sample_path='./Samples/Sample_'+str(experiment_number)+'.csv'
sample=pd.read_csv(sample_path)
n_samples=sample.shape[0]

model_path='./Models/'+model_type+'_CatBoost_'+str(experiment_number)
model_cat=pickle.load(open(model_path,'rb'))

local_shapley_folder_path='./game_value_loc_'+str(experiment_number)
n_trees=len(glob.glob1(local_shapley_folder_path,'*.csv'))
print(f'We consider the CatBoost ensemble from experiment {experiment_number} which has {n_trees} trees.')

We consider the CatBoost ensemble from experiment 1 which has 300 trees.


The cell below computes and stores the average of outputs of each tree from the ensemble over the training data. The function retrieve_catboost from the Retrieve_splits notebook is used. 

In [3]:
def tree_averages(model_cat):
    
    #Dumping the CatBoost model as a dictionary. 
    if (not isinstance(model_cat,CatBoostClassifier)) and (not isinstance(model_cat,CatBoostRegressor)):
        raise TypeError('The input should be a CatBoost classifier or regressor.')
    model_cat.save_model('temp',format='json')
    model=open('temp')
    dictionary_catboost=json.load(model)
    os.remove('temp')
    
    #The average associated with a tree is the sum of products region['value']*region['probability'] over all leaves.
    averages=[]
    for tree_structure in dictionary_catboost['oblivious_trees']:
        regions=retrieve_catboost(tree_structure)['regions']
        probabilities=[]
        values=[]
        #Extracting the value and the probability associated with each leaf (or equivalently, with the corresponding region.)
        for region in regions:                     
            probabilities+=[region['probability']]
            values+=[region['value']]
        average=np.dot(np.asarray(probabilities),np.asarray(values))
        averages+=[average]
    
    return averages
##############################################################################        
#The function retrieve_catboost from the Retrieve_splits notebook.

def retrieve_catboost(tree_structure):
    #Initializing the dictionary:
    info={}
    
    #The first two keys are easy:
    info['depth']=len(tree_structure['splits'])
    info['n_leaves']=2**info['depth']
    
    #Initializing the next two keys:
    info['splits']=[]
    info['distinct_feature_indx']=[]
    
    for split in tree_structure['splits']:               #Each element of tree_structure['splits'] describes a splitting that takes place across an entire level.
        if split['float_feature_index'] not in info['distinct_feature_indx']:
            info['distinct_feature_indx']+=[split['float_feature_index']]
        info['splits']+=[(split['float_feature_index'],split['border'])]
        
    
    #It remains to compute info['region'], a list comprised of one dictionary per region. 
    #Initializing:
    info['regions']=[]
    
    for i in range(2**info['depth']):
        #Constructing the dictionary describing this region:
        region={}
        region['value']=tree_structure['leaf_values'][i]
        region['weight']=tree_structure['leaf_weights'][i]
        
        #Initializing the keys that describe bounds for each feature.
        for feature_index in info['distinct_feature_indx']:
            region[feature_index]=[-float('inf'),float('inf')]
            
        expansion='{0:b}'.format(i)                          #The binary expansion of i which is the index of the leaf/region under consideration.
        while len(expansion)<info['depth']:                  #(An integer from [0,2**depth-1], we want len(expansion)=depth.) 
            expansion='0'+expansion
                                                             
        for j in range(info['depth']):                       #The leftmost characters of the expansion are determined by top splits near 
            feature_index=info['splits'][-j-1][0]            #the root which are encoded by the rightmost entries of info['splits'].         
            threshold=info['splits'][-j-1][1]                #(Keep in mind that splits closer to the root appear at the end of tree_structure['splits']).
            
            if expansion[j]=='0':                            #Meaning we go to the left since feature_value<threshold.
                region[feature_index]=modify_interval(region[feature_index],threshold,'upper')
            else:                                            #Meaning we go to the left since feature_value>threshold.
                region[feature_index]=modify_interval(region[feature_index],threshold,'lower')
        
        #Adding the dictionary constructed for this region to info['regions'].
        info['regions']+=[region]
        
        
    #Adding a key for porbability to each dictionary from info['regions']
    total_weight=0
    for region in info['regions']:
        total_weight+=region['weight']
    for region in info['regions']:
        region['probability']=region['weight']/total_weight
    
    return info

############
#An auxiliary function 

def modify_interval(interval,bound,kind):
    if interval==None:                      #Nothing to modify if the interval is empty to begin with. 
        return None
    if kind=='upper':
        if interval[0]>=bound:
            return None
        else:
            interval[1]=min(interval[1],bound)
    else:
        if interval[1]<=bound:
            return None
        else:
            interval[0]=max(interval[0],bound)
    return interval            

##############################################################################
#Saving the averages for all trees in the ensemble
averages=tree_averages(model_cat)

In [4]:
#We pick a random tree. Since leaves corresponding to degenerate regions are not considered 
#in look-up tables, we only consider tables with (number of rows)=2**(number of columns),
#that is, trees without repeated features. 
#For such trees, the internal enumeration of leaves matches the order of rows. 

while True:
    tree_index=random.randint(0,n_trees-1)
    local_shapley=pd.read_csv(local_shapley_folder_path+'/game_value_tree_'+str(tree_index)+'.csv',
                         header=None)
    if local_shapley.shape[0]==2**(local_shapley.shape[1]):
        break
print(f'The tree of index {tree_index} was chosen randomly from the CatBoost ensemble.')
print(f'The average of its outputs over the training data is {averages[tree_index]}.')
        
#The outputs of the chosen tree. These are leaf values which for classifiers are logit probability.         
outputs=model_cat.predict(sample,prediction_type='RawFormulaVal',
                          ntree_start=tree_index,ntree_end=tree_index+1)
        

#Determining leaves of the tree at which sample points land:
leaf_indices=model_cat.calc_leaf_indexes(sample,ntree_start=tree_index,ntree_end=tree_index+1).reshape(n_samples)

#Adding the sum of rows to the table of shapley values
local_shapley['sum']=local_shapley.sum(axis=1)

#Subtracting the sum of Shapley values at the leaf corresponding to a sample point from the leaf value:
difference=outputs-np.asarray(local_shapley['sum'][leaf_indices].to_list())
print('\nVerifying the efficiency axiom: the sum of local Shapley values minus the output should be the same for all 100 sample data points; this difference always coincides with the average ouput of the tree.')
difference

The tree of index 120 was chosen randomly from the CatBoost ensemble.
The average of its outputs over the training data is -0.0022147689080705183.

Verifying the efficiency axiom: the sum of local Shapley values minus the output should be the same for all 100 sample data points; this difference always coincides with the average ouput of the tree.


array([-0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00221477,
       -0.00221477, -0.00221477, -0.00221477, -0.00221477, -0.00