In [1]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


import shap
from alibi.explainers import KernelShap
from alibi.datasets import fetch_adult

from src.mai.topk.direct import Direct
from src.mai.topk.halving import Halving
from src.mai.topk.lucb import LUCB
from src.mai.topk.kl_lucb import KL_LUCB
from src.distributions import Shap


%load_ext autoreload
%autoreload 2

## Constants

In [2]:
m = 5
eps = 0.1
delta = 0.1

## Dataset

In [3]:
if True:
    n_samples = 10000
    n_features = 150
    n_informative = 100
    n_redundant = 50
    n_clusters_per_class=10
    n_classes = 2
    
    X, y = make_classification(n_samples=n_samples, 
                               n_features=n_features, 
                               n_informative=n_informative, 
                               n_redundant=n_redundant,
                               n_classes=n_classes,
                               hypercube=False,
                               n_clusters_per_class=n_clusters_per_class)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
else:
    # adult census dataset
    data = fetch_adult()
    X, y = data['data'], data['target']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    n_features = X.shape[-1]

## Train model

In [4]:
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
clf.score(X_train, y_train), clf.score(X_test, y_test)

(0.9681357493857494, 0.8252725318593582)

In [5]:
max_val = np.max(clf.predict(X_train))
min_val = np.min(clf.predict(X_train))
min_val, max_val

(0, 1)

## Explain prediction

In [6]:
X = X_test[:1]
baseline = X_train[:10]

### KernelShap

In [7]:
def predict_fn(x):
    y = np.clip(clf.predict_proba(x), 0.01, 0.99)
    return y / np.sum(y, axis=-1, keepdims=True)

explainer = KernelShap(predictor=predict_fn,
                       link='identity',
                       task='classification')

explainer = explainer.fit(baseline)

In [8]:
for i in [5, 10, 12, 14, 16]:
    exp = explainer.explain(X=X,
                            nsamples=2**i,
                            l1_reg=False)

    indices = np.argsort(exp.shap_values[0][0])[-m:]
    vals = exp.shap_values[0][0][indices]
    print(f'#samples: 2^{i}; indices={np.sort(indices)}')

  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^5; indices=[1 4 5 6 7]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^10; indices=[1 6 7 8 9]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^12; indices=[1 5 6 7 9]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^14; indices=[1 5 6 7 9]


  0%|          | 0/1 [00:00<?, ?it/s]

#samples: 2^16; indices=[1 5 6 7 9]


### MAI

In [9]:
simple_predict_fn = lambda x: predict_fn(x)[:, 0]

arms = [Shap(feature=i,
             X=X,
             baseline=baseline,
             predictor=simple_predict_fn,
             min_val=min_val,
             max_val=max_val) for i in range(n_features)]

In [10]:
algo = Direct(arms=arms, m=m, eps=eps, delta=delta)
ret_indices = algo.play()
np.sort(ret_indices)

100%|██████████| 12/12 [00:00<00:00, 37.76it/s]


array([1, 5, 6, 7, 9])

In [11]:
# algo = Halving(arms=arms, m=m, eps=eps, delta=delta)
# ret_indices = algo.play()
# np.sort(ret_indices)

In [12]:
algo = LUCB(arms=arms, m=m, eps=eps, delta=delta, batch_size=1000)
ret_indices = algo.play()
np.sort(ret_indices)

first iteration: 100%|██████████| 12/12 [00:00<00:00, 41.23it/s]
5it [00:00, 20.91it/s]


array([5, 6, 7, 8, 9])

In [16]:
algo = KL_LUCB(arms=arms, m=m, eps=eps, delta=delta, batch_size=1000)
ret_indices = algo.play()
np.sort(ret_indices)

first iteration: 100%|██████████| 12/12 [00:00<00:00, 43.03it/s]
0it [00:00, ?it/s]


array([1, 5, 6, 7, 9])