In [10]:
import copy

import numpy as np
from scipy.special import binom

from games import ParameterizedSparseLinearModel
from approximators import SHAPIQEstimator, PermutationSampling
from approximators.regression import RegressionEstimator

### Setup the game
Here we use the linear model as we can compute the ground truth interaction values directly from the game.

In [11]:
# setup the game function (here we use a
game = ParameterizedSparseLinearModel(
    n=30, # n of players
    weighting_scheme="uniform", # how the interactions should be distributed over the subset sizes
    min_interaction_size=1, # min size of interactions in the model
    max_interaction_size=20, # max size of interactions in the model
    n_interactions=100, # number of interactions in the model
    n_non_important_features=3 # number of dummy (zero weight) features, which will also not be part of the interactions
)

game_name = game.game_name
game_fun = game.set_call
n = game.n
N = set(range(n))

### Set up the interaction index
The interaction values will be calculated for the order of interactions.

In [12]:
interaction_order = 2

### Setup the baseline approximators
In total there are two approaches for the interaction indices. Permutation sampling is defined for SII and STI. The weighted least-squares approach is defined for FSI.

In [13]:
# Permutation Sampling for Shapley Interaction Index
shapley_extractor_sii_permutation = PermutationSampling(
    N=N,
    order=interaction_order,
    interaction_type="SII",
    top_order=True
)

# Permutation Sampling for Shapley Taylor Index
shapley_extractor_sti_permutation = PermutationSampling(
    N=N,
    order=interaction_order,
    interaction_type="STI",
    top_order=True
)

# Regression Estimator for Shapley Faith Index
shapley_extractor_FSI_regression = RegressionEstimator(
    N=N,
    max_order=interaction_order
)

baselines = {
    "SII": shapley_extractor_sii_permutation,
    "STI": shapley_extractor_sti_permutation,
    "FSI": shapley_extractor_FSI_regression
}

### Setup SHAP-IQ approximator
SHAP-IQ is defined for all interaction indices that follow a general definition of Shapley interaction (SII, STI, and FSI). For more information we refer to the full paper (the section about SI).

In [14]:
# SHAP-IQ to approximate the Shapley Interaction Index
shapley_extractor_sii = SHAPIQEstimator(
    N=N,
    order=interaction_order,
    interaction_type="SII",
    top_order=True
)

# SHAP-IQ to approximate the Shapley Taylor Index
shapley_extractor_sti = SHAPIQEstimator(
    N=N,
    order=interaction_order,
    interaction_type="STI",
    top_order=True
)

# SHAP-IQ to approximate the Shapley Faith Index
shapley_extractor_FSI = SHAPIQEstimator(
    N=N,
    order=interaction_order,
    interaction_type="FSI",
    top_order=True
)

approximators = {
    "SII": shapley_extractor_sii,
    "STI": shapley_extractor_sti,
    "FSI": shapley_extractor_FSI
}

### Run the approximators
Run the baseline and SHAP-IQ approximators. Also compute the ground truth interaction values.

In [15]:
budget = 2**14

baseline_results = {}
shap_iq_results = {}
ground_truth_results = {}
for interaction_type in {'SII', 'STI', 'FSI'}:
    baseline = baselines[interaction_type]
    shap_iq = approximators[interaction_type]

    # run baseline method
    approx_value = baseline.approximate_with_budget(
        game_fun, budget
    )
    baseline_results[interaction_type] = copy.deepcopy(approx_value)

    # run shap_iq method
    approx_value = shap_iq.compute_interactions_from_budget(
        game=game.set_call,
        budget=budget
    )
    shap_iq_results[interaction_type] = copy.deepcopy(approx_value)

    # get ground truths (only possible this way with the sparse linear model, otherwise we need to use brute force)
    ground_truth_results[interaction_type] = copy.deepcopy(
        game.exact_values(
            gamma_matrix=shap_iq.weights[interaction_order],
            min_order=interaction_order,
            max_order=interaction_order
        )
    )


Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 7492.28it/s]

