# Compute Complexity metrics on all models

In [3]:
from TSInterpret.InterpretabilityModels.Saliency.TSR import TSR, Saliency_PTY
from TSInterpret.InterpretabilityModels.counterfactual.TSEvoCF import TSEvo
# from TSInterpret.InterpretabilityModels.counterfactual.SETSCF import SETSCF

import torch 
from XTSCBench.ClassificationModels.CNN_T import ResNetBaseline, UCRDataset,fit
from XTSCBench.ClassificationModels.LSTM import LSTM
from XTSCBench.CounterfactualEvaluation import CounterfactualEvaluation
from tslearn.datasets import UCR_UEA_datasets
import sklearn
import numpy as np 
import os


2024-12-01 22:27:23.569377: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-01 22:27:23.590670: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-01 22:27:23.706948: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-01 22:27:23.707038: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-01 22:27:23.707314: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

## Data Loading

In [72]:
# dataset='ECG5000'
# dataset = 'ptbxl'
dataset='Epilepsy'

In [73]:

if dataset in ['ECG5000','ECG200','Epilepsy']:
    #For use with CNN reverse Data Dimensions
    train_x, train_y, test_x, test_y=UCR_UEA_datasets().load_dataset(dataset)
elif dataset=='ptbxl':
    train_x = np.load(f'./datasets/ptbxl/x_train.npy')
    train_y = np.load(f'./datasets/ptbxl/y_train.npy')
    test_x = np.load(f'./datasets/ptbxl/x_test.npy')
    test_y = np.load(f'./datasets/ptbxl/y_test.npy')

# 1 hot encoding outcomes
enc1=sklearn.preprocessing.OneHotEncoder(sparse=False).fit(np.vstack((train_y.reshape(-1,1),test_y.reshape(-1,1))))
train_y=enc1.transform(train_y.reshape(-1,1))
test_y=enc1.transform(test_y.reshape(-1,1))    

n_pred_classes =train_y.shape[1]

train_dataset = UCRDataset(train_x.astype(np.float64),train_y.astype(np.int64))
test_dataset = UCRDataset(test_x.astype(np.float64),test_y.astype(np.int64))

train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=16,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)



## Load/Train Model

In [74]:
device='cpu'
hidden_size=10
rnn=0.1
input_size = train_x.shape[-1] # univariate or multi?

model_path = './trained_models'
NumTimesteps = train_x.shape[-2]
NumFeatures = train_x.shape[-1]

# Model saved by this name
model_name = f'lstm_{dataset}_h{hidden_size}_drop{rnn}'

In [75]:
if dataset=='Epilepsy':
    model_name = 'lstm_Epilepsy_h200_drop0.3_acc58'

In [76]:

model =LSTM(input_size, hidden_size ,n_pred_classes,rnndropout=0.1).to(device) 

if os.path.isfile(f'./{model_path}/{model_name}'):
    model =LSTM(1, 10 ,n_pred_classes,rnndropout=0.1).to('cpu') 
    model = torch.load(f'./{model_path}/{model_name}')
    print(f"Model {model_name} successfully loaded")
else:
    print("Model not found. Please train model using training_models.ipynb and provide in this notebook")

model.eval()


Model lstm_Epilepsy_h200_drop0.3_acc58 successfully loaded


LSTM(
  (drop): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=200, out_features=4, bias=True)
  (rnn): LSTM(3, 200, batch_first=True)
)

# Make explainers

In [77]:
from TSInterpret.InterpretabilityModels.counterfactual.TSEvoCF import TSEvo

tsevo_exp = TSEvo(model= model,data=(train_x,train_y), mode = 'time',backend='PYT',epochs=30)

y was one Hot Encoded


In [78]:
# Temporal Saliency 

## Methods
# * Gradients (GRAD)
# * Integrated Gradients (IG)
# * Gradient Shap (GS)
# * DeepLift (DL)
# * DeepLiftShap (DLS)
# * SmoothGrad (SG)
# * Shapley Value Sampling(SVS)
# * Feature Ablation (FA)
# * Occlusion (FO)

from TSInterpret.InterpretabilityModels.Saliency.TSR import TSR, Saliency_PTY

