In [1]:
# Utils
import torch
import numpy as np
import pickle

# ML models
from openxai.LoadModel import LoadModel

# Data loaders
from openxai.dataloader import return_loaders

# Explanation models
from openxai.Explainer import Explainer

# Evaluation methods
from openxai.evaluator import Evaluator

# Perturbation methods required for the computation of the relative stability metrics
from openxai.explainers.catalog.perturbation_methods import NormalPerturbation
from openxai.explainers.catalog.perturbation_methods import NewDiscrete_NormalPerturbation

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
# Choose the model and the data set you wish to generate explanations for
data_loader_batch_size = 32
data_name = 'adult' # must be one of ['heloc', 'adult', 'german', 'compas']
model_name = 'ann'    # must be one of ['lr', 'ann']

### (0) Explanation method hyperparameters

In [27]:
# Hyperparameters for Lime
lime_mode = 'tabular'
lime_sample_around_instance = True
lime_kernel_width = 0.75
lime_n_samples = 1000
lime_discretize_continuous = False
lime_standard_deviation = float(np.sqrt(0.03))

### (1) Data Loaders

In [28]:
# Get training and test loaders
loader_train, loader_test = return_loaders(data_name=data_name,
                                           download=True,
                                           batch_size=data_loader_batch_size)
data_iter = iter(loader_test)
inputs, labels = data_iter.next()
labels = labels.type(torch.int64)

In [29]:
# get full training data set
data_all = torch.FloatTensor(loader_train.dataset.data)

### (2) Load a pretrained ML model

In [30]:
# Load pretrained ml model
model = LoadModel(data_name=data_name,
                  ml_model=model_name,
                  pretrained=True)

### (3) Choose an explanation method

#### I: Explanation method with particular hyperparameters (LIME)

In [31]:
# You can supply your own set of hyperparameters like so:
param_dict_lime = dict()
param_dict_lime['dataset_tensor'] = data_all
param_dict_lime['std'] = lime_standard_deviation
param_dict_lime['mode'] = lime_mode
param_dict_lime['sample_around_instance'] = lime_sample_around_instance
param_dict_lime['kernel_width'] = lime_kernel_width
param_dict_lime['n_samples'] = lime_n_samples
param_dict_lime['discretize_continuous'] = lime_discretize_continuous
lime = Explainer(method='lime',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_lime=param_dict_lime)

In [32]:
lime_custom = lime.get_explanation(inputs, 
                                   label=labels)

100%|██████████| 32/32 [00:00<00:00, 99.30it/s] 


In [33]:
lime_custom[0,:]
# print(lime_custom.size())

tensor([ 0.1770,  0.1251,  0.4666,  0.9785,  0.1941,  0.2594,  0.0423,  0.0342,
        -0.2683, -0.0188,  0.0434, -0.0044,  0.0232])

#### II: Explanation method with default hyperparameters (LIME)

In [34]:
# You can also use the default hyperparameters likes so:
lime = Explainer(method='lime',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_lime=None)
lime_default_exp = lime.get_explanation(inputs.float(), 
                                        label=labels)
lime_default_exp[0,:]

100%|██████████| 32/32 [00:00<00:00, 92.28it/s] 


tensor([ 1.8319e-01,  9.3354e-02,  4.5634e-01,  9.3258e-01,  1.9481e-01,
         2.1715e-01,  5.4959e-02,  2.5337e-05, -2.3532e-01, -1.3000e-02,
         7.4688e-02,  2.1017e-02,  3.3147e-03])

#### III: Explanation method with default hyperparameters (IG)

In [35]:
index = 0
# To use a different explanation method change the method name like so
ig = Explainer(method='ig',
               model=model,
               dataset_tensor=data_all,
               param_dict_lime=None)
ig_default_exp = ig.get_explanation(inputs.float(), 
                                    label=labels)
ig_default_exp[index,:]

tensor([-1.8331e-01, -1.1128e-01, -4.4522e-01, -9.3644e-01, -1.8038e-01,
        -2.2002e-01, -2.6008e-02,  3.7538e-03,  2.3225e-01,  3.3227e-02,
        -3.5954e-02,  3.3398e-04, -9.2020e-03], dtype=torch.float64)

