In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data import RandomSampler, BatchSampler
from torch.distributions.categorical import Categorical

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

from copy import deepcopy
from tqdm.auto import tqdm
import itertools
import shap 
print(f"SHAP version: {shap.__version__}")

import numpy as np
import pickle
import os.path
import lightgbm as lgb
import matplotlib.pyplot as plt
import time
import random
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

import collections

import keras
print(keras.__version__)

from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Flatten, Activation, Input, Conv2D, MaxPooling2D
from keras import regularizers
from keras import backend as K

## Utils Functions

In [None]:
class MaskLayer1d(nn.Module):

    def __init__(self, value, append):
        super().__init__()
        self.value = value
        self.append = append

    def forward(self, input_tuple):
        x, S = input_tuple
        x = x * S + self.value * (1 - S)
        if self.append:
            x = torch.cat((x, S), dim=1)
        return x


class MaskLayer2d(nn.Module):

    def __init__(self, value, append):
        super().__init__()
        self.value = value
        self.append = append

    def forward(self, input_tuple):
        x, S = input_tuple
        if len(S.shape) == 3:
            S = S.unsqueeze(1)
        x = x * S + self.value * (1 - S)
        if self.append:
            x = torch.cat((x, S), dim=1)
        return x


class KLDivLoss(nn.Module):

    def __init__(self, reduction='batchmean', log_target=False):
        super().__init__()
        self.kld = nn.KLDivLoss(reduction=reduction, log_target=log_target)

    def forward(self, pred, target):

        return self.kld(pred.log_softmax(dim=1), target)  #####################################


class DatasetRepeat(Dataset):

    def __init__(self, datasets):
        # Get maximum number of elements.
        assert np.all([isinstance(dset, Dataset) for dset in datasets])
        items = [len(dset) for dset in datasets]
        num_items = np.max(items)

        # Ensure all datasets align.
        # assert np.all([num_items % num == 0 for num in items])
        self.dsets = datasets
        self.num_items = num_items
        self.items = items

    def __getitem__(self, index):
        assert 0 <= index < self.num_items
        return_items = [dset[index % num] for dset, num in
                        zip(self.dsets, self.items)]
        return tuple(itertools.chain(*return_items))

    def __len__(self):
        return self.num_items


class DatasetInputOnly(Dataset):

    def __init__(self, dataset):
        assert isinstance(dataset, Dataset)
        self.dataset = dataset

    def __getitem__(self, index):
        return (self.dataset[index][0],)

    def __len__(self):
        return len(self.dataset)


class UniformSampler:

    def __init__(self, num_players):
        self.num_players = num_players

    def sample(self, batch_size):

        S = torch.ones(batch_size, self.num_players, dtype=torch.float32)
        num_included = (torch.rand(batch_size) * (self.num_players + 1)).int()
        # TODO ideally avoid for loops
        # TODO ideally pass buffer to assign samples in place
        for i in range(batch_size):
            S[i, num_included[i]:] = 0
            S[i] = S[i, torch.randperm(self.num_players)]

        return S


class ShapleySampler:

    def __init__(self, num_players):
        arange = torch.arange(1, num_players)
        w = 1 / (arange * (num_players - arange))
        w = w / torch.sum(w)
        self.categorical = Categorical(probs=w)
        self.num_players = num_players
        self.tril = torch.tril(
            torch.ones(num_players - 1, num_players, dtype=torch.float32),
            diagonal=0)

    def sample(self, batch_size, paired_sampling):

        num_included = 1 + self.categorical.sample([batch_size])
        S = self.tril[num_included - 1]
        # TODO ideally avoid for loops
        for i in range(batch_size):
            if paired_sampling and i % 2 == 1:
                S[i] = 1 - S[i - 1]
            else:
                S[i] = S[i, torch.randperm(self.num_players)]
        return S

## Surrogate - Code

In [None]:
def validate(surrogate, loss_fn, data_loader):

    with torch.no_grad():
        # Setup.
        device = next(surrogate.surrogate.parameters()).device
        mean_loss = 0
        N = 0
        link=nn.Softmax(dim=1)

        for x, y, S in data_loader:
            x = x.to(device)
            y = y.to(device)
            S = S.to(device)
            pred = surrogate(x, S)
            loss = loss_fn(pred, y)
            N += len(x)
            mean_loss += len(x) * (loss - mean_loss) / N

    return mean_loss


def generate_labels(dataset, model, batch_size):

    with torch.no_grad():
        # Setup.
        preds = []
        if isinstance(model, torch.nn.Module):
            device = next(model.parameters()).device
        else:
            device = torch.device('cpu')
        loader = DataLoader(dataset, batch_size=batch_size)

        for (x,) in loader:
            pred = model(x.to(device)).cpu()
            # print("Generate Labels",pred.shape)
            preds.append(pred)

    return torch.cat(preds)


class Surrogate:

    def __init__(self, surrogate, num_features, groups=None):
        # Store surrogate model.
        self.surrogate = surrogate

        # Store feature groups.
        if groups is None:
            self.num_players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.num_players = len(groups)
            device = next(surrogate.parameters()).device
            self.groups_matrix = torch.zeros(
                len(groups), num_features, dtype=torch.float32, device=device)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = 1

    def train_original_model(self,
                             train_data,
                             val_data,
                             original_model,
                             batch_size,
                             max_epochs,
                             loss_fn,
                             validation_samples=1,
                             validation_batch_size=None,
                             lr=1e-3,
                             min_lr=1e-5,
                             lr_factor=0.5,
                             weight_decay=0.01,
                             lookback=5,
                             training_seed=None,
                             validation_seed=None,
                             bar=False,
                             verbose=False):

        # Set up train dataset.
        if isinstance(train_data, np.ndarray):
            train_data = torch.tensor(train_data, dtype=torch.float32)

        if isinstance(train_data, torch.Tensor):
            train_set = TensorDataset(train_data)
        elif isinstance(train_data, Dataset):
            train_set = train_data
        else:
            raise ValueError('train_data must be either tensor or a '
                             'PyTorch Dataset')

        # Set up train data loader.
        random_sampler = RandomSampler(
            train_set, replacement=True,
            num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
        batch_sampler = BatchSampler(
            random_sampler, batch_size=batch_size, drop_last=True)
        train_loader = DataLoader(train_set, batch_sampler=batch_sampler)

        # Set up validation dataset.
        sampler = UniformSampler(self.num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)
        S_val = sampler.sample(len(val_data) * validation_samples)
        if validation_batch_size is None:
            validation_batch_size = batch_size

        if isinstance(val_data, np.ndarray):
            val_data = torch.tensor(val_data, dtype=torch.float32)

        if isinstance(val_data, torch.Tensor):
            # Generate validation labels.
            y_val = generate_labels(TensorDataset(val_data), original_model,
                                    validation_batch_size)
            y_val_repeat = y_val.repeat(
                validation_samples, *[1 for _ in y_val.shape[1:]])

            # Create dataset.
            val_data_repeat = val_data.repeat(validation_samples, 1)
            # print(val_data_repeat.shape)
            # print(y_val_repeat.shape)
            # print(S_val.shape)
            val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val)
        elif isinstance(val_data, Dataset):
            # Generate validation labels.
            y_val = generate_labels(val_data, original_model,
                                    validation_batch_size)
            y_val_repeat = y_val.repeat(
                validation_samples, *[1 for _ in y_val.shape[1:]])

            # Create dataset.
            val_set = DatasetRepeat(
                [val_data, TensorDataset(y_val_repeat, S_val)])
        else:
            raise ValueError('val_data must be either tuple of tensors or a '
                             'PyTorch Dataset')

        val_loader = DataLoader(val_set, batch_size=validation_batch_size)

        # Setup for training.
        surrogate = self.surrogate
        device = next(surrogate.parameters()).device
        optimizer = optim.AdamW(surrogate.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,verbose=verbose)
        best_loss = 100000000
        best_epoch = 0
        best_model = deepcopy(surrogate)
        loss_list = [best_loss]
        if training_seed is not None:
            torch.manual_seed(training_seed)

        for epoch in range(max_epochs):
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc='Training epoch')
            else:
                batch_iter = train_loader

            for (x,) in batch_iter:
                # Prepare data.
                x = x.to(device)

                # Get original model prediction.
                with torch.no_grad():
                    y = original_model(x)

                # Generate subsets.
                S = sampler.sample(batch_size).to(device=device)
                
                # Make predictions.
                pred = self.__call__(x, S)
                #print(x.shape)
                #print(pred.shape)
                loss = loss_fn(pred, y)

                # Optimizer step.
                loss.backward()
                optimizer.step()
                surrogate.zero_grad()

            # Evaluate validation loss.
            self.surrogate.eval()
            val_loss = validate(self, loss_fn, val_loader).item()
            self.surrogate.train()

            # Print progress.
            if verbose:
                print('----- Epoch = {} -----'.format(epoch + 1))
                print('Val loss = {:.4f}'.format(val_loss))
                print('')
            scheduler.step(val_loss)
            loss_list.append(val_loss)

            # Check if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(surrogate)
                best_epoch = epoch
                if verbose:
                    print('New best epoch, loss = {:.4f}'.format(val_loss))
                    print('')
            elif epoch - best_epoch == lookback:
                if verbose:
                    print('Stopping early')
                break

        # Clean up.
        for param, best_param in zip(surrogate.parameters(), best_model.parameters()):
            param.data = best_param.data
        self.loss_list = loss_list
        self.surrogate.eval()

    def __call__(self, x, S):

        if self.groups_matrix is not None:
            S = torch.mm(S, self.groups_matrix)

        return self.surrogate((x, S))

## FastSHAP - Code

In [None]:
def additive_efficient_normalization(pred, grand, null):

    gap = (grand - null) - torch.sum(pred, dim=1)
    # gap = gap.detach()
    return pred + gap.unsqueeze(1) / pred.shape[1]


def multiplicative_efficient_normalization(pred, grand, null):

    ratio = (grand - null) / torch.sum(pred, dim=1)
    # ratio = ratio.detach()
    return pred * ratio.unsqueeze(1)


def evaluate_explainer(explainer, normalization, x, grand, null, num_players, inference=False):

    # Evaluate explainer.
    # S=torch.ones_like(x)
    # pred = explainer((x,S))
    pred = explainer(x)

    # Reshape SHAP values.
    if len(pred.shape) == 4:
        # Image.
        image_shape = pred.shape
        pred = pred.reshape(len(x), -1, num_players)
        pred = pred.permute(0, 2, 1)
    else:
        # Tabular.
        image_shape = None
        pred = pred.reshape(len(x), num_players, -1)

    # For pre-normalization efficiency gap.
    total = pred.sum(dim=1)

    # Apply normalization.
    if normalization:
        pred = normalization(pred, grand, null)

    # Reshape for inference.
    if inference:
        if image_shape is not None:
            pred = pred.permute(0, 2, 1)
            pred = pred.reshape(image_shape)

        return pred

    return pred, total


def calculate_grand_coalition(dataset, imputer, batch_size, link, device, num_workers):

    ones = torch.ones(batch_size, imputer.num_players, dtype=torch.float32, device=device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        pin_memory=True, num_workers=num_workers)
    with torch.no_grad():
        grand = []
        for (x,) in loader:
            grand.append(link(imputer(x.to(device), ones[:len(x)])))

        # Concatenate and return.
        grand = torch.cat(grand)
        if len(grand.shape) == 1:
            grand = grand.unsqueeze(1)

    return grand


def generate_validation_data(val_set, imputer, validation_samples, sampler, batch_size, link, device, num_workers):
    
    # Generate coalitions.
    val_S = sampler.sample(
        validation_samples * len(val_set), paired_sampling=True).reshape(
        len(val_set), validation_samples, imputer.num_players)

    # Get values.
    val_values = []
    for i in range(validation_samples):
        # Set up data loader.
        dset = DatasetRepeat([val_set, TensorDataset(val_S[:, i])])
        loader = DataLoader(dset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, num_workers=num_workers)
        values = []

        for x, S in loader:
            values.append(link(imputer(x.to(device), S.to(device))).cpu().data)

        val_values.append(torch.cat(values))

    val_values = torch.stack(val_values, dim=1)
    return val_S, val_values


def validate_explainer(val_loader, imputer, explainer, null, link, normalization):

    with torch.no_grad():
        # Setup.
        device = next(explainer.parameters()).device
        mean_loss = 0
        N = 0
        loss_fn = nn.MSELoss()

        for x, grand, S, values in val_loader:
            # Move to device.
            x = x.to(device)
            S = S.to(device)
            grand = grand.to(device)
            values = values.to(device)

            # Evaluate explainer.
            pred, _ = evaluate_explainer(
                explainer, normalization, x, grand, null, imputer.num_players)

            # Calculate loss.
            approx = null + torch.matmul(S, pred)
            loss = loss_fn(approx, values)

            # Update average.
            N += len(x)
            mean_loss += len(x) * (loss - mean_loss) / N

    return mean_loss


