In [14]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.cluster import KMeans

from shapkit_nbdev.shapley_values import ShapleyValues
from shapkit_nbdev.inspector import inspector
from shapkit_nbdev.monte_carlo_shapley import MonteCarloShapley
from shapkit_nbdev.sgd_shapley import SGDshapley

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load dataset

In [2]:
from sklearn.datasets import load_wine
wine = load_wine()
columns = wine.feature_names
target_names = wine.target_names
X = pd.DataFrame(wine.data, columns=columns)
y = wine.target
print("Classes: {0}".format(np.unique(y)))
X.head(3)

Classes: [0 1 2]


Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od280/od315_of_diluted_wines,proline
0,14.23,1.71,2.43,15.6,127.0,2.8,3.06,0.28,2.29,5.64,1.04,3.92,1065.0
1,13.2,1.78,2.14,11.2,100.0,2.65,2.76,0.26,1.28,4.38,1.05,3.4,1050.0
2,13.16,2.36,2.67,18.6,101.0,2.8,3.24,0.3,2.81,5.68,1.03,3.17,1185.0



# Train a ML model

In [3]:
clustering = DBSCAN(eps=3, min_samples=2).fit(X)
np.unique(clustering.labels_)

array([-1,  0,  1,  2])

In [4]:
kmeans = KMeans(n_clusters=3, random_state=0).fit(X)

# Define the game

In [5]:
d = X.shape[1]
n = 2**d - 2
d, n

(13, 8190)

In [6]:
fc = lambda x: kmeans.predict(x.reshape(1,-1))[0]

In [7]:
r_class, x_class = 0, 0
while x_class == r_class:
    idx_r, idx_x = np.random.choice(np.arange(len(X)), size=2, replace=False)
    r = X.iloc[idx_r,:]
    x = X.iloc[idx_x,:]
    r_class = fc(r.values)
    x_class = fc(x.values)
fc_class = lambda x: 1 if int(fc(x)) == int(x_class) else 0

In [8]:
print(r)
print()
print("Group Prediction for r: {0:.0f}".format(fc(r.values)))

alcohol                           14.75
malic_acid                         1.73
ash                                2.39
alcalinity_of_ash                 11.40
magnesium                         91.00
total_phenols                      3.10
flavanoids                         3.69
nonflavanoid_phenols               0.43
proanthocyanins                    2.81
color_intensity                    5.40
hue                                1.25
od280/od315_of_diluted_wines       2.73
proline                         1150.00
Name: 13, dtype: float64

Group Prediction for r: 1


In [9]:
print(x)
print()
print("Group Prediction for x: {0:.0f}".format(fc(x.values)))

alcohol                          12.70
malic_acid                        3.87
ash                               2.40
alcalinity_of_ash                23.00
magnesium                       101.00
total_phenols                     2.83
flavanoids                        2.55
nonflavanoid_phenols              0.43
proanthocyanins                   1.95
color_intensity                   2.57
hue                               1.19
od280/od315_of_diluted_wines      3.13
proline                         463.00
Name: 79, dtype: float64

Group Prediction for x: 2


# Exact Shapley Values

In [10]:
true_shap = ShapleyValues(x=x, fc=fc_class, r=r)

100%|██████████| 13/13 [00:25<00:00,  1.93s/it]


In [11]:
true_shap

alcohol                         0.0
malic_acid                      0.0
ash                             0.0
alcalinity_of_ash               0.0
magnesium                       0.0
total_phenols                   0.0
flavanoids                      0.0
nonflavanoid_phenols            0.0
proanthocyanins                 0.0
color_intensity                 0.0
hue                             0.0
od280/od315_of_diluted_wines    0.0
proline                         1.0
dtype: float64

# Approximation methods

## Monte Carlo 

In [12]:
mc_shap = MonteCarloShapley(x=x, fc=fc_class, r=r, n_iter=100)
mc_shap

100%|██████████| 100/100 [00:00<00:00, 254.03it/s]


alcohol                         0.0
malic_acid                      0.0
ash                             0.0
alcalinity_of_ash               0.0
magnesium                       0.0
total_phenols                   0.0
flavanoids                      0.0
nonflavanoid_phenols            0.0
proanthocyanins                 0.0
color_intensity                 0.0
hue                             0.0
od280/od315_of_diluted_wines    0.0
proline                         1.0
dtype: float64

## SGD

In [13]:
sgd_est = SGDshapley(d, C=y.max())
sgd_shap = sgd_est.sgd(x=x, fc=fc_class, r=r, n_iter=1000, step=.1, step_type="sqrt")
sgd_shap

100%|██████████| 1000/1000 [00:00<00:00, 2164.31it/s]


alcohol                         0.014268
malic_acid                      0.012615
ash                            -0.023293
alcalinity_of_ash               0.040183
magnesium                       0.012993
total_phenols                   0.007581
flavanoids                      0.005783
nonflavanoid_phenols           -0.001543
proanthocyanins                -0.006194
color_intensity                 0.007330
hue                            -0.009825
od280/od315_of_diluted_wines   -0.020466
proline                         0.960569
dtype: float64