#### IV: Explanation method with default hyperparameters (SHAP)

In [53]:
shap = Explainer(method='shap',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_shap=None)
shap_default_exp = shap.get_explanation(inputs.float(),
                                        label=labels)
shap_default_exp[index,:]



tensor([-0.0268, -0.0031, -0.0452,  0.0038,  0.0047, -0.0204, -0.0060,  0.0134,
         0.1242, -0.0065, -0.0231, -0.0024, -0.0145])

### (4) Choose an evaluation metric

In [36]:
def generate_mask(explanation, top_k):
    mask_indices = torch.topk(explanation, top_k).indices
    mask = torch.zeros(explanation.shape) > 10
    for i in mask_indices:
        mask[i] = True
    return mask

In [37]:
# Perturbation class parameters
perturbation_mean = 0.0
perturbation_std = 0.10
perturbation_flip_percentage = 0.03
if data_name == 'compas':
    feature_types = ['c', 'd', 'c', 'c', 'd', 'd', 'd']
# Adult feature types
elif data_name == 'adult':
    feature_types = ['c'] * 6 + ['d'] * 7

# Gaussian feature types
elif data_name == 'synthetic':
    feature_types = ['c'] * 20
# Heloc feature types
elif data_name == 'heloc':
    feature_types = ['c'] * 23
elif data_name == 'german':
    feature_types = pickle.load(open('./data/German_Credit_Data/german-feature-metadata.p', 'rb'))

In [38]:
# Perturbation methods
if data_name == 'german':
    # use special perturbation class
    perturbation = NewDiscrete_NormalPerturbation("tabular",
                                                  mean=perturbation_mean,
                                                  std_dev=perturbation_std,
                                                  flip_percentage=perturbation_flip_percentage)

else:
    perturbation = NormalPerturbation("tabular",
                                      mean=perturbation_mean,
                                      std_dev=perturbation_std,
                                      flip_percentage=perturbation_flip_percentage)

In [54]:
input_dict = dict()
index = index
index = 0

# inputs and models
input_dict['x'] = inputs[index].reshape(-1)
input_dict['input_data'] = inputs
input_dict['explainer'] = shap
input_dict['explanation_x'] = shap_default_exp[index,:].flatten()
input_dict['model'] = model

# perturbation method used for the stability metric
input_dict['perturbation'] = perturbation
input_dict['perturb_method'] = perturbation
input_dict['perturb_max_distance'] = 0.4
input_dict['feature_metadata'] = feature_types
input_dict['p_norm'] = 2
input_dict['eval_metric'] = None

# true label, predicted label, and masks
input_dict['top_k'] = 3
input_dict['y'] = labels[index].detach().item()
input_dict['y_pred'] = torch.max(model(inputs[index].unsqueeze(0).float()), 1).indices.detach().item()
input_dict['mask'] = generate_mask(input_dict['explanation_x'].reshape(-1), input_dict['top_k'])

# required for the representation stability measure
input_dict['L_map'] = model

In [55]:
evaluator = Evaluator(input_dict,
                      inputs=inputs,
                      labels=labels, 
                      model=model, 
                      explainer=shap)

In [58]:
if hasattr(model, 'return_ground_truth_importance'):
    # evaluate rank correlation
    print('RC:', evaluator.evaluate(metric='RC'))

    # evaluate feature agreement
    print('FA:', evaluator.evaluate(metric='FA'))

    # evaluate rank agreement
    print('RA:', evaluator.evaluate(metric='RA'))

    # evaluate sign agreement
    print('SA:', evaluator.evaluate(metric='SA'))

    # evaluate signed rankcorrelation
    print('SRA:', evaluator.evaluate(metric='SRA'))

In [57]:
# evaluate prediction gap on umportant features
print('PGU:', evaluator.evaluate(metric='PGU'))

# evaluate prediction gap on important features
print('PGI:', evaluator.evaluate(metric='PGI'))

# evaluate relative input stability
print('RIS:', evaluator.evaluate(metric='RIS'))

# evaluate relative output stability
print('ROS:', evaluator.evaluate(metric='ROS'))

PGU: 0.8090787
PGI: 0.8894353
RIS: 39.43808705525204
ROS: 688.1789800956093