class FastSHAP:

    def __init__(self,
                 explainer,
                 imputer,
                 normalization='none',
                 link=None):
        # Set up explainer, imputer and link function.
        self.explainer = explainer
        self.imputer = imputer
        self.num_players = imputer.num_players
        self.null = None
        if link is None or link == 'none':
            self.link = nn.Identity()
        elif isinstance(link, nn.Module):
            self.link = link
        else:
            raise ValueError('unsupported link function: {}'.format(link))

        # Set up normalization.
        if normalization is None or normalization == 'none':
            self.normalization = None
        elif normalization == 'additive':
            self.normalization = additive_efficient_normalization
        elif normalization == 'multiplicative':
            self.normalization = multiplicative_efficient_normalization
        else:
            raise ValueError('unsupported normalization: {}'.format(
                normalization))

    def train(self,
              train_data,
              val_data,
              batch_size,
              num_samples,
              max_epochs,
              lr=2e-4,
              min_lr=1e-5,
              lr_factor=0.5,
              weight_decay=0.01,
              eff_lambda=0,
              paired_sampling=True,
              validation_samples=None,
              lookback=5,
              training_seed=None,
              validation_seed=None,
              num_workers=0,
              bar=False,
              verbose=False):

        # Set up explainer model.
        explainer = self.explainer
        num_players = self.num_players
        imputer = self.imputer
        link = self.link
        normalization = self.normalization
        explainer.train()
        device = next(explainer.parameters()).device

        # Verify other arguments.
        if validation_samples is None:
            validation_samples = num_samples

        # Set up train dataset.
        if isinstance(train_data, np.ndarray):
            x_train = torch.tensor(train_data, dtype=torch.float32)
            train_set = TensorDataset(x_train)
        elif isinstance(train_data, torch.Tensor):
            train_set = TensorDataset(train_data)
        elif isinstance(train_data, Dataset):
            train_set = train_data
        else:
            raise ValueError('train_data must be np.ndarray, torch.Tensor or Dataset')

        # Set up validation dataset.
        if isinstance(val_data, np.ndarray):
            x_val = torch.tensor(val_data, dtype=torch.float32)
            val_set = TensorDataset(x_val)
        elif isinstance(val_data, torch.Tensor):
            val_set = TensorDataset(val_data)
        elif isinstance(val_data, Dataset):
            val_set = val_data
        else:
            raise ValueError('train_data must be np.ndarray, torch.Tensor or Dataset')

        # Grand coalition value.
        grand_train = calculate_grand_coalition(
            train_set, imputer, batch_size * num_samples, link, device,
            num_workers).cpu()
        grand_val = calculate_grand_coalition(
            val_set, imputer, batch_size * num_samples, link, device,
            num_workers).cpu()
        
        # print("grand_trian",grand_train.shape)

        # Null coalition.
        with torch.no_grad():
            zeros = torch.zeros(1, num_players, dtype=torch.float32, device=device)
            null = link(imputer(train_set[0][0].unsqueeze(0).to(device), zeros))
            if len(null.shape) == 1:
                null = null.reshape(1, 1)
        self.null = null

        # Set up train loader.
        # print("train_set",len(train_set))
        train_set = DatasetRepeat([train_set, TensorDataset(grand_train)])
        # print("train_set_rep",len(train_set))
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=num_workers)

        # Generate validation data.
        sampler = ShapleySampler(num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)
        val_S, val_values = generate_validation_data(val_set, imputer, validation_samples, sampler, batch_size * num_samples, link, device, num_workers)

        # Set up val loader.
        val_set = DatasetRepeat(
            [val_set, TensorDataset(grand_val, val_S, val_values)])
        val_loader = DataLoader(val_set, batch_size=batch_size * num_samples, pin_memory=True, num_workers=num_workers)

        # Setup for training.
        loss_fn = nn.MSELoss()
        optimizer = optim.AdamW(explainer.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, verbose=verbose)
        self.loss_list = []
        best_loss = np.inf
        best_epoch = -1
        best_model = None
        if training_seed is not None:
            torch.manual_seed(training_seed)

        for epoch in range(max_epochs):
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc='Training epoch')
            else:
                batch_iter = train_loader

            for x, grand in batch_iter:
                # Sample S.
                S = sampler.sample(batch_size * num_samples, paired_sampling=paired_sampling)

                # Move to device.
                x = x.to(device)
                S = S.to(device)
                grand = grand.to(device)

                # print("x",x.shape, x)
                # print("S",S.shape, S)
                # print("grand",grand.shape, grand)

                # Evaluate value function.
                x_tiled = x.unsqueeze(1).repeat(
                    1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                    ).reshape(batch_size * num_samples, *x.shape[1:])
                # print("x_tiled",x_tiled.shape, x_tiled)
                with torch.no_grad():
                    values = link(imputer(x_tiled, S))
                    # print("values",values.shape, values)

                # Evaluate explainer.
                pred, total = evaluate_explainer(explainer, normalization, x, grand, null, num_players)
                # print("pred",pred.shape, pred)

                # Calculate loss.
                S = S.reshape(batch_size, num_samples, num_players)
                # print("S",S.shape, S)
                values = values.reshape(batch_size, num_samples, -1)
                # print("values",values.shape, values)
                # print("null",null.shape, null)
                # print("matmul",torch.matmul(S, pred).shape, torch.matmul(S, pred))
                approx = null + torch.matmul(S, pred)
                # print("approx",approx.shape, approx)
                loss = loss_fn(approx, values)
                if eff_lambda:
                    loss = loss + eff_lambda * loss_fn(total, grand - null)

                # Take gradient step.
                loss = loss * num_players
                loss.backward()
                optimizer.step()
                explainer.zero_grad()

            #     break
            # break

            # Evaluate validation loss.
            explainer.eval()
            val_loss = num_players * validate_explainer(val_loader, imputer, explainer, null, link, normalization).item()
            explainer.train()

            # Save loss, print progress.
            if verbose:
                print('----- Epoch = {} -----'.format(epoch + 1))
                print(f'Val loss = {round(val_loss,6)}')
                print('')
            scheduler.step(val_loss)
            self.loss_list.append(val_loss)

            # Check for convergence.
            if self.loss_list[-1] < best_loss:
                best_loss = self.loss_list[-1]
                best_epoch = epoch
                best_model = deepcopy(explainer)
                if verbose:
                    print(f'New best epoch, loss = {round(val_loss,6)}')
                    print('')
            elif epoch - best_epoch == lookback:
                if verbose:
                    print('Stopping early at epoch = {}'.format(epoch))
                break

        # Copy best model.
        for param, best_param in zip(explainer.parameters(), best_model.parameters()):
            param.data = best_param.data
        explainer.eval()

    def shap_values(self, x):

        # Data conversion.
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        elif isinstance(x, torch.Tensor):
            pass
        else:
            raise ValueError('data must be np.ndarray or torch.Tensor')

        # Ensure null coalition is calculated.
        device = next(self.explainer.parameters()).device
        if self.null is None:
            with torch.no_grad():
                zeros = torch.zeros(1, self.num_players, dtype=torch.float32, device=device)
                null = self.link(self.imputer(x[:1].to(device), zeros))
            if len(null.shape) == 1:
                null = null.reshape(1, 1)
            self.null = null

        # Generate explanations.
        with torch.no_grad():
            # Calculate grand coalition (for normalization).
            if self.normalization:
                grand = calculate_grand_coalition(
                    x, self.imputer, len(x), self.link, device, 0)
            else:
                grand = None

            # Evaluate explainer.
            x = x.to(device)
            pred = evaluate_explainer(
                self.explainer, self.normalization, x, grand, self.null,
                self.imputer.num_players, inference=True)

        return pred.cpu().data.numpy()
    
    # def __call__(self, x, S):

    #     return self.expl((x, S))

# Dataset

## Census

In [None]:
# Load and split data
dataset="Census"
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train.values)
X_val_s = ss.transform(X_val.values)
X_test_s = ss.transform(X_test.values)

print(*shap.datasets.adult()[0].shape)

## Magic


In [None]:
data = np.loadtxt('data/magic04.data', dtype=object, delimiter=',')
X = data[:,:-1]
Y =data[:,-1]
X = X.astype(float)
Y = Y.astype(str)
mapper={"h": 0, "g": 1}
Y =[mapper[el] for el in Y]

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = ["fLength","fWidth","fSize","fConc","fConc1","fAsym","fM3Long", "fM3Trans", "fAlpha", "fDist"]
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Magic"

## Credit

In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
dataset = pd.read_csv('data/credit_card.csv', sep=",")
dataset=dataset.drop("ID", axis=1)
dataset.head()

#mapper={"present":1,"absent":0}
Y=dataset["DEFAULT_PAYMENT"].values
#Y=[mapper[el] for el in y]
X=dataset.drop("DEFAULT_PAYMENT",axis=1)
columns=X.columns
X=X.values
print(X.shape, Y.shape)


X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)

num_features = X_train.shape[1]
feature_names = columns
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)

print(collections.Counter(Y_train))
print(collections.Counter(Y_val))
print(collections.Counter(Y_test))

dataset="Credit"

## Jannis

In [None]:
dataset = pd.read_csv('data/jannis.txt', sep=",")
columns=[f"V{i}" for i in range(0,dataset.shape[1]-1)]
columns.append("Class")
dataset.columns=columns

Y=dataset["Class"].values
#print(collections.Counter(Y))
X=dataset.drop("Class",axis=1)
X=X.values
#print(len(X),len(Y))

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)

num_features = X_train.shape[1]
feature_names = columns
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)

print(num_features)
print(collections.Counter(Y_train))
print(collections.Counter(Y_val))
print(collections.Counter(Y_test))

dataset="Jannis"

## Bank

In [None]:
data = np.loadtxt('data/data_banknote_authentication.txt', dtype=object, delimiter=',')
X = data[:,:-1]
Y =data[:,-1]
Y = [int(el) for el in Y]

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)



# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = ["variance","skewness","curtosis","entropy"]
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Bank"

## Diabetes

In [None]:
data = np.loadtxt('data/diabetes_data.txt', dtype=object, delimiter=',')
X = data[:,:-1]
Y =data[:,-1]
mapper={"tested_negative": 0, "tested_positive": 1}
Y =[mapper[el] for el in Y]

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = ["preg","plas","pres","skin","insu","mass","pedi", "age"]
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Diabetes"

## Mozilla

In [None]:
data = np.loadtxt('data/mozzilla_data.txt', dtype=object, delimiter=',')
X = data[:,1:-1]
Y =data[:,-1]
Y=[int(el) for el in Y]

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = ["start","end","event","size"]
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Mozilla"

## Phoneme

In [None]:
data = np.loadtxt('data/phoneme.txt', dtype=object, delimiter=',')
X = data[:,:-1]
Y =data[:,-1]
mapper={'1': 0, '2': 1}
Y =[mapper[el] for el in Y]

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = ["v1","v2","v3","v4","v5"]
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Phoneme"

## Click

In [None]:
dataset = pd.read_csv('data/click.arff', sep=",", header=None)
dataset.shape
print(dataset[0].value_counts()/len(dataset))
data=dataset.values
X=data[:,1:]
Y=data[:,0]
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
# feature_names = ['Area', 'Perimeter', 'Major_Axis', 'Minor_Axis', 'Eccentricity', 'Convex_Area', 'Extent_Real']
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Click"

## Bankruptcy

In [None]:
dataset = pd.read_csv('data/bank.csv', sep=";")
dataset.head()
X=dataset.drop("y",axis=1)
columns=X.columns

Y=dataset["y"]
mapper={"no":0,"yes":1}
Y=[mapper[el] for el in Y]

# Encode categorical variables on X with EncoderLabel
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()

for col in ['job', 'marital', 'education', 'default', 'housing',
       'loan', 'contact', 'month', 'poutcome']:
    X[col] = encoder.fit_transform(X[col])

X=X.values

X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, X_test.shape)

# Data scaling
num_features = X_train.shape[1]
print(num_features)
feature_names = columns
ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)
X_test_s = ss.transform(X_test)
dataset="Bankruptcy"

## KDD Cup 99

In [None]:
cols="""duration,
protocol_type,
service,
flag,
src_bytes,
dst_bytes,
land,
wrong_fragment,
urgent,
hot,
num_failed_logins,
logged_in,
num_compromised,
root_shell,
su_attempted,
num_root,
num_file_creations,
num_shells,
num_access_files,
num_outbound_cmds,
is_host_login,
is_guest_login,
count,
srv_count,
serror_rate,
srv_serror_rate,
rerror_rate,
srv_rerror_rate,
same_srv_rate,
diff_srv_rate,
srv_diff_host_rate,
dst_host_count,
dst_host_srv_count,
dst_host_same_srv_rate,
dst_host_diff_srv_rate,
dst_host_same_src_port_rate,
dst_host_srv_diff_host_rate,
dst_host_serror_rate,
dst_host_srv_serror_rate,
dst_host_rerror_rate,
dst_host_srv_rerror_rate"""

columns=[]
for c in cols.split(','):
    if(c.strip()):
       columns.append(c.strip())

columns.append('target')
#print(columns)
print(len(columns))

attacks_types = {
    'normal': 'normal',
'back': 'dos',
'buffer_overflow': 'u2r',
'ftp_write': 'r2l',
'guess_passwd': 'r2l',
'imap': 'r2l',
'ipsweep': 'probe',
'land': 'dos',
'loadmodule': 'u2r',
'multihop': 'r2l',
'neptune': 'dos',
'nmap': 'probe',
'perl': 'u2r',
'phf': 'r2l',
'pod': 'dos',
'portsweep': 'probe',
'rootkit': 'u2r',
'satan': 'probe',
'smurf': 'dos',
'spy': 'r2l',
'teardrop': 'dos',
'warezclient': 'r2l',
'warezmaster': 'r2l',
}

path = "data/KDDcup99/kddcup.data.corrected"
df = pd.read_csv(path,names=columns)

df['Attack Type'] = df.target.apply(lambda r:attacks_types[r[:-1]])

#Finding categorical features
num_cols = df._get_numeric_data().columns

cate_cols = list(set(df.columns)-set(num_cols))
cate_cols.remove('target')
cate_cols.remove('Attack Type')

cate_cols

df = df.dropna('columns')# drop columns with NaN

df = df[[col for col in df if df[col].nunique() > 1]]# keep columns where there are more than 1 unique values

df.drop('num_root',axis = 1,inplace = True)
df.drop('srv_serror_rate',axis = 1,inplace = True)
df.drop('srv_rerror_rate',axis = 1, inplace=True)
df.drop('dst_host_srv_serror_rate',axis = 1, inplace=True)
df.drop('dst_host_serror_rate',axis = 1, inplace=True)
df.drop('dst_host_rerror_rate',axis = 1, inplace=True)
df.drop('dst_host_srv_rerror_rate',axis = 1, inplace=True)
df.drop('dst_host_same_srv_rate',axis = 1, inplace=True)

pmap = {'icmp':0,'tcp':1,'udp':2}
df['protocol_type'] = df['protocol_type'].map(pmap)

fmap = {'SF':0,'S0':1,'REJ':2,'RSTR':3,'RSTO':4,'SH':5 ,'S1':6 ,'S2':7,'RSTOS0':8,'S3':9 ,'OTH':10}
df['flag'] = df['flag'].map(fmap)

df.drop('service',axis = 1,inplace= True)

df = df.drop(['target',], axis=1)
print(df.shape)

In [None]:
df_d1 = df[:77753]
df_d2 = df[77753:155506]
df_d3 = df[155506:233259]
df_d4 = df[233259:311012]
df_d5 = df[311012:388765]
df_d6 = df[388765:466518]

df_d=df_d1
dataset= "KDD_1"

In [None]:
# map Attack Type to 0 if normal, 1 otherwise
df_d['Label'] = df_d['Attack Type'].apply(lambda x: 0 if x == 'normal' else 1)

Y = df_d[['Label']]
X = df_d.drop(['Attack Type',], axis=1)
feature_names = X.columns

X=X.values

X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.3, random_state=0)
num_features = X_train.shape[1]
print(X_train.shape, X_val.shape, num_features, feature_names)

ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)

## BTC

In [None]:
data_dir = 'data/BitcoinHistoryDataset'
data_path = os.path.join(data_dir, 'BTC-2019min.csv')
data = pd.read_csv(data_path, engine='c')
data.head()

# split date into year, month, day, hour
data['Date'] = pd.to_datetime(data['date'])
data['Year'] = data['Date'].dt.year
data['Month'] = data['Date'].dt.month
data['Day'] = data['Date'].dt.day
data['Hour'] = data['Date'].dt.hour
data['Minute'] = data['Date'].dt.minute

# sort by date
data = data.sort_values('Date')

# drop column unix, date, Date, symbol
data = data.drop(['unix', 'date', 'symbol'], axis=1)

old_columns = data.columns
    
def get_features(df):
    df['price mean'] = df[['open', 'high', 'low', 'close']].mean(axis = 1)
    df['upper shadow'] = df['high'] - np.maximum(df['open'], df['close'])
    df['lower shadow'] = np.minimum(df['open'], df['close']) - df['low']
    df['spread'] = df['high'] - df['low']
    df['trade'] = df['close'] - df['open']
    df['open close LPC'] = np.log(df['close'] / df['open'])
    df['10 period SMA'] = df['close'].rolling(10).mean().fillna(0)
    df['20 period SMA'] = df['close'].rolling(20).mean().fillna(0)
    df['5 period LR'] = pd.Series(np.log(df['close'])).diff(periods=5).fillna(0)
    df['10 period LR'] = pd.Series(np.log(df['close'])).diff(periods=10).fillna(0)
    df['log norm close'] = np.log(df['close'] + 1)/10
    df['buy/sell'] = df['close'].diff(periods=1)
    df = df.copy().loc[df['buy/sell'].notna()]
    df['buy/sell'] = df['buy/sell'].apply(lambda x: 0 if x<=0 else 1)
    return df

#compute features and drop irrelavant column
data = get_features(data)
data.head()

# drop columns open, high, low, close
data = data.drop(['open', 'high', 'low', 'close','Year'], axis=1)

df_2019 = data[data['Year'] == 2019]

df_01 = data[data['Month'] == 1]
df_02 = data[data['Month'] == 2]
df_03 = data[data['Month'] == 3]
df_04 = data[data['Month'] == 4]
df_05 = data[data['Month'] == 5]
df_06 = data[data['Month'] == 6]
df_07 = data[data['Month'] == 7]
df_08 = data[data['Month'] == 8]
df_09 = data[data['Month'] == 9]
df_10 = data[data['Month'] == 10]
df_11 = data[data['Month'] == 11]
df_12 = data[data['Month'] == 12]

In [None]:
df_y=df_01
dataset= "BTC_2019m_01"


Y=df_y['buy/sell']
tmp=df_y.drop(['buy/sell','Date'], axis=1)
X=tmp.values
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.3, random_state=0)
num_features = X_train.shape[1]
feature_names = tmp.columns
print(X_train.shape, X_val.shape, num_features, feature_names)

ss = StandardScaler()
ss.fit(X_train)
X_train_s = ss.transform(X_train)
X_val_s = ss.transform(X_val)

# Black-Box Model - Train


In [None]:
# Seed value
# Apparently you may use different seed values at each stage
seed_value= 0

# 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)

# 2. Set the `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set the `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.set_random_seed(seed_value)
# tf.random.set_seed(seed_value)
# for later versions: 
# tf.compat.v1.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
# for later versions:
# session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
# sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
# tf.compat.v1.keras.backend.set_session(sess)

In [None]:
bbmodel = Sequential()
bbmodel.add(Dense(64, activation='relu', input_shape=(num_features,)))
bbmodel.add(Dropout(0.5))
bbmodel.add(Dense(64, activation='relu'))
bbmodel.add(Dropout(0.5))
# model.add(Dense(1, activation='sigmoid'))
bbmodel.add(Dense(2))
bbmodel.add(Activation('softmax'))

