# MODULES

## Ensemble

In [1]:
import warnings
warnings.filterwarnings("ignore") # Ignore all warnings
import torch
from EnsembleXAI.Ensemble import normEnsembleXAI
from captum.attr import IntegratedGradients, GradientShap, Saliency, Lime
from torchvision.models import resnet50, ResNet50_Weights

# Define model
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()

# define input
input1 = torch.randn(1, 3, 32, 32)
input2 = torch.randn(1, 3, 32, 32)
inputs = torch.cat((input1,input2),dim=0)

ig = IntegratedGradients(model).attribute(inputs, target=3)
# gs = GradientShap(model).attribute(inputs, target=3)
lime = Lime(model).attribute(inputs, target=3)
sal = Saliency(model).attribute(inputs, target=3)

explanations = torch.stack([ig, lime, sal], dim=1) # on first dim num_images, then num_xai
agg = normEnsembleXAI(explanations, aggregating_func='avg')

print('***Inputs \ndim:{} shape:{}'.format(inputs.dim(),inputs.shape))
print('***IntegratedGradients method \ndim:{} shape:{}'.format(ig.dim(),ig.shape))
print('***Ensemble XAI \ndim:{} shape:{}'.format(explanations.dim(),explanations.shape))
print('***Agg normXAI XAI \ndim:{} shape:{}'.format(agg.dim(),agg.shape))

***Inputs 
dim:4 shape:torch.Size([2, 3, 32, 32])
***IntegratedGradients method 
dim:4 shape:torch.Size([2, 3, 32, 32])
***Ensemble XAI 
dim:5 shape:torch.Size([2, 3, 3, 32, 32])
***Agg normXAI XAI 
dim:4 shape:torch.Size([2, 3, 32, 32])


## Metrics

In [2]:
from EnsembleXAI.Metrics import consistency

# get two explanations
explanation1 = ig
explanation2 = sal # 2,3,32,32
explanations = torch.stack([explanation1, explanation2], dim=1) # expected shape: [num_images, num_explanations, channeld,h,w]

for _i,one_image in enumerate(explanations):
    # Calculate consistency for the two explanations
    con = consistency(one_image)
    print(f"Consistency of image {_i}: {con}")

Consistency of image 0: 0.5401846766471863
Consistency of image 1: 0.5810154676437378


## Normalization

In [3]:
from EnsembleXAI.Normalization import mean_var_normalize,median_iqr_normalize,second_moment_normalize

# get two explanations
explanation1 = ig
explanation2 = sal
# stack them
explanations = torch.stack([explanation1, explanation2], dim=1) # expected shape: [num_images, num_explanations, channeld,h,w]
# explanations = torch.randn(1,1, 3, 64, 64)

# Scale an explanation using Robust Standardization
normal_std = mean_var_normalize(explanations)
# Scale an explanation using Robust Standardization
robust_std = median_iqr_normalize(explanations)
# Scale an explanation using Second Moment Scaling
scaled_explanation = second_moment_normalize(explanations)