Attention-based Explainable AI (XAI)
========================

In [None]:
import torch
import sys

ROOT = ...
sys.path.append(ROOT)


from medDerm.agent import *
from medDerm.tools import *
from medDerm.utils import *

In [None]:
device="cuda"
config_path=f"{ROOT}/checkpoints/exp-HAM+Derm7pt-all+BCN+HAM-bin+DermNet+Fitzpatrick.yaml"


torch.mps.empty_cache()
model = load_checkpoint(config_path).to(device)
model.eval()
head="HAM10k"


In [None]:
attentions = []

def hook_fn(module, input, output):
    # output is (batch_size, num_heads, tokens, tokens)
    attentions.append(output)

layer = 2

# Register the hook for each transformer block
model_info=model.model
block = model_info.layers[layer-1]
for transformer_block in block.blocks:
    transformer_block.attn.attn_drop.register_forward_hook(hook_fn)

In [None]:
# Preprocess image
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
image_path = f"{ROOT}/datasets/ISIC2018_Task3_Test_input/ISIC_0035859.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)


with torch.no_grad():
    _ = model.forward_explaination_tasks(input_tensor)


In [None]:
for i, att in enumerate(attentions):
    print(f"Layer {i}: attention shape = {att.shape}")

In [None]:
import torch.nn.functional as F


def compute_rollout_attention(attentions, start_layer=4, end_layer=22, skip_layers=False):# fix the layer in order to use only layers with the same batch size
    # Use the average across all heads
    attentions = attentions[start_layer:end_layer]
    batch_size = attentions[0].size(0)
    print(f"Batch size: {batch_size}")
    for i,att in enumerate(attentions):
        if att.size(0) != batch_size:
            if not skip_layers:
                print(f"Layer {i} has different batch size: {att.size(0)} select the right interval of layers or call the function with the parameter skip_layers=True")
                return None
            else: 
                print(f"Layer {i} has different batch size: {att.size(0)}, erased from attentions")
                attentions.pop(i)

    result = torch.eye(attentions[0].size(-1)).to(device) # Identity matrix N*N (N number of tokens)
    for attention in attentions:
        #each attention is (batch_size, num_heads, tokens, tokens)
        attention_heads_fused = attention.mean(dim=1) # mean over heads-> (batch_size, tokens, tokens)
        attention_heads_fused += torch.eye(attention_heads_fused.size(-1)).to(device) # Add identity matrix in order to avoid zero attention
        attention_heads_fused /= attention_heads_fused.sum(dim=-1, keepdim=True)#Normalizes the attention across each row (so each token's attentions sum to 1).
        result = torch.matmul(attention_heads_fused, result)
    return result[0]

rollout = compute_rollout_attention(attentions, start_layer=4, end_layer=22)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Assuming the attention map is for 7x7 patches
heatmap = rollout[1:].mean(dim=0).reshape(7, 7).detach().cpu().numpy()
heatmap = cv2.resize(heatmap, (224, 224))
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

# Overlay on image
img_np = np.array(image.resize((224, 224)))
heatmap_color = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
overlay = heatmap_color * 0.3 + img_np * 0.7

plt.imshow(overlay.astype(np.uint8))
plt.axis('off')
plt.show()