# Compile the model
bbmodel.compile(optimizer='adam',
            #   loss='binary_crossentropy',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])

# Train the model
bbmodel.fit(X_train_s, Y_train, epochs=50, batch_size=32, validation_data=(X_val_s, Y_val), verbose=0)

# Evaluate the model
loss, accuracy = bbmodel.evaluate(X_test_s, Y_test, verbose=0)
print('Accuracy: %.2f' % (accuracy*100))

# Surrogate - Train

In [None]:
# set all random seeds
SEED=42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

# Select device
device = torch.device('cuda:0')

def original_model(x):
    pred = bbmodel.predict(x.cpu().numpy())
    return torch.tensor(pred, dtype=torch.float32, device=x.device)


# Check for model
if os.path.isfile(f'checkpoints/{dataset}_surrogate.pt'): 
    print('Loading saved surrogate model')
    surr = torch.load(f'checkpoints/{dataset}_surrogate.pt').to(device)
    surrogate = Surrogate(surr, num_features)
else:
    if dataset=="Census": # SAME PARAMETERS AS IN FASTSHAP
        print("Census!")
        surr = nn.Sequential(
                MaskLayer1d(value=0, append=True),
                nn.Linear(2 * num_features, 128),
                nn.ELU(inplace=True),
                nn.Linear(128, 128),
                nn.ELU(inplace=True),
                nn.Linear(128, 2)).to(device)

        # Set up surrogate object
        surrogate = Surrogate(surr, num_features)

        # Train
        start=time.time()
        surrogate.train_original_model(
            X_train_s,
            X_val_s[:200],
            original_model,
            batch_size=32,
            max_epochs=200,
            loss_fn=KLDivLoss(),  #KLDivLoss(),
            lookback=10,
            validation_samples=128,
            validation_batch_size=10000,
            training_seed=SEED,
            verbose=True)
        end=time.time()
        print("Training Time:",(end-start))
    else:
        print(dataset)
        Layer_size=512
        surr = nn.Sequential(
            MaskLayer1d(value=0, append=True), 
            nn.Linear( 2*num_features, Layer_size), 
            nn.LeakyReLU(inplace=True),
            nn.Linear(Layer_size, Layer_size),
            nn.LeakyReLU(inplace=True),
            nn.Linear(Layer_size, 2),
        ).to(device)

        surrogate = Surrogate(surr, num_features)

        start=time.time()
        surrogate.train_original_model(
            X_train_s,
            X_val_s,
            original_model,
            batch_size=32, #8
            max_epochs=200,
            loss_fn=KLDivLoss(),
            validation_samples=10,
            validation_batch_size=10000,
            verbose=True,
            lr=1e-4, #1e-4
            min_lr=1e-8,
            lr_factor=0.5,
            weight_decay=0.01,
            training_seed=SEED,
            lookback=20
        )
        end=time.time()
        print((end-start))
    
    surr.cpu()
    torch.save(surr, f'checkpoints/{dataset}_surrogate.pt')
    surr.to(device)

# FastSHAP - Train

In [None]:
SEED=42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
print(device)

# Check for model
if os.path.isfile(f'checkpoints/{dataset}_explainer.pt'): 
    print('Loading saved explainer model')
    explainer = torch.load(f'checkpoints/{dataset}_explainer.pt').to(device)
    fastshap = FastSHAP(explainer, surrogate, normalization='additive', link=nn.Softmax(dim=-1))
else:
    if dataset=="Census": # SAME PARAMETERS AS IN FASTSHAP
        print("Census!")
        explainer = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2 * num_features)).to(device)

        # Set up FastSHAP object
        fastshap = FastSHAP(explainer, surrogate, normalization='additive', link=nn.Softmax(dim=-1))

        # Train
        start=time.time()
        fastshap.train(
            X_train_s,
            X_val_s[:200],
            batch_size=32,
            num_samples=32,
            max_epochs=200,
            lookback=10,
            validation_samples=128,
            training_seed=SEED,
            verbose=True)
        end=time.time()
        print("Training Time:",(end-start))
    else:
        print(dataset)
        LAYER_SIZE=256
        explainer = nn.Sequential(
            # MaskLayer1d(value=0, append=True),
            nn.Linear(num_features, LAYER_SIZE),
            nn.LeakyReLU(inplace=True),
            nn.Linear(LAYER_SIZE, LAYER_SIZE),
            nn.LeakyReLU(inplace=True),
            nn.Linear(LAYER_SIZE, 2 * num_features)).to(device)

        # Set up FastSHAP object
        fastshap = FastSHAP(explainer, surrogate, normalization='additive', link=nn.Softmax(dim=-1))

        # Train\
        start=time.time()
        fastshap.train(
            X_train_s,
            X_val_s[:200],
            batch_size=32,
            num_samples=32,
            max_epochs=400,#200
            validation_samples=32, #128
            verbose=True,
            paired_sampling=True,
            lr=2e-4, #1e-4
            min_lr=1e-8,
            lr_factor=0.5,
            weight_decay=0.01,
            training_seed=SEED,
            lookback=20,
        )
        end=time.time()
        print((end-start))

    
    explainer.cpu()
    torch.save(explainer, f'checkpoints/{dataset}_explainer.pt')
    explainer.to(device)

# LightningSHAP

## Code

In [None]:
def validate_STFS(model, loss_fn1, loss_fn2, data_loader, batch_size, num_samples, sampler, sampler_surr, paired_sampling, epoch, loss_used):
    #print('validate_STFS')
    with torch.no_grad():
        # Setup.
        device = next(model.model.parameters()).device
        mean_loss = 0
        mean_loss1 = 0
        mean_loss2 = 0
        mean_loss3 = 0
        mean_loss4 = 0
        mean_loss5 = 0
        N = 0
        link=nn.Softmax(dim=-1)

        # COMPUTE NULL COALITION
        sample=data_loader.dataset[0][0]
        sample = sample.to(device)
        zeros=torch.zeros(1, model.num_players, device=device)
        null=model(sample, zeros)
        null_reshape = null.reshape(1, model.num_players, -1)
        null_sum = null_reshape.sum(dim=1)
        null=link(null_sum)
        if len(null.shape) == 1:
            null = null.reshape(1, 1)

        # print("VALIDATION")

        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            # Generate subsets.
            S = sampler.sample(batch_size*num_samples, paired_sampling=paired_sampling).to(device=device)
            S_surr = sampler_surr.sample(batch_size).to(device=device)

            pred_xs = model(x, S_surr)
            pred_xs_reshape = pred_xs.reshape(len(x), model.num_players, -1)
            pred_xs_sum = pred_xs_reshape.sum(dim=1)

            loss1 = loss_fn1(pred_xs_sum, y)

            ones=torch.ones_like(x).to(device)
            pred=model(x, ones)
            pred_reshape = pred.reshape(len(x), model.num_players, -1)
            grand_sum = pred_reshape.sum(dim=1)
            grand=link(grand_sum)
            
            pred_eff = additive_efficient_normalization(pred_reshape, y, null) ################### NORMALIZATION WITH Y
            total=pred_eff.sum(dim=1)

            x_tiled = x.unsqueeze(1).repeat(
                1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                ).reshape(batch_size * num_samples, *x.shape[1:])
            
            
            val = model(x_tiled, S)
            val_reshape = val.reshape(len(x_tiled), model.num_players, -1)
            val_sum = val_reshape.sum(dim=1)

            values = link(val_sum)
            
            S=S.reshape(batch_size, num_samples, model.num_players)
            values=values.reshape(batch_size, num_samples, -1)

            approx = null + torch.matmul(S, pred_eff)

            loss2 = loss_fn2(approx, values)
            loss2 = loss2 * 1
            loss3 = loss_fn2(total, y-null)

            ####################################################################################
            ################### FORCE LOSS ON SINGLE FEATURE AT TIME ###########################
            eye=torch.eye(model.num_players)
            eye=eye.repeat(batch_size, 1).to(device)
            x_extended=torch.cat([el.unsqueeze(0).repeat(model.num_players,1) for el in x])
            single_pred=model(x_extended, eye)
            single_pred=single_pred.reshape(batch_size*model.num_players, model.num_players, -1)
            single_pred_sum = single_pred.sum(dim=1)
            single_pred_surr=link(single_pred_sum)
            single_pred_eff=additive_efficient_normalization(single_pred, single_pred_surr, null)
            delta=single_pred_surr-null
            tmp=[]
            i=0
            for el in single_pred_eff:
                tmp.append(el[i].unsqueeze(0))
                i+=1
                if i==model.num_players:
                    i=0
            tmp=torch.cat(tmp)
            loss4=loss_fn2(tmp, delta)
            ####################################################################################

            loss5=loss_fn2(y,grand)

            if epoch>0:
                loss = loss1 + loss2 #+ loss4 + loss5
                if "L3" in loss_used:
                    loss += loss3
                if "L4" in loss_used:
                    loss += loss4
                if "L5" in loss_used:
                    loss += loss5
            else:
                loss = loss1

            N += len(x)
            mean_loss += len(x) * (loss - mean_loss) / N
            mean_loss1 += len(x) * (loss1 - mean_loss1) / N
            mean_loss2 += len(x) * (loss2 - mean_loss2) / N
            mean_loss3 += len(x) * (loss3 - mean_loss3) / N
            mean_loss4 += len(x) * (loss4 - mean_loss4) / N
            mean_loss5 += len(x) * (loss5 - mean_loss5) / N

    return mean_loss, mean_loss1, mean_loss2, mean_loss3, mean_loss4, mean_loss5


def generate_labels_STFS(dataset, model, batch_size):

    with torch.no_grad():
        # Setup.
        preds = []
        if isinstance(model, torch.nn.Module):
            device = next(model.parameters()).device
        else:
            device = torch.device('cpu')
        loader = DataLoader(dataset, batch_size=batch_size)

        for (x,) in loader:
            pred = model(x.to(device)).cpu()
            preds.append(pred)

    return torch.cat(preds)

def additive_efficient_normalization(pred, grand, null):
    gap = (grand - null) - torch.sum(pred, dim=1)
    return pred + gap.unsqueeze(1) / pred.shape[1]


def multiplicative_efficient_normalization(pred, grand, null):
    ratio = (grand - null) / torch.sum(pred, dim=1)
    return pred * ratio.unsqueeze(1)


class LightningSHAP:

    def __init__(self, model, om, num_features, groups=None):
        # Store surrogate model.
        self.model = model
        self.batch_size = None
        self.validation_batch_size = None
        self.num_samples = None
        self.link = None
        self.bbm=om

        # Store feature groups.
        if groups is None:
            self.num_players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.num_players = len(groups)
            device = next(surrogate.parameters()).device
            self.groups_matrix = torch.zeros(
                len(groups), num_features, dtype=torch.float32, device=device)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = 1

                
    def train_original_model(self,
                             train_data,
                             val_data,
                             original_model,
                             batch_size,
                             max_epochs,
                             loss_fn1,
                             loss_fn2,
                             validation_samples=1,
                             validation_batch_size=None,
                             lr=None,
                             min_lr=None,
                             lr_factor=None,
                             weight_decay=None,
                             lookback=None,
                             num_samples=None,
                             training_seed=None,
                             validation_seed=None,
                             paired_sampling=False,
                             bar=False,
                             verbose=False,
                             loss_used=["L3","L4","L5"]):

        # Set up train dataset.
        if isinstance(train_data, np.ndarray):
            train_data = torch.tensor(train_data, dtype=torch.float32)
        y_tr = generate_labels_STFS(TensorDataset(train_data), original_model, batch_size)

        if isinstance(train_data, torch.Tensor):
            train_set = TensorDataset(train_data, y_tr)
        elif isinstance(train_data, Dataset):
            train_set = train_data
        else:
            raise ValueError('train_data must be either tensor or a PyTorch Dataset')

        # Set up train data loader.
        random_sampler = RandomSampler(train_set, replacement=True, num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
        batch_sampler = BatchSampler(random_sampler, batch_size=batch_size, drop_last=True)
        train_loader = DataLoader(train_set, batch_sampler=batch_sampler, num_workers=4)

        # Set up validation dataset.
        sampler_surr=UniformSampler(self.num_players)
        sampler = ShapleySampler(self.num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)
        # S_val = sampler.sample(len(val_data) * num_samples, paired_sampling=paired_sampling)
        # S_val_surr = sampler_surr.sample(len(val_data))
        if validation_batch_size is None:
            validation_batch_size = batch_size

        if isinstance(val_data, np.ndarray):
            val_data = torch.tensor(val_data, dtype=torch.float32)

        if isinstance(val_data, torch.Tensor):
            # Generate validation labels.
            y_val = generate_labels_STFS(TensorDataset(val_data), original_model, validation_batch_size)
            y_val_repeat = y_val.repeat(validation_samples, *[1 for _ in y_val.shape[1:]])

            # Create dataset.
            val_data_repeat = val_data.repeat(validation_samples, 1)
            val_set = TensorDataset(val_data_repeat, y_val_repeat)
        else:
            raise ValueError('val_data must be either tuple of tensors or a PyTorch Dataset')

        val_loader = DataLoader(val_set, batch_size=validation_batch_size, drop_last=True, num_workers=4)

        self.batch_size = batch_size
        self.validation_batch_size = validation_batch_size
        self.num_samples = num_samples 
        # self.bbm = original_model

        # Setup for training.
        link=nn.Softmax(dim=-1)
        model = self.model
        device = next(model.parameters()).device
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        #optimizer = MTAdam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_factor, patience=int(lookback // 2), min_lr=min_lr,verbose=verbose)
        best_loss = 100000
        best_epoch = 0
        best_model = deepcopy(model)
        val_loss_list = []
        val_loss1_list = []
        val_loss2_list = []
        val_loss3_list = []
        val_loss4_list = []
        val_loss5_list = []
        train_loss_list = []
        train_loss1_list = []
        train_loss2_list = []
        train_loss3_list = []
        train_loss4_list = []
        train_loss5_list = []
        if training_seed is not None:
            torch.manual_seed(training_seed)

        #print('STFS_training')
        print("#"*50)
        print('Training surrogate model with LOSSES:',loss_used)
        print("#"*50)
        for epoch in range(max_epochs):
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc='Training epoch')
            else:
                batch_iter = train_loader

            mean_loss = 0
            mean_loss1 = 0
            mean_loss2 = 0
            mean_loss3 = 0
            mean_loss4 = 0
            mean_loss5 = 0
            N = 0

            iter=0
            for (x,y) in batch_iter:
                iter+=1
                # Prepare data.
                x = x.to(device)
                y = y.to(device)

                # Generate subsets.
                S = sampler.sample(batch_size*num_samples, paired_sampling=paired_sampling).to(device=device)
                S_surr = sampler_surr.sample(batch_size).to(device=device)

                pred_xs = self.__call__(x, S_surr)
                pred_xs_reshape = pred_xs.reshape(len(x), self.num_players, -1)
                pred_xs_sum = pred_xs_reshape.sum(dim=1)

                loss1 = loss_fn1(pred_xs_sum, y)

                # COMPUTE NULL COALITION
                self.model.eval()
                with torch.no_grad():
                    zeros=torch.zeros(1, self.num_players, device=device)
                    null=self.__call__(x[:1], zeros)
                    null_reshape = null.reshape(1, self.num_players, -1)
                    null_sum = null_reshape.sum(dim=1)
                    null=link(null_sum)
                    if len(null.shape) == 1:
                        null = null.reshape(1, 1)
                self.model.train()

                ones=torch.ones_like(x).to(device)
                pred=self.__call__(x, ones)
                pred_reshape = pred.reshape(len(x), self.num_players, -1)
                grand_sum = pred_reshape.sum(dim=1)
                grand=link(grand_sum)
                
                pred_eff = additive_efficient_normalization(pred_reshape, y, null)
                total=pred_eff.sum(dim=1)


                x_tiled = x.unsqueeze(1).repeat(
                    1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                    ).reshape(batch_size * num_samples, *x.shape[1:])
                
                val = self.__call__(x_tiled, S)
                val_reshape = val.reshape(len(x_tiled), self.num_players, -1)
                val_sum = val_reshape.sum(dim=1)

                values = link(val_sum)
                
                S=S.reshape(batch_size, num_samples, self.num_players)
                values=values.reshape(batch_size, num_samples, -1)
                approx = null + torch.matmul(S, pred_eff)

                loss2 = loss_fn2(approx, values)
                loss2 = loss2 * 1
                loss3 = loss_fn2(total, y-null)


                ####################################################################################
                ################### FORCE LOSS ON SINGLE FEATURE AT TIME ###########################
                eye=torch.eye(self.num_players)
                eye=eye.repeat(batch_size, 1).to(device)
                x_extended=torch.cat([el.unsqueeze(0).repeat(self.num_players,1) for el in x])
                single_pred=self.__call__(x_extended, eye)
                single_pred=single_pred.reshape(batch_size*self.num_players, self.num_players, -1)
                single_pred_sum = single_pred.sum(dim=1)
                single_pred_surr=link(single_pred_sum)
                single_pred_eff=additive_efficient_normalization(single_pred, single_pred_surr, null)
                delta=single_pred_surr-null
                tmp=[]
                i=0
                for el in single_pred_eff:
                    tmp.append(el[i].unsqueeze(0))
                    i+=1
                    if i==self.num_players:
                        i=0
                tmp=torch.cat(tmp)
                loss4=loss_fn2(tmp, delta)
                ####################################################################################

                loss5=loss_fn2(y,grand)

                if iter<3:
                    print("-"*50)
                    print("Y",y)
                    print("GRAND",grand)
                    print("TOAL",total)

                if epoch>=0:
                    loss = loss1 + loss2*self.num_players #+ loss4*self.num_players + loss5*self.num_players
                    if "L3" in loss_used:
                        loss = loss + loss3*self.num_players
                    if "L4" in loss_used:
                        loss = loss + loss4*self.num_players
                    if "L5" in loss_used:
                        loss = loss + loss5*self.num_players
                else:
                    loss = loss1 
                
                lossprint = loss1.item() + loss2.item()  #+ loss4
                if "L3" in loss_used:
                    lossprint = lossprint + loss3.item() 
                if "L4" in loss_used:
                    lossprint = lossprint + loss4.item() 
                if "L5" in loss_used:
                    lossprint = lossprint + loss5.item() 

                N += len(x)
                mean_loss += len(x) * (lossprint - mean_loss) / N
                mean_loss1 += len(x) * (loss1.item()  - mean_loss1) / N
                mean_loss2 += len(x) * (loss2.item()  - mean_loss2) / N
                mean_loss3 += len(x) * (loss3.item()  - mean_loss3) / N
                mean_loss4 += len(x) * (loss4.item()  - mean_loss4) / N
                mean_loss5 += len(x) * (loss5.item()  - mean_loss5) / N

                # Optimizer step.
                loss.backward()
                optimizer.step()
                model.zero_grad()

                del loss, loss1, loss2, loss3, loss4, loss5
                del values, pred
            gc.collect()
            torch.cuda.empty_cache()

            if verbose:
                print('----- Epoch = {} -----'.format(epoch + 1))

                if epoch>=0:
                    print('Train loss = {:.6f}'.format(mean_loss))
                    print('Train loss1 = {:.6f}'.format(mean_loss1))
                    print('Train loss2 = {:.6f}'.format(mean_loss2))
                    print('Train loss3 = {:.10f}'.format(mean_loss3))
                    print('Train loss4 = {:.6f}'.format(mean_loss4))
                    print('Train loss5 = {:.6f}'.format(mean_loss5))
                else:
                    print('Train loss = {:.6f}'.format(mean_loss1))
                    print('Train loss1 = {:.6f}'.format(mean_loss1))
                    print('Train loss2 = -')
                    print('Train loss3 = -')
                    print('Train loss4 = -')
                    print('Train loss5 = -')
                    
                print('')

            # Evaluate validation loss.
            self.model.eval()
            val_loss, val_loss1, val_loss2, val_loss3, val_loss4, val_loss5 = validate_STFS(self, loss_fn1, loss_fn2, val_loader,  batch_size, num_samples, sampler, sampler_surr, paired_sampling, epoch, loss_used)#.item()
            self.model.train()

            # Print progress.
            if verbose:
                #print('----- Epoch = {} -----'.format(epoch + 1))
                if epoch>=0:
                    print('Val loss = {:.6f}'.format(val_loss))
                    print('Val loss1 = {:.6f}'.format(val_loss1))
                    print('Val loss2 = {:.6f}'.format(val_loss2))
                    print('Val loss3 = {:.10f}'.format(val_loss3))
                    print('Val loss4 = {:.6f}'.format(val_loss4))
                    print('Val loss5 = {:.6f}'.format(val_loss5))
                else:
                    print('Val loss = {:.6f}'.format(val_loss1))
                    print('Val loss1 = {:.6f}'.format(val_loss1))
                    print('Val loss2 = -')
                    print('Val loss3 = -')
                    print('Val loss4 = -')
                    print('Val loss5 = -')
                print('')

            scheduler.step(val_loss)
            val_loss_list.append(val_loss)
            val_loss1_list.append(val_loss1)
            val_loss2_list.append(val_loss2)
            val_loss3_list.append(val_loss3)
            val_loss4_list.append(val_loss4)
            val_loss5_list.append(val_loss5)
            train_loss_list.append(mean_loss)
            train_loss1_list.append(mean_loss1)
            train_loss2_list.append(mean_loss2)
            train_loss3_list.append(mean_loss3)
            train_loss4_list.append(mean_loss4)
            train_loss5_list.append(mean_loss5)

            # Check if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(model)
                best_epoch = epoch
                if verbose:
                    print('\t=> New best epoch, loss = {:.4f}'.format(val_loss))
                    print('')
            elif epoch - best_epoch == lookback:
                if verbose:
                    print('Stopping early')
                break

        # Clean up.
        for param, best_param in zip(model.parameters(), best_model.parameters()):
            param.data = best_param.data
            
        self.val_loss_list = val_loss_list
        self.val_loss1_list = val_loss1_list
        self.val_loss2_list = val_loss2_list
        self.val_loss3_list = val_loss3_list
        self.val_loss4_list = val_loss4_list
        self.val_loss5_list = val_loss5_list
        self.train_loss_list = train_loss_list
        self.train_loss1_list = train_loss1_list
        self.train_loss2_list = train_loss2_list
        self.train_loss3_list = train_loss3_list
        self.train_loss4_list = train_loss4_list
        self.train_loss5_list = train_loss5_list
        self.model.eval()


    def __call__(self, x, S):

        return self.model((x,S))
    

    def shap_values(self, x):

        # Data conversion.
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        elif isinstance(x, torch.Tensor):
            pass
        else:
            raise ValueError('data must be np.ndarray or torch.Tensor')

        # Ensure null coalition is calculated.
        device = next(self.model.parameters()).device
        link=nn.Softmax(dim=-1)
        x=x.to(device)
        
        # Generate explanations.
        with torch.no_grad():
            # Calculate grand coalition (for normalization).

            zeros=torch.zeros(1, self.num_players, device=device)
            null=self.__call__(x,zeros)
            null_reshape = null.reshape(1, self.num_players, -1)
            null_sum = null_reshape.sum(dim=1)
            null=link(null_sum)
            if len(null.shape) == 1:
                null = null.reshape(1, 1)

            ones=torch.ones(1, self.num_players, device=device)
            pred=self.__call__(x, ones)
            pred_reshape = pred.reshape(len(x), self.num_players, -1)
            # grand_sum = pred_reshape.sum(dim=1)
            # grand=link(grand_sum)

            y=self.bbm(x)

            pred = additive_efficient_normalization(pred_reshape, y, null)

        return pred.cpu().data.numpy()

