In [1]:
# Utils
import torch
import numpy as np
import pickle

# ML models
from openxai.LoadModel import LoadModel

# Data loaders
from openxai.dataloader import return_loaders

# Explanation models
from openxai.Explainer import Explainer

# Evaluation methods
from openxai.evaluator import Evaluator

# Perturbation methods required for the computation of the relative stability metrics
from openxai.explainers.catalog.perturbation_methods import NormalPerturbation
from openxai.explainers.catalog.perturbation_methods import NewDiscrete_NormalPerturbation

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Choose the model and the data set you wish to generate explanations for
data_loader_batch_size = 32
data_name = 'adult' # must be one of ['heloc', 'adult', 'german', 'compas']
model_name = 'lr'    # must be one of ['lr', 'ann']

### (0) Explanation method hyperparameters

In [3]:
# Hyperparameters for Lime
lime_mode = 'tabular'
lime_sample_around_instance = True
lime_kernel_width = 0.75
lime_n_samples = 1000
lime_discretize_continuous = False
lime_standard_deviation = float(np.sqrt(0.03))

### (1) Data Loaders

In [4]:
# Get training and test loaders
loader_train, loader_test = return_loaders(data_name=data_name,
                                           download=True,
                                           batch_size=data_loader_batch_size)
data_iter = iter(loader_test)
inputs, labels = next(data_iter)
labels = labels.type(torch.int64)

In [5]:
# get full training data set
data_all = torch.FloatTensor(loader_train.dataset.data)

### (2) Load a pretrained ML model

In [6]:
# Load pretrained ml model
model = LoadModel(data_name=data_name,
                  ml_model=model_name,
                  pretrained=True)

### (3) Choose an explanation method

#### I: Explanation method with particular hyperparameters (LIME)

In [7]:
# You can supply your own set of hyperparameters like so:
param_dict_lime = dict()
param_dict_lime['dataset_tensor'] = data_all
param_dict_lime['std'] = lime_standard_deviation
param_dict_lime['mode'] = lime_mode
param_dict_lime['sample_around_instance'] = lime_sample_around_instance
param_dict_lime['kernel_width'] = lime_kernel_width
param_dict_lime['n_samples'] = lime_n_samples
param_dict_lime['discretize_continuous'] = lime_discretize_continuous
lime = Explainer(method='lime',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_lime=param_dict_lime)

In [8]:
lime_custom = lime.get_explanation(inputs, 
                                   label=labels)

100%|██████████████████████████████████████████| 32/32 [00:00<00:00, 206.12it/s]


In [9]:
lime_custom[0,:]

tensor([ 0.0904, -0.0068,  0.2557,  1.3355,  0.1716,  0.1180,  0.0382, -0.0041,
        -0.1398, -0.0170,  0.0126,  0.0073,  0.0148])

#### II: Explanation method with default hyperparameters (LIME)

In [10]:
# You can also use the default hyperparameters likes so:
lime = Explainer(method='lime',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_lime=None)
lime_default_exp = lime.get_explanation(inputs.float(), 
                                        label=labels)
lime_default_exp[0,:]

100%|██████████████████████████████████████████| 32/32 [00:00<00:00, 265.18it/s]


tensor([ 0.0888,  0.0129,  0.2171,  1.2452,  0.1439,  0.1072,  0.0416,  0.0051,
        -0.1425, -0.0325,  0.0178, -0.0172,  0.0037])

In [11]:
# You can also use the default hyperparameters likes so:
shap = Explainer(method='shap',
                 model=model,
                 dataset_tensor=data_all,
                 param_dict_shap=None)
shap_default_exp = shap.get_explanation(inputs.float(), 
                                        label=labels)
shap_default_exp[0,:]



tensor([-0.0034, -0.0011, -0.0766, -0.0034,  0.0043, -0.0148, -0.0079, -0.0028,
         0.0818,  0.0294, -0.0052, -0.0026, -0.0063])

#### III: Explanation method with default hyperparameters (IG)

In [12]:
index = 5
# To use a different explanation method change the method name like so
ig = Explainer(method='ig',
               model=model,
               dataset_tensor=data_all,
               param_dict_lime=None)
ig_default_exp = ig.get_explanation(inputs.float(), 
                                    label=labels)
ig_default_exp[index,:]