tsr_GRAD_exp = Saliency_PTY(model, NumTimeSteps=train_x.shape[-2], NumFeatures=train_x.shape[-1], method='GRAD', mode='time', tsr=True)


In [79]:
tsr_FA_exp = Saliency_PTY(model, NumTimeSteps=train_x.shape[-2], NumFeatures=train_x.shape[-1], method='FA', mode='time', tsr=True)


In [80]:
tsr_FO_exp =  Saliency_PTY(model, NumTimeSteps=train_x.shape[-2], NumFeatures=train_x.shape[-1], method='FO', mode='time', tsr=True)

In [81]:
# # NativeGuideCF
# from TSInterpret.InterpretabilityModels.counterfactual.NativeGuideCF import NativeGuideCF

# ng_exp = NativeGuideCF(model,(train_x,train_y), backend='PYT', mode='feat',method='NUN_CF')

In [82]:
explainer =  [
    tsevo_exp,
    tsr_FA_exp,
    tsr_GRAD_exp,
    # tsr_FO_exp
    ]


In [15]:

bm=CounterfactualEvaluation(explainer=explainer)


In [34]:
SummaryTable = bm.evaluate(test_x[0:2], np.argmax(test_y[0:2],axis=1),model, mode='time',aggregate=True)


No Target
No Target


In [35]:
SummaryTable.head()

Unnamed: 0,d1_mean,d2_mean,d3_mean,d4_mean,validty_mean,d1_std,d2_std,d3_std,d4_std,validty_std,method,normalize,tsr,transformer,epochs
0,0.998382,0.6239,386.745868,2.6,1.0,0.0,0.003408,10.046387,0.226274,0.0,GRAD,True,True,,
1,0.999191,0.877832,805.530702,4.02,1.0,0.001144,0.044349,79.354777,0.325269,0.0,,,,authentic_opposing_information,30.0


# Metric Settings

In [83]:
test_y.shape

(138, 4)

In [84]:
num_test_samples = 2
interp_folder = './interp_metrics'

# Counterfactual Metrics

In [19]:
bm=CounterfactualEvaluation(explainer=explainer)


In [20]:
SummaryTable_counterfact = bm.evaluate(test_x[0:num_test_samples], np.argmax(test_y[0:num_test_samples],axis=1),model, mode='time',aggregate=True)


No Target
No Target


In [22]:
SummaryTable_counterfact.head()

Unnamed: 0,d1_mean,d2_mean,d3_mean,d4_mean,validty_mean,d1_std,d2_std,d3_std,d4_std,validty_std,method,normalize,tsr,transformer,epochs
0,0.998382,0.6239,386.745868,2.6,1.0,0.0,0.003408,10.046387,0.226274,0.0,GRAD,True,True,,
1,0.999191,0.657026,427.236924,2.624741,1.0,0.001144,0.0429,67.733426,0.260842,0.0,FA,True,True,,
2,0.998382,0.874167,773.04605,3.55,1.0,0.002288,0.040424,67.197135,0.084853,0.0,,,,authentic_opposing_information,30.0


In [23]:
SummaryTable_counterfact.to_csv(f"{interp_folder}/{model_name}_CF.csv", index=False)

# Faithfulness Metrics

In [25]:
from XTSCBench.FaithfulnessEvaluation import FaithfulnessEvaluation
bm=FaithfulnessEvaluation(explainer=explainer,mlmodel=None)


In [26]:
SummaryTable_faith = bm.evaluate(test_x[0:num_test_samples], np.argmax(test_y[0:num_test_samples],axis=1), model,exp=None, mode='time',aggregate=True)


No Target
No Target
GET METRICS
Original (2, 206, 3)
EXP (2, 206, 3)
GET METRICS
Original (2, 206, 3)
EXP (2, 206, 3)
GET METRICS
Original (2, 206, 3)
EXP (2, 206, 3)


In [27]:
SummaryTable_faith.head()

Unnamed: 0,method,normalize,tsr,transformer,epochs
0,GRAD,True,True,,
1,FA,True,True,,
2,,,,authentic_opposing_information,30.0


In [28]:
SummaryTable_faith.to_csv(f"{interp_folder}/{model_name}_faith.csv", index=False)

# Reliability Metrics

In [29]:
from XTSCBench.ReliabilityEvaluation import ReliabilityEvaluation
from quantus.metrics.localisation.auc import AUC