## Train

In [None]:
# LOSS_USED=["L1","L2"]
# LOSS_USED=["L1","L2","L4"]
# LOSS_USED=["L1","L2","L5"]
LOSS_USED=["L1","L2","L4", "L5"]
LOSS_STRING="".join(LOSS_USED)
print(LOSS_STRING)

In [None]:
SEED=42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
if os.path.isfile(f'checkpoints/{dataset}_lshap_{LOSS_STRING}.pt!'): ############################################################################
    print('Loading saved explainer model')
    net = torch.load(f'checkpoints/{dataset}_lshap.pt').to(device)
    def original_model(x):
        pred = bbm_model.predict_proba(x.cpu().numpy())
        return torch.tensor(pred, dtype=torch.float32, device=x.device)
    lshap = LightningSHAP(net, original_model, num_features)
else:
    LAYER_SIZE = 512 #1024
    net = nn.Sequential(
            MaskLayer1d(value=0, append=True),
            nn.Linear(2*num_features, LAYER_SIZE),
            nn.LeakyReLU(inplace=True),
            nn.Linear(LAYER_SIZE, LAYER_SIZE),
            nn.LeakyReLU(inplace=True),
            nn.Linear(LAYER_SIZE, LAYER_SIZE),
            nn.LeakyReLU(inplace=True),
            # nn.Linear(LAYER_SIZE, LAYER_SIZE),
            # nn.LeakyReLU(inplace=True),
            # nn.Linear(LAYER_SIZE, LAYER_SIZE),
            # nn.LeakyReLU(inplace=True),
            nn.Linear(LAYER_SIZE, 2 * num_features)
        ).to(device)

    # Set up original model
    def original_model(x):
        pred = bbm_model.predict_proba(x.cpu().numpy())
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Set up surrogate object
    lshap = LightningSHAP(net, original_model, num_features)

    # Train
    start=time.time()
    lshap.train_original_model(
        X_train_s,
        X_val_s[:200],
        original_model,
        batch_size=32,
        num_samples=32,
        paired_sampling=True,
        max_epochs=1,
        loss_fn1=KLDivLoss(), #KLDivLoss(),
        loss_fn2=nn.MSELoss(), #KLDivLoss(),
        validation_samples=32, #128
        validation_batch_size=None,
        lookback=10,
        lr=5e-4,#2e-4
        min_lr=1e-8,
        weight_decay=1e-2, ######################################
        lr_factor=0.5,
        verbose=True,
        training_seed=SEED,
        loss_used=LOSS_USED
        )
    end=time.time()
    print("Training Time:",(end-start))

# SHAPREG - Code

In [None]:
import warnings
import numpy as np
import matplotlib.pyplot as plt

import numpy as np


class DefaultExtension:
    '''Extend a model by replacing removed features with default values.'''
    def __init__(self, values, model):
        self.model = model
        if values.ndim == 1:
            values = values[np.newaxis]
        elif values[0] != 1:
            raise ValueError('values shape must be (dim,) or (1, dim)')
        self.values = values
        self.values_repeat = values

    def __call__(self, x, S):
        # Prepare x.
        if len(x) != len(self.values_repeat):
            self.values_repeat = self.values.repeat(len(x), 0)

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = self.values_repeat[~S]

        # Make predictions.
        return self.model(x_)


class MarginalExtension:
    '''Extend a model by marginalizing out removed features using their
    marginal distribution.'''
    def __init__(self, data, model):
        self.model = model
        self.data = data
        self.data_repeat = data
        self.samples = len(data)
        # self.x_addr = None
        # self.x_repeat = None

    def __call__(self, x, S):
        # Prepare x and S.
        n = len(x)
        x = x.repeat(self.samples, 0)
        S = S.repeat(self.samples, 0)
        # if self.x_addr != id(x):
        #     self.x_addr = id(x)
        #     self.x_repeat = x.repeat(self.samples, 0)
        # x = self.x_repeat

        # Prepare samples.
        if len(self.data_repeat) != self.samples * n:
            self.data_repeat = np.tile(self.data, (n, 1))

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = self.data_repeat[~S]

        # Make predictions.
        pred = self.model(x_)
        pred = pred.reshape(-1, self.samples, *pred.shape[1:])
        return np.mean(pred, axis=1)


class UniformExtension:
    '''Extend a model by marginalizing out removed features using a
    uniform distribution.'''
    def __init__(self, values, categorical_inds, samples, model):
        self.model = model
        self.values = values
        self.categorical_inds = categorical_inds
        self.samples = samples

    def __call__(self, x, S):
        # Prepare x and S.
        n = len(x)
        x = x.repeat(self.samples, 0)
        S = S.repeat(self.samples, 0)

        # Prepare samples.
        samples = np.zeros((n * self.samples, x.shape[1]))
        for i in range(x.shape[1]):
            if i in self.categorical_inds:
                inds = np.random.choice(
                    len(self.values[i]), n * self.samples)
                samples[:, i] = self.values[i][inds]
            else:
                samples[:, i] = np.random.uniform(
                    low=self.values[i][0], high=self.values[i][1],
                    size=n * self.samples)

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = samples[~S]

        # Make predictions.
        pred = self.model(x_)
        pred = pred.reshape(-1, self.samples, *pred.shape[1:])
        return np.mean(pred, axis=1)


class UniformContinuousExtension:
    '''
    Extend a model by marginalizing out removed features using a
    uniform distribution. Specific to sets of continuous features.

    TODO: should we have caching here for repeating x?

    '''
    def __init__(self, min_vals, max_vals, samples, model):
        self.model = model
        self.min = min_vals
        self.max = max_vals
        self.samples = samples

    def __call__(self, x, S):
        # Prepare x and S.
        x = x.repeat(self.samples, 0)
        S = S.repeat(self.samples, 0)

        # Prepare samples.
        u = np.random.uniform(size=x.shape)
        samples = u * self.min + (1 - u) * self.max

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = samples[~S]

        # Make predictions.
        pred = self.model(x_)
        pred = pred.reshape(-1, self.samples, *pred.shape[1:])
        return np.mean(pred, axis=1)


class ProductMarginalExtension:
    '''Extend a model by marginalizing out removed features the
    product of their marginal distributions.'''
    def __init__(self, data, samples, model):
        self.model = model
        self.data = data
        self.data_repeat = data
        self.samples = samples

    def __call__(self, x, S):
        # Prepare x and S.
        n = len(x)
        x = x.repeat(self.samples, 0)
        S = S.repeat(self.samples, 0)

        # Prepare samples.
        samples = np.zeros((n * self.samples, x.shape[1]))
        for i in range(x.shape[1]):
            inds = np.random.choice(len(self.data), n * self.samples)
            samples[:, i] = self.data[inds, i]

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = samples[~S]

        # Make predictions.
        pred = self.model(x_)
        pred = pred.reshape(-1, self.samples, *pred.shape[1:])
        return np.mean(pred, axis=1)


class SeparateModelExtension:
    '''Extend a model using separate models for each subset of features.'''
    def __init__(self, model_dict):
        self.model_dict = model_dict

    def __call__(self, x, S):
        output = []
        for i in range(len(S)):
            # Extract model.
            row = S[i]
            model = self.model_dict[str(row)]

            # Make prediction.
            output.append(model(x[i:i+1, row]))

        return np.concatenate(output, axis=0)


class ConditionalExtension:
    '''Extend a model by marginalizing out removed features using a model of
    their conditional distribution.'''
    def __init__(self, conditional_model, samples, model):
        self.model = model
        self.conditional_model = conditional_model
        self.samples = samples
        self.x_addr = None
        self.x_repeat = None

    def __call__(self, x, S):
        # Prepare x.
        if self.x_addr != id(x):
            self.x_addr = id(x)
            self.x_repeat = x.repeat(self.samples, 0)
        x = self.x_repeat

        # Prepare samples.
        S = S.repeat(self.samples, 0)
        samples = self.conditional_model(x, S)

        # Replace specified indices.
        x_ = x.copy()
        x_[~S] = samples[~S]

        # Make predictions.
        pred = self.model(x_)
        pred = pred.reshape(-1, self.samples, *pred.shape[1:])
        return np.mean(pred, axis=1)


class ConditionalSupervisedExtension:
    '''Extend a model using a supervised surrogate model.'''
    def __init__(self, surrogate):
        self.surrogate = surrogate

    def __call__(self, x, S):
        return self.surrogate(x, S)