Exact values: Final computation: 100%|██████████| 43500/43500.0 [00:00<00:00, 874248.78it/s]

Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 7332.27it/s]

Exact values: Final computation: 100%|██████████| 43500/43500.0 [00:00<00:00, 877629.85it/s]

Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 7338.47it/s]

Exact values: Final computation: 100%|██████████| 43500/43500.0 [00:00<00:00, 876318.90it/s]

Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 7312.25it/s]

Exact values: Final computation: 100%|██████████| 43500/43500.0 [00:00<00:00, 841468.38it/s]

Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 8576.43it/s]

Exact values: Final computation: 100%|██████████| 43500/43500.0 [00:00<00:00, 825695.23it/s]

Exact values: pre-computed weights: 100%|██████████| 60/60 [00:00<00:00, 8306.65it/s]

Exact v

In [16]:
def mse(gt, approx):
    return np.sum((approx - gt) ** 2) / binom(n, interaction_order)

def mae(gt, approx):
    return np.sum(abs(approx - gt)) / binom(n, interaction_order)

### Results for SII

In [17]:
ground_truth = ground_truth_results['SII'][interaction_order]
ground_truth[0:10,0:10]

array([[0.        , 1.23650319, 1.07479209, 1.05632822, 0.86215728,
        0.9446618 , 1.03340214, 0.88913369, 0.8498577 , 0.8805166 ],
       [0.        , 0.        , 0.74486269, 1.0310984 , 0.77919746,
        0.73013863, 1.01863339, 0.86770917, 0.5175948 , 0.93763025],
       [0.        , 0.        , 0.        , 0.77027028, 0.77482485,
        0.88655876, 0.93174988, 0.92494576, 0.49018877, 0.8109293 ],
       [0.        , 0.        , 0.        , 0.        , 0.66626615,
        0.8260678 , 0.92432102, 1.01415216, 0.698431  , 1.06286796],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.58425238, 0.73509963, 0.60408573, 0.54917033, 0.84351996],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.7493785 , 0.77388202, 0.82047558, 0.50759201],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.9099447 , 0.94495607, 0.81941716],
       [0.        , 0.        , 0.       

In [18]:
shap_iq_approx = shap_iq_results['SII'][interaction_order]
shap_iq_approx[0:10,0:10]

array([[0.        , 1.31909109, 1.39808051, 0.84555952, 0.74241507,
        1.02096432, 0.97219863, 0.67994928, 0.68817321, 0.84853506],
       [0.        , 0.        , 0.86772691, 0.87638678, 0.52947907,
        0.83012873, 0.95561535, 0.65494701, 0.73001903, 0.91491088],
       [0.        , 0.        , 0.        , 0.77941076, 0.64953634,
        1.08944284, 1.06799816, 0.32607804, 0.63745497, 0.57655545],
       [0.        , 0.        , 0.        , 0.        , 0.65622055,
        0.816422  , 0.90826287, 0.96240934, 0.31597361, 1.10986421],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.7551389 , 0.22724817, 0.74913297, 0.68719454, 1.27239085],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.66057935, 0.91685942, 0.62835133, 0.41184365],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 1.06907602, 0.90650573, 0.77312821],
       [0.        , 0.        , 0.       

In [20]:
baseline_approx = baseline_results['SII'][interaction_order]
baseline_approx[0:10,0:10]

array([[0.        , 1.34930083, 0.38059255, 0.13545948, 1.79239844,
        0.13459902, 2.15948525, 0.14697162, 1.52894947, 0.67460588],
       [0.        , 0.        , 1.10617419, 1.21184839, 0.5119647 ,
        0.        , 0.16858386, 1.62816078, 0.        , 1.40596464],
       [0.        , 0.        , 0.        , 0.12978002, 2.08612979,
        0.        , 2.8633546 , 3.07027763, 1.11649643, 3.07785069],
       [0.        , 0.        , 0.        , 0.        , 0.67228359,
        2.09505193, 0.87384828, 0.08886677, 3.12422282, 2.01698577],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.66761961, 0.04944379, 0.        , 1.63189271, 0.35350594],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 1.8945357 , 1.17836557, 1.88596473, 0.03671737],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 2.11391735, 0.60256652, 0.07621577],
       [0.        , 0.        , 0.       

In [21]:
print("average MSE (SHAP-IQ):", mse(ground_truth, shap_iq_approx))
print("average MSE (baseline):", mse(ground_truth, baseline_approx))
print("average MAE (SHAP-IQ):", mae(ground_truth, shap_iq_approx))
print("average MAE (baseline-IQ):", mae(ground_truth, baseline_approx))

average MSE (SHAP-IQ): 0.04115883676512268
average MSE (baseline): 0.6971960137361491
average MAE (SHAP-IQ): 0.15984924373098103
average MAE (baseline-IQ): 0.5931086751047642


### Results for STI
The tables show only a 10:10 selection as it's easier to render on the browser.

In [23]:
ground_truth = ground_truth_results['STI'][interaction_order]
ground_truth[0:10,0:10]

array([[0.        , 0.25364211, 0.16501766, 0.18063368, 0.14334586,
        0.15933285, 0.15816146, 0.13829299, 0.12497818, 0.13344251],
       [0.        , 0.        , 0.09765813, 0.16487622, 0.11123641,
        0.11047134, 0.15556689, 0.12116054, 0.07010268, 0.13684074],
       [0.        , 0.        , 0.        , 0.11276116, 0.12220416,
        0.13138412, 0.13381662, 0.12042494, 0.05704975, 0.10707327],
       [0.        , 0.        , 0.        , 0.        , 0.09207049,
        0.12334399, 0.14200411, 0.16128499, 0.10545164, 0.18542799],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.09553591, 0.11748126, 0.07928619, 0.08074468, 0.13665937],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.09596795, 0.09767513, 0.11072452, 0.06194814],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.12175044, 0.14243971, 0.13006958],
       [0.        , 0.        , 0.       

In [24]:
shap_iq_approx = shap_iq_results['STI'][interaction_order]
shap_iq_approx[0:10,0:10]

array([[ 0.        ,  0.16847185,  0.18929198,  0.21109203,  0.17205177,
         0.13903629,  0.2244925 ,  0.12221301,  0.11930438,  0.11887942],
       [ 0.        ,  0.        ,  0.13409132,  0.28177785,  0.06431715,
         0.14330756,  0.21200852,  0.109316  ,  0.08039629,  0.21402872],
       [ 0.        ,  0.        ,  0.        ,  0.06686338,  0.27406434,
         0.14885208,  0.06236145,  0.10721628, -0.00044986,  0.07138431],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.06974314,
         0.08096034,  0.17284042,  0.14457847,  0.14633362,  0.24658543],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.05864616,  0.0602123 ,  0.0583028 ,  0.09518174,  0.12971862],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.14592489,  0.0685213 ,  0.0630305 ,  0.13919349],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.09545364

