In [1]:
import pickle
import pandas as pd
import shap
import numpy as np
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
with open("model.pkl", "rb") as f:
    model = pickle.load(f)

with open("explainer.pkl", "rb") as f:
    explainer = pickle.load(f)

def predict(data_list):
    y_pred_test=model.predict([data_list])
    return y_pred_test

In [3]:
def prepare(data_list):
    prev_data_frame = pd.DataFrame(data_list)
    data_frame=prev_data_frame.transpose().values[0]

    feature_names=[]
    for i in prev_data_frame.transpose():
        feature_names.append(i)
    
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(data_frame, check_additivity=False)
    shap_values = shap_values.transpose()
    
    return {'data_frame': data_frame, 'shap_values': shap_values, 'feature_names': feature_names}

In [4]:
def show_shap(data_list):
    
    data_frame=pd.DataFrame(data_list)
    explainer = shap.TreeExplainer(model)
    prev_shap_values = explainer.shap_values(data_frame, check_additivity=False)
    shap_values = prev_shap_values.transpose()

    index=0
    feature_importances = np.abs(shap_values).mean(axis=0)
    sorted_features = np.sort(feature_importances[0])[::-1]
    max_shap=sorted_features[0]

    print(f"row {i} is safe mainly because> ", end='')
    
    for i in range (0, len(feature_importances[0])):
        if(feature_importances[0][i]==max_shap):
            print(f"column index {i}", end=' ')
            index=i
    
    data_point = data_frame.iloc[index]
    print(f"is: {data_point[0]}")

In [5]:
def get_feature_importance_ranking(data_list):
  
  shap_values = prepare(data_list)['shap_values']
  feature_names = prepare(data_list)['feature_names']
  
  feature_importances = np.abs(shap_values).mean(axis=0)

  importance_df = pd.DataFrame({'feature': feature_names, 'importance': feature_importances})

  importance_df.sort_values(by='importance', ascending=False, inplace=True)

  arr=np.around(importance_df, decimals=4)
  print(arr.transpose())

In [6]:
def plot_shap_value_distribution(data_list):
    
  shap_values = prepare(data_list)['shap_values']
  feature_names = prepare(data_list)['feature_names']
    
  for i in range(shap_values.shape[1]):
    plt.hist(shap_values[:, i])
    plt.xlabel("SHAP Value")
    plt.ylabel("Count")
    plt.title(f"Distribution of SHAP Values for {feature_names[i]}")
    plt.show()
    plt.clf()


In [7]:
def explain_individual_datapoint(data_list):
    
  shap_values = prepare(data_list)['shap_values']
  feature_names = prepare(data_list)['feature_names']
  
  for feature_name, shap_value in zip(feature_names, shap_values[0]):
    print(f"Feature: {feature_name}, SHAP Value: {shap_value}")

In [8]:
data=pd.read_csv('base files/creditcard_test.csv')
data_list = data.values.tolist()
i=0
for row in data_list:
    i=i+1
    res=predict(row)
    if(res[0]==1):
        try:
            # show_shap(row)
            get_feature_importance_ranking(row)
            # plot_shap_value_distribution(row)
            # explain_individual_datapoint(row)
        except Exception as e:
            print(e)

                 14       10       12      4       7        11      9   \
feature     14.0000  10.0000  12.0000  4.0000  7.0000  11.0000  9.0000   
importance   0.2249   0.1668   0.0445  0.0414  0.0343   0.0195  0.0159   

                0        16       27  ...      2        28      8        29  \
feature     0.0000  16.0000  27.0000  ...  2.0000  28.0000  8.0000  29.0000   
importance  0.0141   0.0129   0.0107  ...  0.0019   0.0018  0.0017   0.0016   

                 24       19       23       25       20       21  
feature     24.0000  19.0000  23.0000  25.0000  20.0000  21.0000  
importance   0.0015   0.0009   0.0008   0.0008   0.0005   0.0004  

[2 rows x 30 columns]
                12       14       17       11       10      4       9   \
feature     12.000  14.0000  17.0000  11.0000  10.0000  4.0000  9.0000   
importance   0.233   0.1846   0.1637   0.1047   0.0895  0.0422  0.0266   

                7        26     20  ...      5        22      13      8   \
feature     7.00

In [None]:
print('done')