def plot(shapley_values,
         feature_names=None,
         sort_features=True,
         max_features=np.inf,
         orientation='horizontal',
         error_bars=True,
         color='tab:green',
         title='Feature Importance',
         title_size=20,
         tick_size=16,
         tick_rotation=None,
         axis_label='',
         label_size=16,
         figsize=(10, 7),
         return_fig=False):
    '''
    Plot Shapley values.
    Args:
      shapley_values: ShapleyValues object.
      feature_names: list of feature names.
      sort_features: whether to sort features by their values.
      max_features: number of features to display.
      orientation: horizontal (default) or vertical.
      error_bars: whether to include standard deviation error bars.
      color: bar chart color.
      title: plot title.
      title_size: font size for title.
      tick_size: font size for feature names and numerical values.
      tick_rotation: tick rotation for feature names (vertical plots only).
      label_size: font size for label.
      figsize: figure size (if fig is None).
      return_fig: whether to return matplotlib figure object.
    '''
    # Default feature names.
    if feature_names is None:
        feature_names = ['Feature {}'.format(i) for i in
                         range(len(shapley_values.values))]

    # Sort features if necessary.
    if len(feature_names) > max_features:
        sort_features = True

    # Perform sorting.
    values = shapley_values.values
    std = shapley_values.std
    if sort_features:
        argsort = np.argsort(values)[::-1]
        values = values[argsort]
        std = std[argsort]
        feature_names = np.array(feature_names)[argsort]

    # Remove extra features if necessary.
    if len(feature_names) > max_features:
        feature_names = (list(feature_names[:max_features])
                         + ['Remaining Features'])
        values = (list(values[:max_features])
                  + [np.sum(values[max_features:])])
        std = (list(std[:max_features])
               + [np.sum(std[max_features:] ** 2) ** 0.5])

    # Warn if too many features.
    if len(feature_names) > 50:
        warnings.warn('Plotting {} features may make figure too crowded, '
                      'consider using max_features'.format(
                        len(feature_names)), Warning)

    # Discard std if necessary.
    if not error_bars:
        std = None

    # Make plot.
    fig = plt.figure(figsize=figsize)
    ax = fig.gca()

    if orientation == 'horizontal':
        # Bar chart.
        ax.barh(np.arange(len(feature_names))[::-1], values,
                color=color, xerr=std)

        # Feature labels.
        if tick_rotation is not None:
            raise ValueError('rotation not supported for horizontal charts')
        ax.set_yticks(np.arange(len(feature_names))[::-1])
        ax.set_yticklabels(feature_names, fontsize=label_size)

        # Axis labels and ticks.
        ax.set_ylabel('')
        ax.set_xlabel(axis_label, fontsize=label_size)
        ax.tick_params(axis='x', labelsize=tick_size)

    elif orientation == 'vertical':
        # Bar chart.
        ax.bar(np.arange(len(feature_names)), values, color=color,
               yerr=std)

        # Feature labels.
        if tick_rotation is None:
            tick_rotation = 45
        if tick_rotation < 90:
            ha = 'right'
            rotation_mode = 'anchor'
        else:
            ha = 'center'
            rotation_mode = 'default'
        ax.set_xticks(np.arange(len(feature_names)))
        ax.set_xticklabels(feature_names, rotation=tick_rotation, ha=ha,
                           rotation_mode=rotation_mode,
                           fontsize=label_size)

        # Axis labels and ticks.
        ax.set_ylabel(axis_label, fontsize=label_size)
        ax.set_xlabel('')
        ax.tick_params(axis='y', labelsize=tick_size)

    else:
        raise ValueError('orientation must be horizontal or vertical')

    # Remove spines.
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    ax.set_title(title, fontsize=title_size)
    plt.tight_layout()

    if return_fig:
        return fig
    else:
        return


def comparison_plot(comparison_values,
                    comparison_names=None,
                    feature_names=None,
                    sort_features=True,
                    max_features=np.inf,
                    orientation='vertical',
                    error_bars=True,
                    colors=('tab:green', 'tab:blue'),
                    title='Feature Importance Comparison',
                    title_size=20,
                    tick_size=16,
                    tick_rotation=None,
                    axis_label='',
                    label_size=16,
                    legend_loc=None,
                    figsize=(10, 7),
                    return_fig=False):
    '''
    Plot comparison between two different ShapleyValues objects.
    Args:
      comparison_values: tuple of ShapleyValues objects to be compared.
      comparison_names: tuple of names for each ShapleyValues object.
      feature_names: list of feature names.
      sort_features: whether to sort features by their Shapley values.
      max_features: number of features to display.
      orientation: horizontal (default) or vertical.
      error_bars: whether to include standard deviation error bars.
      colors: colors for each set of Shapley values.
      title: plot title.
      title_size: font size for title.
      tick_size: font size for feature names and numerical values.
      tick_rotation: tick rotation for feature names (vertical plots only).
      label_size: font size for label.
      legend_loc: legend location.
      figsize: figure size (if fig is None).
      return_fig: whether to return matplotlib figure object.
    '''
    # Default feature names.
    if feature_names is None:
        feature_names = ['Feature {}'.format(i) for i in
                         range(len(comparison_values[0].values))]

    # Default comparison names.
    num_comps = len(comparison_values)
    if num_comps not in (2, 3, 4, 5):
        raise ValueError('only support comparisons for 2-5 sets of values')
    if comparison_names is None:
        comparison_names = ['Shapley Values {}'.format(i) for i in
                            range(num_comps)]

    # Default colors.
    if colors is None:
        colors = ['tab:green', 'tab:blue', 'tab:purple',
                  'tab:orange', 'tab:pink'][:num_comps]

    # Sort features if necessary.
    if len(feature_names) > max_features:
        sort_features = True

    # Extract values.
    values = [shapley_values.values for shapley_values in comparison_values]
    std = [shapley_values.std for shapley_values in comparison_values]

    # Perform sorting.
    if sort_features:
        argsort = np.argsort(values[0])[::-1]
        values = [shapley_values[argsort] for shapley_values in values]
        std = [stddev[argsort] for stddev in std]
        feature_names = np.array(feature_names)[argsort]

    # Remove extra features if necessary.
    if len(feature_names) > max_features:
        feature_names = (list(feature_names[:max_features])
                         + ['Remaining Features'])
        values = [
            list(shapley_values[:max_features])
            + [np.sum(shapley_values[max_features:])]
            for shapley_values in values]
        std = [list(stddev[:max_features])
               + [np.sum(stddev[max_features:] ** 2) ** 0.5]
               for stddev in std]

    # Warn if too many features.
    if len(feature_names) > 50:
        warnings.warn('Plotting {} features may make figure too crowded, '
                      'consider using max_features'.format(
                        len(feature_names)), Warning)

    # Discard std if necessary.
    if not error_bars:
        std = [None for _ in std]

    # Make plot.
    width = 0.8 / num_comps
    fig = plt.figure(figsize=figsize)
    ax = fig.gca()

    if orientation == 'horizontal':
        # Bar chart.
        enumeration = enumerate(zip(values, std, comparison_names, colors))
        for i, (shapley_values, stddev, name, color) in enumeration:
            pos = - 0.4 + width / 2 + width * i
            ax.barh(np.arange(len(feature_names))[::-1] - pos,
                    shapley_values, height=width, color=color, xerr=stddev,
                    label=name)

        # Feature labels.
        if tick_rotation is not None:
            raise ValueError('rotation not supported for horizontal charts')
        ax.set_yticks(np.arange(len(feature_names))[::-1])
        ax.set_yticklabels(feature_names, fontsize=label_size)

        # Axis labels and ticks.
        ax.set_ylabel('')
        ax.set_xlabel(axis_label, fontsize=label_size)
        ax.tick_params(axis='x', labelsize=tick_size)

    elif orientation == 'vertical':
        # Bar chart.
        enumeration = enumerate(zip(values, std, comparison_names, colors))
        for i, (shapley_values, stddev, name, color) in enumeration:
            pos = - 0.4 + width / 2 + width * i
            ax.bar(np.arange(len(feature_names)) + pos,
                   shapley_values, width=width, color=color, yerr=stddev,
                   label=name)

        # Feature labels.
        if tick_rotation is None:
            tick_rotation = 45
        if tick_rotation < 90:
            ha = 'right'
            rotation_mode = 'anchor'
        else:
            ha = 'center'
            rotation_mode = 'default'
        ax.set_xticks(np.arange(len(feature_names)))
        ax.set_xticklabels(feature_names, rotation=tick_rotation, ha=ha,
                           rotation_mode=rotation_mode,
                           fontsize=label_size)

        # Axis labels and ticks.
        ax.set_ylabel(axis_label, fontsize=label_size)
        ax.set_xlabel('')
        ax.tick_params(axis='y', labelsize=tick_size)

    # Remove spines.
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    plt.legend(loc=legend_loc, fontsize=label_size)
    ax.set_title(title, fontsize=title_size)
    plt.tight_layout()

    if return_fig:
        return fig
    else:
        return

import pickle
import numpy as np


def crossentropyloss(pred, target):
    '''Cross entropy loss that does not average across samples.'''
    if pred.ndim == 1:
        pred = pred[:, np.newaxis]
        pred = np.concatenate((1 - pred, pred), axis=1)

    if pred.shape == target.shape:
        # Soft cross entropy loss.
        pred = np.clip(pred, a_min=1e-12, a_max=1-1e-12)
        return - np.sum(np.log(pred) * target, axis=1)
    else:
        # Standard cross entropy loss.
        return - np.log(pred[np.arange(len(pred)), target])


def mseloss(pred, target):
    '''MSE loss that does not average across samples.'''
    if len(pred.shape) == 1:
        pred = pred[:, np.newaxis]
    if len(target.shape) == 1:
        target = target[:, np.newaxis]
    return np.sum((pred - target) ** 2, axis=1)


class ShapleyValues:
    '''For storing and plotting Shapley values.'''
    def __init__(self, values, std):
        self.values = values
        self.std = std

    def plot(self,
             feature_names=None,
             sort_features=True,
             max_features=np.inf,
             orientation='horizontal',
             error_bars=True,
             color='C0',
             title='Feature Importance',
             title_size=20,
             tick_size=16,
             tick_rotation=None,
             axis_label='',
             label_size=16,
             figsize=(10, 7),
             return_fig=False):
        '''
        Plot Shapley values.
        Args:
          feature_names: list of feature names.
          sort_features: whether to sort features by their Shapley values.
          max_features: number of features to display.
          orientation: horizontal (default) or vertical.
          error_bars: whether to include standard deviation error bars.
          color: bar chart color.
          title: plot title.
          title_size: font size for title.
          tick_size: font size for feature names and numerical values.
          tick_rotation: tick rotation for feature names (vertical plots only).
          label_size: font size for label.
          figsize: figure size (if fig is None).
          return_fig: whether to return matplotlib figure object.
        '''
        return plotting.plot(
            self, feature_names, sort_features, max_features, orientation,
            error_bars, color, title, title_size, tick_size, tick_rotation,
            axis_label, label_size, figsize, return_fig)

    def comparison(self,
                   other_values,
                   comparison_names=None,
                   feature_names=None,
                   sort_features=True,
                   max_features=np.inf,
                   orientation='vertical',
                   error_bars=True,
                   colors=None,
                   title='Shapley Value Comparison',
                   title_size=20,
                   tick_size=16,
                   tick_rotation=None,
                   axis_label='',
                   label_size=16,
                   legend_loc=None,
                   figsize=(10, 7),
                   return_fig=False):
        '''
        Plot comparison with another set of Shapley values.
        Args:
          other_values: another Shapley values object.
          comparison_names: tuple of names for each Shapley value object.
          feature_names: list of feature names.
          sort_features: whether to sort features by their Shapley values.
          max_features: number of features to display.
          orientation: horizontal (default) or vertical.
          error_bars: whether to include standard deviation error bars.
          colors: colors for each set of Shapley values.
          title: plot title.
          title_size: font size for title.
          tick_size: font size for feature names and numerical values.
          tick_rotation: tick rotation for feature names (vertical plots only).
          label_size: font size for label.
          legend_loc: legend location.
          figsize: figure size (if fig is None).
          return_fig: whether to return matplotlib figure object.
        '''
        return plotting.comparison_plot(
            (self, other_values), comparison_names, feature_names,
            sort_features, max_features, orientation, error_bars, colors, title,
            title_size, tick_size, tick_rotation, axis_label, label_size,
            legend_loc, figsize, return_fig)

    def save(self, filename):
        '''Save Shapley values object.'''
        if isinstance(filename, str):
            with open(filename, 'wb') as f:
                pickle.dump(self, f)
        else:
            raise TypeError('filename must be str')

    def __repr__(self):
        with np.printoptions(precision=2, threshold=12, floatmode='fixed'):
            return 'Shapley Values(\n  (Mean): {}\n  (Std):  {}\n)'.format(
                self.values, self.std)


def load(filename):
    '''Load Shapley values object.'''
    with open(filename, 'rb') as f:
        shapley_values = pickle.load(f)
        if isinstance(shapley_values, ShapleyValues):
            return shapley_values
        else:
            raise ValueError('object is not instance of ShapleyValues class')
        
import numpy as np


class CooperativeGame:
    '''Base class for cooperative games.'''

    def __init__(self):
        raise NotImplementedError

    def __call__(self, S):
        '''Evaluate cooperative game.'''
        raise NotImplementedError

    def grand(self):
        '''Get grand coalition value.'''
        return self.__call__(np.ones((1, self.players), dtype=bool))[0]

    def null(self):
        '''Get null coalition value.'''
        return self.__call__(np.zeros((1, self.players), dtype=bool))[0]


class PredictionGame(CooperativeGame):
    '''
    Cooperative game for an individual example's prediction.
    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
    '''

    def __init__(self, extension, sample, groups=None):
        # Add batch dimension to sample.
        if sample.ndim == 1:
            sample = sample[np.newaxis]
        elif sample.shape[0] != 1:
            raise ValueError('sample must have shape (ndim,) or (1,ndim)')

        self.extension = extension
        self.sample = sample

        # Store feature groups.
        num_features = sample.shape[1]
        if groups is None:
            self.players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.players = len(groups)
            self.groups_matrix = np.zeros(
                (len(groups), num_features), dtype=bool)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = True

        # Caching.
        self.sample_repeat = sample

    def __call__(self, S):
        '''
        Evaluate cooperative game.
        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        if len(S) != len(self.sample_repeat):
            self.sample_repeat = self.sample.repeat(len(S), 0)
        input_data = self.sample_repeat

        # Apply group transformation.
        if self.groups_matrix is not None:
            S = np.matmul(S, self.groups_matrix)

        # Evaluate.
        return self.extension(input_data, S)


class PredictionLossGame(CooperativeGame):
    '''
    Cooperative game for an individual example's loss value.
    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
      label: the input's true label.
      loss: loss function (see utils.py).
    '''

    def __init__(self, extension, sample, label, loss, groups=None):
        # Add batch dimension to sample.
        if sample.ndim == 1:
            sample = sample[np.newaxis]

        # Add batch dimension to label.
        if np.isscalar(label):
            label = np.array([label])

        # Convert label dtype if necessary.
        if loss is crossentropyloss:
            # Make sure not soft cross entropy.
            if (label.ndim <= 1) or (label.shape[1] == 1):
                # Only convert if float.
                if np.issubdtype(label.dtype, np.floating):
                    label = label.astype(int)

        self.extension = extension
        self.sample = sample
        self.label = label
        self.loss = loss

        # Store feature groups.
        num_features = sample.shape[1]
        if groups is None:
            self.players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.players = len(groups)
            self.groups_matrix = np.zeros(
                (len(groups), num_features), dtype=bool)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = True

        # Caching.
        self.sample_repeat = sample
        self.label_repeat = label

    def __call__(self, S):
        '''
        Evaluate cooperative game.
        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        if len(S) != len(self.sample_repeat):
            self.sample_repeat = self.sample.repeat(len(S), 0)
            self.label_repeat = self.label.repeat(len(S), 0)
        input_data = self.sample_repeat
        output_label = self.label_repeat

        # Apply group transformation.
        if self.groups_matrix is not None:
            S = np.matmul(S, self.groups_matrix)

        # Evaluate.
        return - self.loss(self.extension(input_data, S), output_label)

import numpy as np


class CooperativeGame:
    '''Base class for cooperative games.'''

    def __init__(self):
        raise NotImplementedError

    def __call__(self, S):
        '''Evaluate cooperative game.'''
        raise NotImplementedError

    def grand(self):
        '''Get grand coalition value.'''
        return self.__call__(np.ones((1, self.players), dtype=bool))[0]

    def null(self):
        '''Get null coalition value.'''
        return self.__call__(np.zeros((1, self.players), dtype=bool))[0]


class PredictionGame(CooperativeGame):
    '''
    Cooperative game for an individual example's prediction.
    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
    '''

    def __init__(self, extension, sample, groups=None):
        # Add batch dimension to sample.
        if sample.ndim == 1:
            sample = sample[np.newaxis]
        elif sample.shape[0] != 1:
            raise ValueError('sample must have shape (ndim,) or (1,ndim)')

        self.extension = extension
        self.sample = sample

        # Store feature groups.
        num_features = sample.shape[1]
        if groups is None:
            self.players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.players = len(groups)
            self.groups_matrix = np.zeros(
                (len(groups), num_features), dtype=bool)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = True

        # Caching.
        self.sample_repeat = sample

    def __call__(self, S):
        '''
        Evaluate cooperative game.
        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        if len(S) != len(self.sample_repeat):
            self.sample_repeat = self.sample.repeat(len(S), 0)
        input_data = self.sample_repeat

        # Apply group transformation.
        if self.groups_matrix is not None:
            S = np.matmul(S, self.groups_matrix)

        # Evaluate.
        return self.extension(input_data, S)


class PredictionLossGame(CooperativeGame):
    '''
    Cooperative game for an individual example's loss value.
    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
      label: the input's true label.
      loss: loss function (see utils.py).
    '''

    def __init__(self, extension, sample, label, loss, groups=None):
        # Add batch dimension to sample.
        if sample.ndim == 1:
            sample = sample[np.newaxis]

        # Add batch dimension to label.
        if np.isscalar(label):
            label = np.array([label])

        # Convert label dtype if necessary.
        if loss is crossentropyloss:
            # Make sure not soft cross entropy.
            if (label.ndim <= 1) or (label.shape[1] == 1):
                # Only convert if float.
                if np.issubdtype(label.dtype, np.floating):
                    label = label.astype(int)

        self.extension = extension
        self.sample = sample
        self.label = label
        self.loss = loss

        # Store feature groups.
        num_features = sample.shape[1]
        if groups is None:
            self.players = num_features
            self.groups_matrix = None
        else:
            # Verify groups.
            inds_list = []
            for group in groups:
                inds_list += list(group)
            assert np.all(np.sort(inds_list) == np.arange(num_features))

            # Map groups to features.
            self.players = len(groups)
            self.groups_matrix = np.zeros(
                (len(groups), num_features), dtype=bool)
            for i, group in enumerate(groups):
                self.groups_matrix[i, group] = True

        # Caching.
        self.sample_repeat = sample
        self.label_repeat = label

    def __call__(self, S):
        '''
        Evaluate cooperative game.
        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        if len(S) != len(self.sample_repeat):
            self.sample_repeat = self.sample.repeat(len(S), 0)
            self.label_repeat = self.label.repeat(len(S), 0)
        input_data = self.sample_repeat
        output_label = self.label_repeat

        # Apply group transformation.
        if self.groups_matrix is not None:
            S = np.matmul(S, self.groups_matrix)

        # Evaluate.
        return - self.loss(self.extension(input_data, S), output_label)

