In [12]:
import torch
import torchvision
import quantus
from autoexplainer.utils import fix_relus_in_model

## Insert your stuff

### This is Resnet example

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading model to CPU/GPU device: {device}")
model = torch.load(f'../models/resnet_18.pth', map_location=device)

transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256, 256)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

dataset = torchvision.datasets.ImageFolder("../data/test", transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2)

Loading model to CPU/GPU device: cpu


Batch to explain.

In [4]:
x_batch, y_batch = iter(data_loader).next()
predicted_labels = model(x_batch).argmax(axis=1)
x_batch_np, y_batch_np = x_batch.cpu().numpy(), y_batch.cpu().numpy()

## Explanations

In [9]:
model = fix_relus_in_model(model)

https://captum.ai/tutorials/Image_and_Text_Classification_LIME

In [13]:
a_batch_gradients = quantus.explain(model, x_batch, y_batch, method="IntegratedGradients", normalise=True)
a_batch_saliency = quantus.explain(model, x_batch, y_batch, method="Saliency", normalise=True)

Lime attribution: 100%|██████████| 25/25 [00:02<00:00,  8.45it/s]
Lime attribution: 100%|██████████| 25/25 [00:02<00:00,  9.99it/s]
  predictions = model(torch.tensor(inputs)).argmax(dim=1)


## Metrics


In [5]:
metric = quantus.FaithfulnessEstimate(**{
    "perturb_func": quantus.baseline_replacement_by_indices,
    "similarity_func": quantus.correlation_pearson,
    "features_in_step": 256,  
    "perturb_baseline": "mean",  
    "pixels_in_step": 28,
})

faithfulness_grad = metric(model=model, 
   x_batch=x_batch_np, 
   y_batch=y_batch_np,
   a_batch=a_batch_gradients,
   **{"device": device})

faithfulness_saliency = metric(model=model, 
   x_batch=x_batch_np, 
   y_batch=y_batch_np,
   a_batch=a_batch_saliency,
   **{"device": device})

 (1) The Faithfulness Estimate metric is likely to be sensitive to the choice of baseline value 'perturb_baseline' and similarity function 'similarity_func'. 
 (2) If attributions are normalised or their absolute values are taken it may destroy or skew information in the explanation and as a result, affect the overall evaluation outcome.
 (3) Make sure to validate the choices for hyperparameters of the metric (by calling .get_params of the metric instance).
 (4) For further information, see original publication: Alvarez-Melis, David, and Tommi S. Jaakkola. 'Towards robust interpretability with self-explaining neural networks.' arXiv preprint arXiv:1806.07538 (2018).
[0m


In [None]:
irof = quantus.IterativeRemovalOfFeatures(**{
    "segmentation_method": "slic",
    "perturb_baseline": "mean",
    "perturb_func": quantus.baseline_replacement_by_indices,
    "return_aggregate": False,
})

irof_grad = irof(model=model,
   x_batch=x_batch_np,
   y_batch=y_batch_np,
   a_batch=None,
   **{"explain_func": quantus.explain, "method": "IntegratedGradients", "device": device})
irof_saliency = irof(model=model,
   x_batch=x_batch_np,
   y_batch=y_batch_np,
   a_batch=None,
   **{"explain_func": quantus.explain, "method": "Saliency", "device": device})