bm=ReliabilityEvaluation(explainer=explainer,mlmodel=None, metrics=[AUC()])


 (1) The AUC metric is likely to be sensitive to the choice of ground truth mask i.e., the 's_batch' input as well as if absolute values 'abs' are taken of the attributions .  
 (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.
 (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance).
 (4) For further information, see original publication: Fawcett, Tom. 'An introduction to ROC analysis' Pattern Recognition Letters Vol 27, Issue 8, (2006).



In [30]:
#CAREFUL THIS IS AN ASSUMPTION
meta=np.zeros_like(test_x[0:num_test_samples])
meta[:,10:20]= np.ones_like(meta[:,10:20])
SummaryTable_reiable = bm.evaluate(test_x[0:num_test_samples], np.argmax(test_y[0:num_test_samples],axis=1),model,meta=meta,exp=None, mode='time',aggregate=True)


No Target
No Target


In [31]:
SummaryTable_reiable.head()

Unnamed: 0,<quantus.metrics.localisation.auc.AUC object at 0x7f3d464639d0>_mean,Pointing_mean,Relevance Rank_mean,Relevance Mass_mean,AuC_mean,<quantus.metrics.localisation.auc.AUC object at 0x7f3d464639d0>_std,Pointing_std,Relevance Rank_std,Relevance Mass_std,AuC_std,method,normalize,tsr,transformer,epochs
0,0.420493,0.0,0.0,4.369506e-07,0.420493,0.129997,0.0,0.0,2.224092e-07,0.129997,GRAD,True,True,,
1,0.542092,0.0,0.166667,0.1224261,0.542092,0.137894,0.0,0.235702,0.1702358,0.137894,FA,True,True,,
2,0.698583,0.0,0.0,1.767379,0.698583,0.013348,0.0,0.0,2.601492,0.013348,,,,authentic_opposing_information,30.0


In [32]:
SummaryTable_reiable.to_csv(f"{interp_folder}/{model_name}_reliable.csv", index=False)

# Complexity Metrics

In [85]:
from XTSCBench.ComplexityEvaluation import ComplexityEvaluation
from quantus.metrics.complexity.effective_complexity import EffectiveComplexity

bm=ComplexityEvaluation(explainer=explainer, metrics= [EffectiveComplexity()])


 (1) The Effective Complexity metric is likely to be sensitive to the choice of normalising 'normalise' (and 'normalise_func') and if taking absolute values of attributions 'abs' and the choice of threshold 'eps'.  
 (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.
 (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance).
 (4) For further information, see original publication: Nguyen, An-phi, and María Rodríguez Martínez. 'On quantitative aspects of model interpretability.' arXiv preprint arXiv:2007.07584 (2020)..



In [86]:
SummaryTable_complex = bm.evaluate(test_x[0:num_test_samples], np.argmax(test_y[0:num_test_samples],axis=1), model, mode='time',aggregate=True)


No Target
No Target


In [87]:
SummaryTable_complex

Unnamed: 0,complexity_mean,<quantus.metrics.complexity.effective_complexity.EffectiveComplexity object at 0x7f3d4616ba90>_mean,complexity_std,<quantus.metrics.complexity.effective_complexity.EffectiveComplexity object at 0x7f3d4616ba90>_std,method,normalize,tsr,transformer,epochs
0,3.377835,158.5,0.259202,30.405592,GRAD,True,True,,
1,4.392748,519.0,0.580744,4.242641,FA,True,True,,
2,6.155328,614.0,0.065276,4.242641,,,,authentic_opposing_information,30.0


In [88]:
SummaryTable_complex.to_csv(f"{interp_folder}/{model_name}_complexity.csv", index=False)

# Robustness Metrics

In [45]:
from XTSCBench.RobustnessEvaluation import RobustnessEvaluation

bm=RobustnessEvaluation(explainer=explainer,mlmodel=None)


In [46]:
SummaryTable_robust = bm.evaluate(test_x[0:2], np.argmax(test_y[0:2],axis=1), model,exp=None, mode='time',aggregate=True)

No Target
No Target
Robustness Shapes
(2, 96, 1)
[1, 1]
X1  (1, 96)
y1  1
No Target


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [None]:
SummaryTable_robust.head()