import numpy as np
from tqdm.auto import tqdm


def default_min_variance_samples(game):
    '''Determine min_variance_samples.'''
    return 5


def default_variance_batches(game, batch_size):
    '''
    Determine variance_batches.
    This value tries to ensure that enough samples are included to make A
    approximation non-singular.
    '''
    if isinstance(game, CooperativeGame):
        return int(np.ceil(10 * game.players / batch_size))
    else:
        # Require more intermediate samples for stochastic games.
        return int(np.ceil(25 * game.players / batch_size))


def calculate_result(A, b, total):
    '''Calculate the regression coefficients.'''
    num_players = A.shape[1]
    try:
        if len(b.shape) == 2:
            A_inv_one = np.linalg.solve(A, np.ones((num_players, 1)))
        else:
            A_inv_one = np.linalg.solve(A, np.ones(num_players))
        A_inv_vec = np.linalg.solve(A, b)
        values = (
            A_inv_vec -
            A_inv_one * (np.sum(A_inv_vec, axis=0, keepdims=True) - total)
            / np.sum(A_inv_one))
    except np.linalg.LinAlgError:
        raise ValueError('singular matrix inversion. Consider using larger '
                         'variance_batches')

    return values


def ShapleyRegression(game,
                      batch_size=512,
                      detect_convergence=True,
                      thresh=0.01,
                      n_samples=None,
                      paired_sampling=True,
                      return_all=False,
                      min_variance_samples=None,
                      variance_batches=None,
                      bar=True,
                      verbose=False):
    # Verify arguments.
    if isinstance(game, CooperativeGame):
        stochastic = False
    elif isinstance(game, StochasticCooperativeGame):
        stochastic = True
    else:
        raise ValueError('game must be CooperativeGame or '
                         'StochasticCooperativeGame')

    if min_variance_samples is None:
        min_variance_samples = default_min_variance_samples(game)
    else:
        assert isinstance(min_variance_samples, int)
        assert min_variance_samples > 1

    if variance_batches is None:
        variance_batches = default_variance_batches(game, batch_size)
    else:
        assert isinstance(variance_batches, int)
        assert variance_batches >= 1

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Weighting kernel (probability of each subset size).
    num_players = game.players
    weights = np.arange(1, num_players)
    weights = 1 / (weights * (num_players - weights))
    weights = weights / np.sum(weights)

    # Calculate null and grand coalitions for constraints.
    if stochastic:
        null = game.null(batch_size=batch_size)
        grand = game.grand(batch_size=batch_size)
    else:
        null = game.null()
        grand = game.grand()

    # Calculate difference between grand and null coalitions.
    total = grand - null

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    n = 0
    b = 0
    A = 0
    estimate_list = []

    # For variance estimation.
    A_sample_list = []
    b_sample_list = []

    # For tracking progress.
    var = np.nan * np.ones(num_players)
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        # Sample subsets.
        S = np.zeros((batch_size, num_players), dtype=bool)
        num_included = np.random.choice(num_players - 1, size=batch_size,
                                        p=weights) + 1
        for row, num in zip(S, num_included):
            inds = np.random.choice(num_players, size=num, replace=False)
            row[inds] = 1

        # Sample exogenous (if applicable).
        if stochastic:
            U = game.sample(batch_size)

        # Update estimators.
        if paired_sampling:
            # Paired samples.
            A_sample = 0.5 * (
                np.matmul(S[:, :, np.newaxis].astype(float),
                          S[:, np.newaxis, :].astype(float))
                + np.matmul(np.logical_not(S)[:, :, np.newaxis].astype(float),
                            np.logical_not(S)[:, np.newaxis, :].astype(float)))
            if stochastic:
                game_eval = game(S, U) - null
                S_comp = np.logical_not(S)
                comp_eval = game(S_comp, U) - null
                b_sample = 0.5 * (
                    S.astype(float).T * game_eval[:, np.newaxis].T
                    + S_comp.astype(float).T * comp_eval[:, np.newaxis].T).T
            else:
                game_eval = game(S) - null
                S_comp = np.logical_not(S)
                comp_eval = game(S_comp) - null
                b_sample = 0.5 * (
                    S.astype(float).T * game_eval[:, np.newaxis].T +
                    S_comp.astype(float).T * comp_eval[:, np.newaxis].T).T
        else:
            # Single sample.
            A_sample = np.matmul(S[:, :, np.newaxis].astype(float),
                                 S[:, np.newaxis, :].astype(float))
            if stochastic:
                b_sample = (S.astype(float).T
                            * (game(S, U) - null)[:, np.newaxis].T).T
            else:
                b_sample = (S.astype(float).T
                            * (game(S) - null)[:, np.newaxis].T).T

        # Welford's algorithm.
        n += batch_size
        b += np.sum(b_sample - b, axis=0) / n
        A += np.sum(A_sample - A, axis=0) / n

        # Calculate progress.
        values = calculate_result(A, b, total)
        A_sample_list.append(A_sample)
        b_sample_list.append(b_sample)
        if len(A_sample_list) == variance_batches:
            # Aggregate samples for intermediate estimate.
            A_sample = np.concatenate(A_sample_list, axis=0).mean(axis=0)
            b_sample = np.concatenate(b_sample_list, axis=0).mean(axis=0)
            A_sample_list = []
            b_sample_list = []

            # Add new estimate.
            estimate_list.append(calculate_result(A_sample, b_sample, total))

            # Estimate current var.
            if len(estimate_list) >= min_variance_samples:
                var = np.array(estimate_list).var(axis=0)

        # Convergence ratio.
        std = np.sqrt(var * variance_batches / (it + 1))
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))

        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(values)
            std_list.append(std)
            if detect_convergence:
                N_list.append(N_est)

    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (
            (np.arange(it + 1) + 1) * batch_size *
            (1 + int(paired_sampling)))
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return ShapleyValues(values, std), tracking_dict
    else:
        return ShapleyValues(values, std)

# EXACT - Code

In [None]:
import scipy
import numpy as np
import pandas as pd

from itertools import chain, combinations

class Estimator:
    """Class for any estimator."""

    def __init__(self, model, num_features):
        """Initialize with model."""

        self.model = model
        self.num_features = num_features

    def _explain(self, explicand, num_evals):

        raise NotImplementedError()

    def __call__(self, explicand, num_evals=100):

        raise NotImplementedError()

class Exact(Estimator):
    """Exact estimation (expoential in the number of features)"""

    def __init__(self, model, num_features):

        super().__init__(model, num_features)

        self.features = set(range(num_features))

    def _powerset(self, iterable):
        s = list(iterable)
        return chain.from_iterable(
            combinations(s, r) for r in range(len(s) + 1)
        )

    def _subset_on_off(self, subset, feature):

        subset_off = np.zeros(self.num_features)
        if subset:
            subset_off[np.array(subset)] = 1

        subset_on = np.copy(subset_off)
        subset_on[feature] = 1

        return (subset_on.astype("bool"), subset_off.astype("bool"))

    def _single_feature(self, feature, explicand, baselines, y):
        
        total_masked_on=[]
        total_masked_off=[]
        total_size=[]
        for baseline in baselines:
            if baseline.ndim == 1:
                baseline = baseline[np.newaxis]
            masked_samples = np.repeat(baseline, 2 ** self.num_features, 0)
            # print('MASKED SAMPLES SHAPE',masked_samples.shape)
            sizes = []

            subsets = self._powerset(self.features - set([feature])) # 2**N-1 features
            for i, subset in enumerate(subsets):

                subset_on, subset_off = self._subset_on_off(subset, feature)
                # print('SUBSET ON',subset_on)
                # print('SUBSET OFF',subset_off)

                sizes.append(subset_off.sum())

                masked_samples[i, subset_on] = explicand[subset_on] # first half of the samples
                masked_samples[2 ** (self.num_features - 1) + i, subset_off] = explicand[subset_off] # second half of the samples

            total_masked_on.append(masked_samples[:2 ** (self.num_features - 1)])
            total_masked_off.append(masked_samples[2 ** (self.num_features - 1):])
            total_size.extend(sizes)
        
        total_masked_on = np.concatenate(total_masked_on, axis=0)
        total_masked_off = np.concatenate(total_masked_off, axis=0)
        total_size = np.array(total_size)
        
        # print('TOTAL MASKED ON SHAPE',total_masked_on.shape)
        # print('TOTAL MASKED OFF SHAPE',total_masked_off.shape)
        # print('TOTAL SIZE SHAPE',total_size.shape)

        # Compute marginal contributions
        weights = 1 / scipy.special.comb(
            self.num_features - 1, total_size
        )
        weights /= self.num_features
        # print('WEIGHTS SHAPE',weights.shape)
        # preds = self.model(masked_samples)[:, y]
        # print('PRED SHAPE',preds.shape)
        # preds_on = preds[: 2 ** (self.num_features - 1)]
        preds_on = self.model(total_masked_on)[:, y]
        # print('PRED ON SHAPE',preds_on.shape)
        # preds_off = preds[2 ** (self.num_features - 1) :]
        preds_off = self.model(total_masked_off)[:, y]
        # print('PRED OFF SHAPE',preds_off.shape)
        deltas = weights * (preds_on - preds_off)

        return deltas.sum()/len(baselines)

    def _explain(self, explicand, baselines,y):

        phi = np.zeros(explicand.shape)

        for i in range(self.num_features):
            phi[i] = self._single_feature(i, explicand, baselines,y)

        return phi

    def __call__(self, explicand, baselines, y):

        return self._explain(explicand, baselines, y)

# SHAPSAMPLING - Code

In [None]:
import numpy as np
# from shapreg import utils, games, stochastic_games
from tqdm.auto import tqdm


def ShapleySampling(game,
                    batch_size=512,
                    detect_convergence=True,
                    thresh=0.01,
                    n_samples=None,
                    antithetical=False,
                    return_all=False,
                    bar=True,
                    verbose=False):
    # Verify arguments.
    if isinstance(game, CooperativeGame):
        stochastic = False
    elif isinstance(game, StochasticCooperativeGame):
        stochastic = True
    else:
        raise ValueError('game must be CooperativeGame or '
                         'StochasticCooperativeGame')

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Calculate null coalition value.
    if stochastic:
        null = game.null(batch_size=batch_size)
    else:
        null = game.null()

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    num_players = game.players
    if isinstance(null, np.ndarray):
        values = np.zeros((num_players, len(null)))
        sum_squares = np.zeros((num_players, len(null)))
        deltas = np.zeros((batch_size, num_players, len(null)))
    else:
        values = np.zeros((num_players))
        sum_squares = np.zeros((num_players))
        deltas = np.zeros((batch_size, num_players))
    permutations = np.tile(np.arange(game.players), (batch_size, 1))
    arange = np.arange(batch_size)
    n = 0

    # For tracking progress.
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        for i in range(batch_size):
            if antithetical and i % 2 == 1:
                permutations[i] = permutations[i - 1][::-1]
            else:
                np.random.shuffle(permutations[i])
        S = np.zeros((batch_size, game.players), dtype=bool)

        # Sample exogenous (if applicable).
        if stochastic:
            U = game.sample(batch_size)

        # Unroll permutations.
        prev_value = null
        for i in range(num_players):
            S[arange, permutations[:, i]] = 1
            if stochastic:
                next_value = game(S, U)
            else:
                next_value = game(S)
            deltas[arange, permutations[:, i]] = next_value - prev_value
            prev_value = next_value

        # Welford's algorithm.
        n += batch_size
        diff = deltas - values
        values += np.sum(diff, axis=0) / n
        diff2 = deltas - values
        sum_squares += np.sum(diff * diff2, axis=0)

        # Calculate progress.
        var = sum_squares / (n ** 2)
        std = np.sqrt(var)
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))

        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(np.copy(values))
            std_list.append(np.copy(std))
            if detect_convergence:
                N_list.append(N_est)

    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (np.arange(it + 1) + 1) * batch_size * num_players
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return ShapleyValues(values, std), tracking_dict
    else:
        return ShapleyValues(values, std)

# DeepExplainer - Code (TF support)

In [None]:
class Explainer(object):
    """ This is the superclass of all explainers.
    """

    def shap_values(self, X):
        raise Exception("SHAP values not implemented for this explainer!")

    def attributions(self, X):
        return self.shap_values(X)
    

import numpy as np
import warnings
# from shap.explainers.explainer import Explainer
from distutils.version import LooseVersion
keras = None
tf = None
tf_ops = None
tf_gradients_impl = None