tensor([-1.5051e-01, -1.9686e-02, -3.8156e-01, -2.0774e+00, -2.3069e-01,
        -1.9626e-01, -2.3884e-02,  4.0561e-04,  2.3834e-01,  5.1439e-02,
        -2.4375e-02, -6.9914e-03, -4.3782e-03], dtype=torch.float64)

### (4) Choose an evaluation metric

In [13]:
def generate_mask(explanation, top_k):
    mask_indices = torch.topk(explanation, top_k).indices
    mask = torch.zeros(explanation.shape) > 10
    for i in mask_indices:
        mask[i] = True
    return mask

In [14]:
# Perturbation class parameters
perturbation_mean = 0.0
perturbation_std = 0.10
perturbation_flip_percentage = 0.03
if data_name == 'compas':
    feature_types = ['c', 'd', 'c', 'c', 'd', 'd', 'd']
# Adult feature types
elif data_name == 'adult':
    feature_types = ['c'] * 6 + ['d'] * 7

# Gaussian feature types
elif data_name == 'synthetic':
    feature_types = ['c'] * 20
# Heloc feature types
elif data_name == 'heloc':
    feature_types = ['c'] * 23
elif data_name == 'german':
    feature_types = pickle.load(open('./data/German_Credit_Data/german-feature-metadata.p', 'rb'))

In [15]:
# Perturbation methods
if data_name == 'german':
    # use special perturbation class
    perturbation = NewDiscrete_NormalPerturbation("tabular",
                                                  mean=perturbation_mean,
                                                  std_dev=perturbation_std,
                                                  flip_percentage=perturbation_flip_percentage)

else:
    perturbation = NormalPerturbation("tabular",
                                      mean=perturbation_mean,
                                      std_dev=perturbation_std,
                                      flip_percentage=perturbation_flip_percentage)

In [16]:
input_dict = dict()
index = index

# inputs and models
input_dict['x'] = inputs[index].reshape(-1)
input_dict['input_data'] = inputs
input_dict['explainer'] = shap
input_dict['explanation_x'] = shap_default_exp[index,:].flatten()
input_dict['model'] = model

# perturbation method used for the stability metric
input_dict['perturbation'] = perturbation
input_dict['perturb_method'] = perturbation
input_dict['perturb_max_distance'] = 0.4
input_dict['feature_metadata'] = feature_types
input_dict['p_norm'] = 2
input_dict['eval_metric'] = None

# true label, predicted label, and masks
input_dict['top_k'] = 11
input_dict['y'] = labels[index].detach().item()
input_dict['y_pred'] = torch.max(model(inputs[index].unsqueeze(0).float()), 1).indices.detach().item()
input_dict['mask'] = generate_mask(input_dict['explanation_x'].reshape(-1), input_dict['top_k'])

# required for the representation stability measure
input_dict['L_map'] = model

In [17]:
evaluator = Evaluator(input_dict,
                      inputs=inputs,
                      labels=labels, 
                      model=model, 
                      explainer=shap)

In [18]:
if hasattr(model, 'return_ground_truth_importance'):
    # evaluate rank correlation
    print('RC:', evaluator.evaluate(metric='RC'))

    # evaluate feature agreement
    print('FA:', evaluator.evaluate(metric='FA'))

    # evaluate rank agreement
    print('RA:', evaluator.evaluate(metric='RA'))

    # evaluate sign agreement
    print('SA:', evaluator.evaluate(metric='SA'))

    # evaluate signed rankcorrelation
    print('SRA:', evaluator.evaluate(metric='SRA'))

RC: (array([0.25274725]), 0.25274725274725274)
FA: (array([0.81818182]), 0.8181818181818182)
RA: (array([0.27272727]), 0.2727272727272727)
SA: (array([0.09090909]), 0.09090909090909091)
SRA: (array([0.]), 0.0)


In [19]:
# evaluate prediction gap on umportant features
print('PGU:', evaluator.evaluate(metric='PGU'))

# evaluate prediction gap on important features
print('PGI:', evaluator.evaluate(metric='PGI'))

# evaluate relative input stability
print('RIS:', evaluator.evaluate(metric='RIS'))

# evaluate relative output stability
print('ROS:', evaluator.evaluate(metric='ROS'))

PGU: 0.16902333
PGI: 0.024730742
RIS: 1639.847305745055
ROS: 25136.389079091703