In [25]:
baseline_approx = baseline_results['STI'][interaction_order]
baseline_approx[0:10,0:10]

array([[0.00000000e+00, 2.66085921e-01, 0.00000000e+00, 8.84525177e-02,
        0.00000000e+00, 1.50431459e-01, 3.19864888e-01, 9.73711874e-02,
        0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.69313037e-01,
        0.00000000e+00, 0.00000000e+00, 4.23671248e-01, 6.13649174e-01,
        1.01382373e-01, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.82271917e-01,
        8.39732575e-02, 8.39732575e-02, 2.05374027e-01, 0.00000000e+00,
        5.76947806e-02, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.54417833e-03, 1.54417833e-03, 1.48044666e+00, 1.25802889e-01,
        1.13552678e+00, 2.94052305e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 8.55174358e-02, 6.04312964e-02, 1.54417833e-03,
        9.83072561e-03, 2.02773711e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   

In [26]:
print("average MSE (SHAP-IQ):", mse(ground_truth, shap_iq_approx))
print("average MSE (baseline):", mse(ground_truth, baseline_approx))
print("average MAE (SHAP-IQ):", mae(ground_truth, shap_iq_approx))
print("average MAE (baseline-IQ):", mae(ground_truth, baseline_approx))

average MSE (SHAP-IQ): 0.00234061225556173
average MSE (baseline): 0.075226097137372
average MAE (SHAP-IQ): 0.03890075205285072
average MAE (baseline-IQ): 0.1404959078910604


### Results for FSI

In [27]:
ground_truth = ground_truth_results['FSI'][interaction_order]
ground_truth[0:10,0:10]

array([[0.        , 0.58783788, 0.41535807, 0.44487845, 0.35589828,
        0.39182199, 0.39928809, 0.34674818, 0.31743008, 0.33901099],
       [0.        , 0.        , 0.25508103, 0.41222169, 0.28599455,
        0.27777504, 0.39528357, 0.3134065 , 0.18240645, 0.35194218],
       [0.        , 0.        , 0.        , 0.28479718, 0.30497129,
        0.33234033, 0.34149881, 0.31559734, 0.15211227, 0.27959836],
       [0.        , 0.        , 0.        , 0.        , 0.23795063,
        0.31161093, 0.3591254 , 0.39819042, 0.26742663, 0.45002624],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.23647315, 0.2935484 , 0.20702232, 0.20703851, 0.34285473],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.25212313, 0.25734284, 0.2886504 , 0.16381003],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.31755982, 0.36103627, 0.32587184],
       [0.        , 0.        , 0.       

In [28]:
shap_iq_approx = shap_iq_results['FSI'][interaction_order]
shap_iq_approx[0:10,0:10]

array([[0.        , 0.52719264, 0.38171535, 0.48494666, 0.35093591,
        0.41711513, 0.51720815, 0.46615193, 0.48833947, 0.37168011],
       [0.        , 0.        , 0.37884569, 0.15488513, 0.04884552,
        0.32595457, 0.50497713, 0.31397978, 0.16980087, 0.30211813],
       [0.        , 0.        , 0.        , 0.42111884, 0.41527104,
        0.28706555, 0.06126737, 0.23446596, 0.25289375, 0.00808755],
       [0.        , 0.        , 0.        , 0.        , 0.10257318,
        0.55369743, 0.42650147, 0.57821122, 0.20276433, 0.48030049],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.57012871, 0.28763865, 0.13650743, 0.29649346, 0.17831586],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.34387209, 0.26858226, 0.20827967, 0.19111303],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.15963745, 0.08708184, 0.39641345],
       [0.        , 0.        , 0.       

In [29]:
baseline_approx = baseline_results['FSI'][interaction_order]
baseline_approx[0:10,0:10]

array([[ 0.        ,  0.96057871,  0.38928932,  0.3497072 ,  0.44075145,
         0.7909039 ,  0.40573244,  0.68927468, -0.13948704,  0.11889659],
       [ 0.        ,  0.        ,  0.4803226 ,  0.34242715,  0.39062489,
         0.22965734,  0.98934986,  0.38092139,  0.25894158,  0.54886865],
       [ 0.        ,  0.        ,  0.        ,  0.25660292, -0.05825013,
        -0.02631917, -0.16970365,  0.15707   ,  0.13217898,  0.29723115],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.62652127,
         0.4640503 ,  0.68408583,  0.73809408,  0.03622637,  0.54570719],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.08361825,  0.06729974, -0.42904969,  0.44440527,  0.99778368],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.10151423,  0.15934159,  0.4103331 , -0.14598689],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.11811955

In [30]:
print("average MSE (SHAP-IQ):", mse(ground_truth, shap_iq_approx))
print("average MSE (baseline):", mse(ground_truth, baseline_approx))
print("average MAE (SHAP-IQ):", mae(ground_truth, shap_iq_approx))
print("average MAE (baseline-IQ):", mae(ground_truth, baseline_approx))

average MSE (SHAP-IQ): 0.012294752408365191
average MSE (baseline): 0.10143086426647954
average MAE (SHAP-IQ): 0.08697834994125356
average MAE (baseline-IQ): 0.2514210905402107