class TFDeepExplainer(Explainer):
    """
    Using tf.gradients to implement the backgropagation was
    inspired by the gradient based implementation approach proposed by Ancona et al, ICLR 2018. Note
    that this package does not currently use the reveal-cancel rule for ReLu units proposed in DeepLIFT.
    """

    def __init__(self, model, data, session=None, learning_phase_flags=None):
        """ An explainer object for a deep model using a given background dataset.

        Note that the complexity of the method scales linearly with the number of background data
        samples. Passing the entire training dataset as `data` will give very accurate expected
        values, but be unreasonably expensive. The variance of the expectation estimates scale by
        roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate,
        and 1000 samples a very good estimate of the expected values.

        Parameters
        ----------
        model : keras.Model or (input : [tf.Operation], output : tf.Operation)
            A keras model object or a pair of TensorFlow operations (or a list and an op) that
            specifies the input and output of the model to be explained. Note that SHAP values
            are specific to a single output value, so you get an explanation for each element of
            the output tensor (which must be a flat rank one vector).

        data : [numpy.array] or [pandas.DataFrame] or function
            The background dataset to use for integrating out features. DeepExplainer integrates
            over all these samples for each explanation. The data passed here must match the input
            operations given to the model. If a function is supplied, it must be a function that
            takes a particular input example and generates the background dataset for that example
        session : None or tensorflow.Session
            The TensorFlow session that has the model we are explaining. If None is passed then
            we do our best to find the right session, first looking for a keras session, then
            falling back to the default TensorFlow session.

        learning_phase_flags : None or list of tensors
            If you have your own custom learning phase flags pass them here. When explaining a prediction
            we need to ensure we are not in training mode, since this changes the behavior of ops like
            batch norm or dropout. If None is passed then we look for tensors in the graph that look like
            learning phase flags (this works for Keras models). Note that we assume all the flags should
            have a value of False during predictions (and hence explanations).
            
        """

        # warnings.warn(
        #     "Please keep in mind DeepExplainer is brand new, and we are still developing it and working on " +
        #     "characterizing/testing it on large networks. This means you should keep an eye out for odd " +
        #     "behavior. Post any issues you run into on github."
        # )

        # try and import keras and tensorflow
        global tf, tf_ops, tf_gradients_impl
        if tf is None:
            from tensorflow.python.framework import ops as tf_ops # pylint: disable=E0611
            from tensorflow.python.ops import gradients_impl as tf_gradients_impl # pylint: disable=E0611
            if not hasattr(tf_gradients_impl, "_IsBackpropagatable"):
                from tensorflow.python.ops import gradients_util as tf_gradients_impl
            import tensorflow as tf
            if LooseVersion(tf.__version__) < LooseVersion("1.4.0"):
                warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.")
        global keras
        if keras is None:
            try:
                import keras
                if LooseVersion(keras.__version__) < LooseVersion("2.1.0"):
                    warnings.warn("Your Keras version is older than 2.1.0 and not supported.")
            except:
                pass

        # determine the model inputs and outputs
        if str(type(model)).endswith("keras.engine.sequential.Sequential'>"):
            self.model_inputs = model.inputs
            self.model_output = model.layers[-1].output
        elif str(type(model)).endswith("keras.models.Sequential'>"):
            self.model_inputs = model.inputs
            self.model_output = model.layers[-1].output
        elif str(type(model)).endswith("keras.engine.training.Model'>"):
            self.model_inputs = model.inputs
            self.model_output = model.layers[-1].output
        elif str(type(model)).endswith("tuple'>"):
            self.model_inputs = model[0]
            self.model_output = model[1]
        else:
            assert False, str(type(model)) + " is not currently a supported model type!"
        assert type(self.model_output) != list, "The model output to be explained must be a single tensor!"
        assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
        self.multi_output = True
        if len(self.model_output.shape) == 1:
            self.multi_output = False

        # check if we have multiple inputs
        self.multi_input = True
        if type(self.model_inputs) != list or len(self.model_inputs) == 1:
            self.multi_input = False
            if type(self.model_inputs) != list:
                self.model_inputs = [self.model_inputs]
        if type(data) != list and (hasattr(data, '__call__')==False):
            data = [data]
        self.data = data
        
        self._vinputs = {} # used to track what op inputs depends on the model inputs
        self.orig_grads = {}
        
        # if we are not given a session find a default session
        if session is None:
            # if keras is installed and already has a session then use it
            ksess = None
            if hasattr(keras.backend.tensorflow_backend, "_SESSION"):
                ksess = keras.backend.tensorflow_backend._SESSION
            elif hasattr(keras.backend.tensorflow_backend.tf_keras_backend._SESSION, "session"):
                ksess = keras.backend.tensorflow_backend.tf_keras_backend._SESSION.session
            if keras is not None and ksess is not None:
                session = keras.backend.get_session()
            else:
                try:
                    session = tf.compat.v1.keras.backend.get_session()
                except:
                    session = tf.keras.backend.get_session()
        self.session = tf.get_default_session() if session is None else session

        # if no learning phase flags were given we go looking for them
        # ...this will catch the one that keras uses
        # we need to find them since we want to make sure learning phase flags are set to False
        if learning_phase_flags is None:
            self.learning_phase_ops = []
            for op in self.session.graph.get_operations():
                if 'learning_phase' in op.name and op.type == "Const" and len(op.outputs[0].shape) == 0:
                    if op.outputs[0].dtype == tf.bool:
                        self.learning_phase_ops.append(op)
            self.learning_phase_flags = [op.outputs[0] for op in self.learning_phase_ops]
        else:
            self.learning_phase_ops = [t.op for t in learning_phase_flags]

        # save the expected output of the model
        # if self.data is a function, set self.expected_value to None
        if (hasattr(self.data, '__call__')):
            self.expected_value = None
        else:
            if self.data[0].shape[0] > 5000:
                warnings.warn("You have provided over 5k background samples! For better performance consider using smaller random sample.")
            self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)

        # find all the operations in the graph between our inputs and outputs
        tensor_blacklist = tensors_blocked_by_false(self.learning_phase_ops) # don't follow learning phase branches
        dependence_breakers = [k for k in op_handlers if op_handlers[k] == break_dependence]
        back_ops = backward_walk_ops(
            [self.model_output.op], tensor_blacklist,
            dependence_breakers
        )
        self.between_ops = forward_walk_ops(
            [op for input in self.model_inputs for op in input.consumers()],
            tensor_blacklist, dependence_breakers,
            within_ops=back_ops
        )

        # save what types are being used
        self.used_types = {}
        for op in self.between_ops:
            self.used_types[op.type] = True

        # make a blank array that will get lazily filled in with the SHAP value computation
        # graphs for each output. Lazy is important since if there are 1000 outputs and we
        # only explain the top 5 it would be a waste to build graphs for the other 995
        if not self.multi_output:
            self.phi_symbolics = [None]
        else:
            noutputs = self.model_output.shape.as_list()[1]
            if noutputs is not None:
                self.phi_symbolics = [None for i in range(noutputs)]
            else:
                raise Exception("The model output tensor to be explained cannot have a static shape in dim 1 of None!")

    def _variable_inputs(self, op):
        """ Return which inputs of this operation are variable (i.e. depend on the model inputs).
        """
        if op.name not in self._vinputs:
            self._vinputs[op.name] = np.array([t.op in self.between_ops or t in self.model_inputs for t in op.inputs])
        return self._vinputs[op.name]

    def phi_symbolic(self, i):
        """ Get the SHAP value computation graph for a given model output.
        """
        if self.phi_symbolics[i] is None:

            # replace the gradients for all the non-linear activations
            # we do this by hacking our way into the registry (TODO: find a public API for this if it exists)
            reg = tf_ops._gradient_registry._registry
            for n in op_handlers:
                if n in reg:
                    self.orig_grads[n] = reg[n]["type"]
                    if op_handlers[n] is not passthrough:
                        reg[n]["type"] = self.custom_grad
                elif n in self.used_types:
                    raise Exception(n + " was used in the model but is not in the gradient registry!")
            # In TensorFlow 1.10 they started pruning out nodes that they think can't be backpropped
            # unfortunately that includes the index of embedding layers so we disable that check here
            if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
                orig_IsBackpropagatable = tf_gradients_impl._IsBackpropagatable
                tf_gradients_impl._IsBackpropagatable = lambda tensor: True
            
            # define the computation graph for the attribution values using custom a gradient-like computation
            try:
                out = self.model_output[:,i] if self.multi_output else self.model_output
                self.phi_symbolics[i] = tf.gradients(out, self.model_inputs)

            finally:

                # reinstate the backpropagatable check
                if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
                    tf_gradients_impl._IsBackpropagatable = orig_IsBackpropagatable

                # restore the original gradient definitions
                for n in op_handlers:
                    if n in reg:
                        reg[n]["type"] = self.orig_grads[n]
        return self.phi_symbolics[i]

    def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True):

        # check if we have multiple inputs
        if not self.multi_input:
            if type(X) == list and len(X) != 1:
                assert False, "Expected a single tensor as model input!"
            elif type(X) != list:
                X = [X]
        else:
            assert type(X) == list, "Expected a list of model inputs!"
        assert len(self.model_inputs) == len(X), "Number of model inputs (%d) does not match the number given (%d)!" % (len(self.model_inputs), len(X))

        # rank and determine the model outputs that we will explain
        if ranked_outputs is not None and self.multi_output:
            model_output_values = self.run(self.model_output, self.model_inputs, X)
            if output_rank_order == "max":
                model_output_ranks = np.argsort(-model_output_values)
            elif output_rank_order == "min":
                model_output_ranks = np.argsort(model_output_values)
            elif output_rank_order == "max_abs":
                model_output_ranks = np.argsort(np.abs(model_output_values))
            else:
                assert False, "output_rank_order must be max, min, or max_abs!"
            model_output_ranks = model_output_ranks[:,:ranked_outputs]
        else:
            model_output_ranks = np.tile(np.arange(len(self.phi_symbolics)), (X[0].shape[0], 1))

        # compute the attributions
        output_phis = []
        for i in range(model_output_ranks.shape[1]):
            phis = []
            for k in range(len(X)):
                phis.append(np.zeros(X[k].shape))
            for j in range(X[0].shape[0]):
                if (hasattr(self.data, '__call__')):
                    bg_data = self.data([X[l][j] for l in range(len(X))])
                    if type(bg_data) != list:
                        bg_data = [bg_data]
                else:
                    bg_data = self.data
                # tile the inputs to line up with the background data samples
                tiled_X = [np.tile(X[l][j:j+1], (bg_data[l].shape[0],) + tuple([1 for k in range(len(X[l].shape)-1)])) for l in range(len(X))]
                # we use the first sample for the current sample and the rest for the references
                joint_input = [np.concatenate([tiled_X[l], bg_data[l]], 0) for l in range(len(X))]
                # run attribution computation graph
                feature_ind = model_output_ranks[j,i]
                sample_phis = self.run(self.phi_symbolic(feature_ind), self.model_inputs, joint_input)

                # assign the attributions to the right part of the output arrays
                for l in range(len(X)):
                    phis[l][j] = (sample_phis[l][bg_data[l].shape[0]:] * (X[l][j] - bg_data[l])).mean(0)

            output_phis.append(phis[0] if not self.multi_input else phis)
        
        if check_additivity:
            self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)
            model_output = self.run(self.model_output, self.model_inputs, X)
            for l in range(len(X)):
                diffs = model_output[:, l] - self.expected_value[l] - output_phis[l].sum(axis=tuple(range(1, output_phis[l].ndim)))
                assert np.abs(diffs).max() < 1e-4, "Explanations do not sum up to the model's output! Please post as a github issue."
        if not self.multi_output:
            return output_phis[0]
        elif ranked_outputs is not None:
            return output_phis, model_output_ranks
        else:
            return output_phis

    def run(self, out, model_inputs, X):
        """ Runs the model while also setting the learning phase flags to False.
        """
        feed_dict = dict(zip(model_inputs, X))
        for t in self.learning_phase_flags:
            feed_dict[t] = False
        return self.session.run(out, feed_dict)

    def custom_grad(self, op, *grads):
        """ Passes a gradient op creation request to the correct handler.
        """
        return op_handlers[op.type](self, op, *grads)


def tensors_blocked_by_false(ops):
    blocked = []
    def recurse(op):
        if op.type == "Switch":
            blocked.append(op.outputs[1]) # the true path is blocked since we assume the ops we trace are False
        else:
            for out in op.outputs:
                for c in out.consumers():
                    recurse(c)
    for op in ops:
        recurse(op)

    return blocked

def backward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist):
    found_ops = []
    op_stack = [op for op in start_ops]
    while len(op_stack) > 0:
        op = op_stack.pop()
        if op.type not in op_type_blacklist and op not in found_ops:
            found_ops.append(op)
            for input in op.inputs:
                if input not in tensor_blacklist:
                    op_stack.append(input.op)
    return found_ops

def forward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist, within_ops):
    found_ops = []
    op_stack = [op for op in start_ops]
    while len(op_stack) > 0:
        op = op_stack.pop()
        if op.type not in op_type_blacklist and op in within_ops and op not in found_ops:
            found_ops.append(op)
            for out in op.outputs:
                if out not in tensor_blacklist:
                    for c in out.consumers():
                        op_stack.append(c)
    return found_ops


def softmax(explainer, op, *grads):

    in0 = op.inputs[0]
    in0_max = tf.reduce_max(in0, axis=-1, keepdims=True, name="in0_max")
    in0_centered = in0 - in0_max
    evals = tf.exp(in0_centered, name="custom_exp")
    rsum = tf.reduce_sum(evals, axis=-1, keepdims=True)
    div = evals / rsum
    explainer.between_ops.extend([evals.op, rsum.op, div.op, in0_centered.op]) # mark these as in-between the inputs and outputs
    out = tf.gradients(div, in0_centered, grad_ys=grads[0])[0]
    del explainer.between_ops[-4:]

    # rescale to account for our shift by in0_max (which we did for numerical stability)
    xin0,rin0 = tf.split(in0, 2)
    xin0_centered,rin0_centered = tf.split(in0_centered, 2)
    delta_in0 = xin0 - rin0
    dup0 = [2] + [1 for i in delta_in0.shape[1:]]
    return tf.where(
        tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
        out,
        out * tf.tile((xin0_centered - rin0_centered) / delta_in0, dup0)
    )

def maxpool(explainer, op, *grads):
    xin0,rin0 = tf.split(op.inputs[0], 2)
    xout,rout = tf.split(op.outputs[0], 2)
    delta_in0 = xin0 - rin0
    dup0 = [2] + [1 for i in delta_in0.shape[1:]]
    cross_max = tf.maximum(xout, rout)
    diffs = tf.concat([cross_max - rout, xout - cross_max], 0)
    xmax_pos,rmax_pos = tf.split(explainer.orig_grads[op.type](op, grads[0] * diffs), 2)
    return tf.tile(tf.where(
        tf.abs(delta_in0) < 1e-7,
        tf.zeros_like(delta_in0),
        (xmax_pos + rmax_pos) / delta_in0
    ), dup0)

def gather(explainer, op, *grads):
    #params = op.inputs[0]
    indices = op.inputs[1]
    #axis = op.inputs[2]
    var = explainer._variable_inputs(op)
    if var[1] and not var[0]:
        assert len(indices.shape) == 2, "Only scalar indices supported right now in GatherV2!"

        xin1,rin1 = tf.split(tf.to_float(op.inputs[1]), 2)
        xout,rout = tf.split(op.outputs[0], 2)
        dup_in1 = [2] + [1 for i in xin1.shape[1:]]
        dup_out = [2] + [1 for i in xout.shape[1:]]
        delta_in1_t = tf.tile(xin1 - rin1, dup_in1)
        out_sum = tf.reduce_sum(grads[0] * tf.tile(xout - rout, dup_out), list(range(len(indices.shape), len(grads[0].shape))))
        if op.type == "ResourceGather":
            return [None, tf.where(
                tf.abs(delta_in1_t) < 1e-6,
                tf.zeros_like(delta_in1_t),
                out_sum / delta_in1_t
            )]
        return [None, tf.where(
            tf.abs(delta_in1_t) < 1e-6,
            tf.zeros_like(delta_in1_t),
            out_sum / delta_in1_t
        ), None]
    elif var[0] and not var[1]:
        return [explainer.orig_grads[op.type](op, grads[0]), None] # linear in this case
    else:
        assert False, "Axis not yet supported to be varying for gather op!"

def linearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
    def handler(explainer, op, *grads):
        var = explainer._variable_inputs(op)
        if var[input_ind0] and not var[input_ind1]:
            return linearity_1d_handler(input_ind0, explainer, op, *grads)
        elif var[input_ind1] and not var[input_ind0]:
            return linearity_1d_handler(input_ind1, explainer, op, *grads)
        elif var[input_ind0] and var[input_ind1]:
            return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
        else:
            return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
    return handler

def nonlinearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
    def handler(explainer, op, *grads):
        var = explainer._variable_inputs(op)
        if var[input_ind0] and not var[input_ind1]:
            return nonlinearity_1d_handler(input_ind0, explainer, op, *grads)
        elif var[input_ind1] and not var[input_ind0]:
            return nonlinearity_1d_handler(input_ind1, explainer, op, *grads)
        elif var[input_ind0] and var[input_ind1]:
            return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
        else: 
            return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
    return handler

def nonlinearity_1d(input_ind):
    def handler(explainer, op, *grads):
        return nonlinearity_1d_handler(input_ind, explainer, op, *grads)
    return handler

def nonlinearity_1d_handler(input_ind, explainer, op, *grads):

    # make sure only the given input varies
    for i in range(len(op.inputs)):
        if i != input_ind:
            assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
    
    xin0,rin0 = tf.split(op.inputs[input_ind], 2)
    xout,rout = tf.split(op.outputs[input_ind], 2)
    delta_in0 = xin0 - rin0
    dup0 = [2] + [1 for i in delta_in0.shape[1:]]
    out = [None for _ in op.inputs]
    orig_grads = explainer.orig_grads[op.type](op, grads[0])
    out[input_ind] = tf.where(
        tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
        orig_grads[input_ind] if len(op.inputs) > 1 else orig_grads,
        grads[0] * tf.tile((xout - rout) / delta_in0, dup0)
    )
    return out

def nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads):
    assert input_ind0 == 0 and input_ind1 == 1, "TODO: Can't yet handle double inputs that are not first!"
    xout,rout = tf.split(op.outputs[0], 2)
    xin0,rin0 = tf.split(op.inputs[input_ind0], 2)
    xin1,rin1 = tf.split(op.inputs[input_ind1], 2)
    delta_in0 = xin0 - rin0
    delta_in1 = xin1 - rin1
    dup0 = [2] + [1 for i in delta_in0.shape[1:]]
    out10 = op_func(xin0, rin1)
    out01 = op_func(rin0, xin1)
    out11,out00 = xout,rout
    out0 = 0.5 * (out11 - out01 + out10 - out00)
    out0 = grads[0] * tf.tile(out0 / delta_in0, dup0)
    out1 = 0.5 * (out11 - out10 + out01 - out00)
    out1 = grads[0] * tf.tile(out1 / delta_in1, dup0)

    # see if due to broadcasting our gradient shapes don't match our input shapes
    if (np.any(np.array(out1.shape) != np.array(delta_in1.shape))):
        broadcast_index = np.where(np.array(out1.shape) != np.array(delta_in1.shape))[0][0]
        out1 = tf.reduce_sum(out1, axis=broadcast_index, keepdims=True)
    elif (np.any(np.array(out0.shape) != np.array(delta_in0.shape))):
        broadcast_index = np.where(np.array(out0.shape) != np.array(delta_in0.shape))[0][0]
        out0 = tf.reduce_sum(out0, axis=broadcast_index, keepdims=True)

    # Avoid divide by zero nans
    out0 = tf.where(tf.abs(tf.tile(delta_in0, dup0)) < 1e-7, tf.zeros_like(out0), out0)
    out1 = tf.where(tf.abs(tf.tile(delta_in1, dup0)) < 1e-7, tf.zeros_like(out1), out1)

    return [out0, out1]

