# Library importing and model configurations

In [None]:
import torch
import torchvision.transforms as transforms
from captum.attr import Occlusion, Saliency, InputXGradient,NoiseTunnel,IntegratedGradients,KernelShap
from captum.attr import visualization as viz

import numpy as np
from PIL import Image
from skimage.segmentation import slic
from skimage.util import img_as_float
import sys

ROOT = ...
sys.path.append(ROOT)


from medDerm.agent import *
from medDerm.tools import *
from medDerm.utils import *

## Model

In [None]:
device="cuda"
config_path=f"{ROOT}/checkpoints/exp-HAM+Derm7pt-all+BCN+HAM-bin+DermNet+Fitzpatrick.yaml"

model = load_checkpoint(config_path).to(device)
model.eval()
head="HAM10k"

## Transformers

In [None]:
transform = transforms.Compose([
 transforms.Resize((224,224)),
 transforms.ToTensor(),
])
norm_transform = transforms.Normalize(
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
)

## Load images 


In [None]:
HAMdatasetDir = f"{ROOT}/datasets/ISIC_ImageNet/"
test_img = Image.open(f'{HAMdatasetDir}ISIC_0034597.jpg').convert('RGB')
transformed_img = transform(test_img)
input_img = norm_transform(transformed_img)
input_img = input_img.unsqueeze(0).to(device)
orig_img = np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0))
transformed_img=transformed_img.unsqueeze(0).to(device)
input_img.requires_grad_()

Perform the classifiction


In [None]:
classes=['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC']


outputs = model.forward_explanation_tasks(input_img)
output=torch.tensor(outputs[head]['predictions'])
prediction_score, pred_label_idx = torch.topk(output, 1)

pred_label_idx.squeeze_()
predicted_label = classes[pred_label_idx.item()]
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
print(pred_label_idx)

In [None]:

def forward_model(image: torch.Tensor):
    outputs = model.forward_explanation_tasks(image)
    predictions = outputs[head]['predictions']# shape: [B, num_classes]
    return predictions # shape: [B]


Gradient-based methods:
=============================================


## Vanilla Gradient



In [None]:
vanilla_gradient = Saliency(forward_model)
attributions_vg = vanilla_gradient.attribute(input_img, target=pred_label_idx)


_ = viz.visualize_image_attr_multiple(
    np.transpose(attributions_vg.squeeze().cpu().detach().numpy(), (1, 2, 0)),
    orig_img,
    methods=["original_image", "masked_image", "blended_heat_map"],
    alpha_overlay=0.5,
    signs=["all", "positive", "positive"],
    titles=["Original", "Image masking", "Blended heat map"],
)


## Input x Gradient



In [None]:
 
InXgrad= InputXGradient(forward_model)
attributions_ixg = InXgrad.attribute(input_img, target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ixg.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                     )

### With SmoothGrad

In [None]:
 
smoothGrad= NoiseTunnel(InXgrad)
attributions_sg = smoothGrad.attribute(
    input_img, 
    stdevs=0.3, 
    target=pred_label_idx,
    nt_samples=30, 
    nt_samples_batch_size=5)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_sg.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                        )


## Integrated Gradients



In [None]:
 
ig = IntegratedGradients(forward_model)

attributions_ig = ig.attribute(
    input_img, 
    target=pred_label_idx, 
    n_steps=200,
    internal_batch_size=10)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                     )

### With SmoothGrad



In [None]:

 
smoothGrad_ig= NoiseTunnel(ig)
attributions_sg_ig = smoothGrad_ig.attribute(
    input_img, 
    stdevs=0.3, 
    target=pred_label_idx,
    nt_samples=30, 
    nt_samples_batch_size=5,
    n_steps=100,
    internal_batch_size=10)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_sg_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                    )



Perturbation-based methods
==================================

## Occlusion

In [None]:
 
occlusion = Occlusion(forward_model)


attributions_occ = occlusion.attribute(input_img, target=pred_label_idx, sliding_window_shapes=(3, 15, 15), strides=(3, 8, 8))

_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                     )

## Kernel SHAP

In [None]:
def generate_feature_mask(input_img):
    
    orig_img_float = img_as_float(input_img)  # shape (H, W, 3), values in [0, 1]
    segments = slic(orig_img_float, n_segments=100, compactness=10)
    feature_mask = torch.tensor(segments, dtype=torch.long).to(device)
    return feature_mask

feature_mask = generate_feature_mask(orig_img)
    

baseline = torch.zeros_like(transformed_img) + 0.5
baseline = baseline.to(device)

kernel_shap = KernelShap(forward_model)


attributions_ks = kernel_shap.attribute(input_img,
                                        target=pred_label_idx,
                                        #baselines=baseline,
                                        n_samples=400,
                                        perturbations_per_eval=32,
                                        show_progress=True,
                                        feature_mask=feature_mask)


_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ks.squeeze().cpu().detach().numpy(), (1, 2, 0)),
                                        orig_img,
                                        methods=["original_image", "masked_image", "blended_heat_map"],
                                        alpha_overlay=0.5,
                                        signs=["all", "positive", "positive"],
                                        titles=["Original", "Image masking", "Blended heat map"]
                                    )