In [1]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


import shap
from alibi.explainers import KernelShap
from alibi.datasets import fetch_adult

from src.mai.topk.direct import Direct
from src.mai.topk.halving import Halving
from src.mai.topk.lucb import LUCB
from src.mai.topk.kl_lucb import KL_LUCB
from src.distributions import Shap


%load_ext autoreload
%autoreload 2

## Constants

In [2]:
m = 5
eps = 0.1
delta = 0.1

## Dataset

In [3]:
if True:
    n_samples = 10000
    n_features = 150
    n_informative = 100
    n_redundant = 50
    n_clusters_per_class=10
    n_classes = 2
    
    X, y = make_classification(n_samples=n_samples, 
                               n_features=n_features, 
                               n_informative=n_informative, 
                               n_redundant=n_redundant,
                               n_classes=n_classes,
                               hypercube=False,
                               n_clusters_per_class=n_clusters_per_class)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
else:
    # adult census dataset
    data = fetch_adult()
    X, y = data['data'], data['target']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    n_features = X.shape[-1]

## Train model

In [4]:
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
clf.score(X_train, y_train), clf.score(X_test, y_test)

(1.0, 0.5345)

In [5]:
max_val = np.max(clf.predict(X_train))
min_val = np.min(clf.predict(X_train))
min_val, max_val

(0, 1)

## Explain prediction

In [6]:
X = X_test[:1]
baseline = X_train[:10]

### KernelShap

In [7]:
def predict_fn(x):
    y = np.clip(clf.predict_proba(x), 0.01, 0.99)
    return y / np.sum(y, axis=-1, keepdims=True)

explainer = KernelShap(predictor=predict_fn,
                       link='identity',
                       task='classification')

explainer = explainer.fit(baseline)

In [8]:
for i in [5, 10, 12, 14, 16]:
    exp = explainer.explain(X=X,
                            nsamples=2**i,
                            l1_reg=False)

    indices = np.argsort(exp.shap_values[0][0])[-m:]
    vals = exp.shap_values[0][0][indices]
    print(f'#samples: 2^{i}; indices={np.sort(indices)}')

  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^5; indices=[ 34  65  97 118 125]


Linear regression equation is singular, Moore-Penrose pseudoinverse is used instead of the regular inverse.
To use regular inverse do one of the following:
1) turn up the number of samples,
2) turn up the L1 regularization with num_features(N) where N is less than the number of samples,
3) group features together to reduce the number of inputs that need to be explained.


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^10; indices=[ 34  71  74  97 141]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^12; indices=[ 7 34 44 60 97]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^14; indices=[ 13  34  44  97 139]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^16; indices=[13 34 44 93 97]


### MAI

In [9]:
simple_predict_fn = lambda x: predict_fn(x)[:, 0]

arms = [Shap(feature=i,
             X=X,
             baseline=baseline,
             predictor=simple_predict_fn,
             min_val=min_val,
             max_val=max_val) for i in range(n_features)]

In [10]:
algo = Direct(arms=arms, m=m, eps=eps, delta=delta)
ret_indices = algo.play()
np.sort(ret_indices)

100%|██████████| 150/150 [00:11<00:00, 13.02it/s]


array([ 7, 13, 34, 44, 97])

In [11]:
# algo = Halving(arms=arms, m=m, eps=eps, delta=delta)
# ret_indices = algo.play()
# np.sort(ret_indices)

In [12]:
algo = LUCB(arms=arms, m=m, eps=eps, delta=delta, batch_size=100)
ret_indices = algo.play()
np.sort(ret_indices)

first iteration: 100%|██████████| 150/150 [00:00<00:00, 204.01it/s]
208it [00:01, 107.57it/s]


array([ 13,  34,  93,  97, 139])

In [15]:
algo = KL_LUCB(arms=arms, m=m, eps=eps, delta=delta, batch_size=100)
ret_indices = algo.play()
np.sort(ret_indices)

first iteration: 100%|██████████| 150/150 [00:00<00:00, 223.20it/s]
36it [00:00, 93.35it/s]


array([13, 34, 60, 93, 97])