def linearity_1d(input_ind):
    def handler(explainer, op, *grads):
        return linearity_1d_handler(input_ind, explainer, op, *grads)
    return handler

def linearity_1d_handler(input_ind, explainer, op, *grads):
    # make sure only the given input varies (negative means only that input cannot vary, and is measured from the end of the list)
    for i in range(len(op.inputs)):
        if i != input_ind:
            assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
    return explainer.orig_grads[op.type](op, *grads)

def linearity_with_excluded(input_inds):
    def handler(explainer, op, *grads):
        return linearity_with_excluded_handler(input_inds, explainer, op, *grads)
    return handler

def linearity_with_excluded_handler(input_inds, explainer, op, *grads):
    # make sure the given inputs don't vary (negative is measured from the end of the list)
    for i in range(len(op.inputs)):
        if i in input_inds or i - len(op.inputs) in input_inds:
            assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
    return explainer.orig_grads[op.type](op, *grads)

def passthrough(explainer, op, *grads):
    return explainer.orig_grads[op.type](op, *grads)

def break_dependence(explainer, op, *grads):
    """ This function name is used to break attribution dependence in the graph traversal.
     
    These operation types may be connected above input data values in the graph but their outputs
    don't depend on the input values (for example they just depend on the shape).
    """
    return [None for _ in op.inputs]


op_handlers = {}

# ops that are always linear
op_handlers["Identity"] = passthrough
op_handlers["StridedSlice"] = passthrough
op_handlers["Squeeze"] = passthrough
op_handlers["ExpandDims"] = passthrough
op_handlers["Pack"] = passthrough
op_handlers["BiasAdd"] = passthrough
op_handlers["Unpack"] = passthrough
op_handlers["Add"] = passthrough
op_handlers["Sub"] = passthrough
op_handlers["Merge"] = passthrough
op_handlers["Sum"] = passthrough
op_handlers["Mean"] = passthrough
op_handlers["Cast"] = passthrough
op_handlers["Transpose"] = passthrough
op_handlers["Enter"] = passthrough
op_handlers["Exit"] = passthrough
op_handlers["NextIteration"] = passthrough
op_handlers["Tile"] = passthrough
op_handlers["TensorArrayScatterV3"] = passthrough
op_handlers["TensorArrayReadV3"] = passthrough
op_handlers["TensorArrayWriteV3"] = passthrough

# ops that don't pass any attributions to their inputs
op_handlers["Shape"] = break_dependence
op_handlers["RandomUniform"] = break_dependence
op_handlers["ZerosLike"] = break_dependence
#op_handlers["StopGradient"] = break_dependence # this allows us to stop attributions when we want to (like softmax re-centering)

# ops that are linear and only allow a single input to vary
op_handlers["Reshape"] = linearity_1d(0)
op_handlers["Pad"] = linearity_1d(0)
op_handlers["ReverseV2"] = linearity_1d(0)
op_handlers["ConcatV2"] = linearity_with_excluded([-1])
op_handlers["Conv2D"] = linearity_1d(0)
op_handlers["Switch"] = linearity_1d(0)
op_handlers["AvgPool"] = linearity_1d(0)
op_handlers["FusedBatchNorm"] = linearity_1d(0)

# ops that are nonlinear and only allow a single input to vary
op_handlers["Relu"] = nonlinearity_1d(0)
op_handlers["Elu"] = nonlinearity_1d(0)
op_handlers["Sigmoid"] = nonlinearity_1d(0)
op_handlers["Tanh"] = nonlinearity_1d(0)
op_handlers["Softplus"] = nonlinearity_1d(0)
op_handlers["Exp"] = nonlinearity_1d(0)
op_handlers["ClipByValue"] = nonlinearity_1d(0)
op_handlers["Rsqrt"] = nonlinearity_1d(0)
op_handlers["Square"] = nonlinearity_1d(0)
op_handlers["Max"] = nonlinearity_1d(0)

# ops that are nonlinear and allow two inputs to vary
op_handlers["SquaredDifference"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: (x - y) * (x - y))
op_handlers["Minimum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.minimum(x, y))
op_handlers["Maximum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.maximum(x, y))

# ops that allow up to two inputs to vary are are linear when only one input varies
op_handlers["Mul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x * y)
op_handlers["RealDiv"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x / y)
op_handlers["MatMul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.matmul(x, y))

# ops that need their own custom attribution functions
op_handlers["GatherV2"] = gather
op_handlers["ResourceGather"] = gather
op_handlers["MaxPool"] = maxpool
op_handlers["Softmax"] = softmax

class DeepExplainer(Explainer):

    def __init__(self, model, data, session=None, learning_phase_flags=None):
        self.explainer = TFDeepExplainer(model, data, session, learning_phase_flags)
        self.expected_value = self.explainer.expected_value

    def shap_values(self, X, ranked_outputs=None, output_rank_order='max'):

        return self.explainer.shap_values(X, ranked_outputs, output_rank_order)

# Monte Carlo - Code

In [None]:
def MonteCarloIter(IND, DATA, MODEL, y):
    svs=[]
    x=DATA.iloc[IND]
    TH=0
    M=1000
    error_sv_m=0
    svs_iter=[]
    for j in range(num_features):
        # calculate the shaply value for feature j
        n_features = len(x)
        marginal_contributions = []
        feature_idxs = list(range(n_features))
        feature_idxs.remove(j)
        feat_iter=[]
        for iter in range(M):
            z = DATA.sample(1).values[0]
            x_idx = random.sample(feature_idxs, min(max(int(0.2*n_features), random.choice(feature_idxs)), int(0.8*n_features)))
            z_idx = [idx for idx in feature_idxs if idx not in x_idx]
            
            # construct two new instances
            x_plus_j = np.array([x[i] if i in x_idx + [j] else z[i] for i in range(n_features)])
            x_minus_j = np.array([z[i] if i in z_idx + [j] else x[i] for i in range(n_features)])
            
            # calculate marginal contribution
            marginal_contribution = MODEL.predict(x_plus_j.reshape(1, -1))[0][y] - MODEL.predict(x_minus_j.reshape(1, -1))[0][y] ## ADAPT
            marginal_contributions.append(marginal_contribution)
            feat_iter.append(sum(marginal_contributions) / len(marginal_contributions))
            # if iter%32==0:
            #     feat_iter.append(sum(marginal_contributions) / len(marginal_contributions))
            
        phi_j_x = sum(marginal_contributions) / len(marginal_contributions)  # our shaply value
        svs.append(phi_j_x)
        svs_iter.append(feat_iter)
     
    # change svs_iter such that the final dimension is MxNumFeatures
    svs_iter=np.array(svs_iter)
    return svs, svs_iter.T

# Evaluate Explainers

In [None]:
from DASP.dasp.dasp import DASP

In [None]:
ts=[]
fs=[]
ls=[]
ksv=[]
psv=[]
ssv=[]
msv=[]
esv=[]
daspsv=[]
deepsv=[]
gradsv=[]

df_X_Test_s=pd.DataFrame(X_test_s)
df_X_train_s=pd.DataFrame(X_train_s)

# pexplainer = shap.explainers.Permutation(bbm_model.predict_proba, df_X_train_s)
# sexplainer = shap.explainers.Sampling(bbm_model.predict_proba, df_X_train_s)
# eexplainer = shap.explainers.Exact(bbm_model.predict_proba, X_train_s)
# imputer = MarginalExtension(X_test_s[:128], bbm_model.predict_proba)
# exact = Exact(bbm_model.predict_proba, num_features)

exact = Exact(bbmodel.predict, num_features)
eexplainer = shap.explainers.Exact(bbmodel.predict, X_train_s)
imputer = MarginalExtension(X_test_s[:128], bbmodel.predict)
exact = Exact(bbmodel.predict, num_features)

deepexpl = DeepExplainer(bbmodel, X_train_s[:1000])
gradexpl = shap.GradientExplainer(bbmodel, X_train_s[:1000])
daspex = DASP(Model(bbmodel.inputs, bbmodel.layers[-2].output))
# deep = DeepShap(bbm_model, X_train_s, num_features)

time_fs=[]
time_ls=[]  
time_sr=[]
time_mc=[]
time_psv=[]
time_ssv=[]
time_ex=[]
time_dasp=[]
time_deep=[]
time_grad=[]

unbiased_kernelshap_curves = []
kernelshap_curves = []
# list_iters=np.arange(64, 1024+1, 64)
UPPER=800
if UPPER==800:
    TAKE=25
    TAKE2=80
else:
    TAKE=32
    TAKE2=103

list_iters=np.arange(32, UPPER+1, 32)
perm_iters=np.arange(10, UPPER+1, 10)
mc_iters=np.arange(1, UPPER+1, 1)
print(num_features)
print(len(list_iters), len(perm_iters), len(mc_iters))

# imputer = MarginalExtension(X_test_s[:128], bbm_model.predict_proba)
imputer = MarginalExtension(X_test_s[:128], bbmodel.predict)
thresh = 0.01
samples = 12228

l2_distance_fs=[]
l2_distance_ls=[]
l2_distance_ks=[]
l2_distance_mc=[]
l2_distance_uks=[]
l2_distance_psv=[]
l2_distance_dasp=[]
l2_distance_deep=[]
l2_distance_grad=[]

kernelshap_iters=128

for ind in tqdm(range(25)):
    x = X_test_s[ind:ind+1]
    y = int(Y_test[ind])

    # set all seeds
    SEED = 42
    np.random.seed(SEED)
    random.seed(SEED)
    torch.manual_seed(SEED)

    # Run FastSHAP
    start = time.time()
    fastshap_values = fastshap.shap_values(x)[0]
    end = time.time()
    time_fs.append(end-start)
    fsv=fastshap_values[:, y]

    # Run LightningSHAP
    start = time.time()
    lshap_values=lshap.shap_values(x)[0]
    end = time.time()
    time_ls.append(end-start)
    lsv=lshap_values[:, y]
    
    # Run KernelSHAP and Unbiased KernelSHAP
    game = PredictionGame(imputer, x)
    start = time.time()
    # uks, ks = ShapleyRegression(game, batch_size=32, thresh=thresh, bar=False, paired_sampling=True, return_all=True)
    uks, ks = ShapleyRegression(game, batch_size=32, thresh=thresh, bar=False, paired_sampling=False, return_all=True)
    end = time.time()
    time_sr.append(end-start)
    # target_uks=np.array([ el[:,y] for el in uks['values']])[:16,:]
    # target_ks=np.array([ el[:,y] for el in ks['values']])[:16,:]
    # target_uks=np.array([ el[:,y] for el in uks['values']])[:TAKE,:]
    target_uks=uks.values[:,y]
    # target_ks=np.array([ el[:,y] for el in ks['values']])[:TAKE,:]
    target_ks=ks['values'][list(ks['iters']).index(kernelshap_iters)][:, y]
    # print(target_uks.shape)
    # print(target_ks.shape)
    

    # Run Permutation
    start = time.time()
    # results = ShapleySampling(game, batch_size=1, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False, bar=False, antithetical=True, return_all=True)
    
    
    results = ShapleySampling(game, batch_size=1, n_samples=int(np.ceil(samples / num_features)), detect_convergence=False, bar=False, return_all=True)
    
    # print(results[1]['values'][-1][:,y])
    # break
    end = time.time()
    time_psv.append(end-start)
    # permutation = np.array([explanation[:,y] for explanation in results[1]['values']])[:-1,:]
    permutation = results[1]['values'][-1][:,y]

    # Run MonteCarlo
    start = time.time()
    mc, mc_iter= MonteCarloIter(ind, df_X_train_s, bbmodel, y)
    end = time.time()
    time_mc.append(end-start)
    # print(mc_iter.shape)

    # Run DASP
    start = time.time()
    dasp_values = daspex.run(x, num_features)
    end = time.time()
    time_dasp.append(end-start)
    daspv=dasp_values[0][y]

    # Run DeepExplainer
    start = time.time()
    deep_values = deepexpl.shap_values(x)
    end = time.time()
    time_deep.append(end-start)
    deepv=deep_values[y][0]

    # Run GradientExplainer
    start = time.time()
    grad_values = gradexpl.shap_values(x)
    end = time.time()
    time_grad.append(end-start)
    gradv=grad_values[y][0]

    # Run Exact SHAP
    if num_features>=15:
        ev=target_uks
    else:
        start = time.time()
        # ev = eexplainer(x).values[0][:,y]
        ev = exact(x[0], X_test_s[:128], y)
        end = time.time()
        time_ex.append(end-start)

    ts.append(target_uks)
    ksv.append(target_ks)
    fs.append(fsv)
    ls.append(lsv)
    psv.append(permutation)
    msv.append(mc)
    esv.append(ev)
    daspsv.append(daspv)
    deepsv.append(deepv)
    gradsv.append(gradv)

    l2_distance_fs.append(np.linalg.norm(ev-fsv))
    l2_distance_ls.append(np.linalg.norm(ev-lsv))

    # l2_kshap=np.array([euclidean_dist(ev,el) for el in target_ks])
    # l2_kshap=np.array([np.linalg.norm(ev-el) for el in target_ks])
    # print(l2_kshap.shape)
    l2_distance_ks.append(np.linalg.norm(ev-target_ks))

    # l2_ukshap=np.array([np.linalg.norm(ev-el) for el in target_uks])
    # print(l2_ukshap.shape)
    l2_distance_uks.append(np.linalg.norm(ev-target_uks))

    # l2_pshap=np.array([np.linalg.norm(ev-el) for el in permutation])
    # print(l2_pshap.shape)
    l2_distance_psv.append(np.linalg.norm(ev-permutation))

    # l2_mcshap=np.array([np.linalg.norm(ev-el) for el in mc_iter])
    # print(l2_mcshap.shape)
    l2_distance_mc.append(np.linalg.norm(ev-mc))

    l2_daspshap=np.linalg.norm(ev-daspv)
    l2_distance_dasp.append(l2_daspshap)

    l2_deepshap=np.linalg.norm(ev-deepv)
    l2_distance_deep.append(l2_deepshap)

    l2_gradshap=np.linalg.norm(ev-gradv)
    l2_distance_grad.append(l2_gradshap)

    # break

In [None]:
dataset

In [None]:
# print Times

print("Time taken by SR: {:.4f}".format(np.mean(time_sr)))
print("Time taken by Permutation: {:.4f}".format(np.mean(time_psv)))
print("Time taken by MonteCarlo: {:.4f}".format(np.mean(time_mc)))
print("Time taken by Exact: {:.4f}".format(np.mean(time_ex)))
print("Time taken by FastSHAP: {:.4f}".format(np.mean(time_fs)))
print("Time taken by LightningSHAP: {:.4f}".format(np.mean(time_ls)))
print("Time taken by DASP: {:.4f}".format(np.mean(time_dasp)))
print("Time taken by DeepExplainer: {:.4f}".format(np.mean(time_deep)))
print("Time taken by GradientExplainer: {:.4f}".format(np.mean(time_grad)))

In [None]:
# print L2 distances
print("L2 distance between Exact and KernelSHAP: {:.4f}".format(np.mean(l2_distance_ks)))
print("L2 distance between Exact and Unbiased KernelSHAP: {:.4f}".format(np.mean(l2_distance_uks)))
print("L2 distance between Exact and Permutation: {:.4f}".format(np.mean(l2_distance_psv)))
print("L2 distance between Exact and MonteCarlo: {:.4f}".format(np.mean(l2_distance_mc)))
print("L2 distance between Exact and FastSHAP: {:.4f}".format(np.mean(l2_distance_fs)))
print("L2 distance between Exact and LightningSHAP: {:.4f}".format(np.mean(l2_distance_ls)))
print("L2 distance between Exact and DASP: {:.4f}".format(np.mean(l2_distance_dasp)))
print("L2 distance between Exact and DeepExplainer: {:.4f}".format(np.mean(l2_distance_deep)))
print("L2 distance between Exact and GradientExplainer: {:.4f}".format(np.mean(l2_distance_grad)))
