Attention-based Explainable AI (XAI)
========================

In [None]:
import torch
import sys
from torchvision import transforms
from PIL import Image

ROOT = ...
sys.path.append(ROOT)


from medDerm.agent import *
from medDerm.tools import *
from medDerm.utils import *

In [None]:
device="cuda"
config_path=f"{ROOT}/checkpoints/exp-HAM+Derm7pt-all+BCN+HAM-bin+DermNet+Fitzpatrick.yaml"


torch.cuda.empty_cache()
model = load_checkpoint(config_path).to(device)
model.eval()
head="HAM10k"


In [None]:
activations = {}
gradients = {}

level=0

model_info= model.model
target_layer = model_info.layers[level].blocks[-1].norm2  # or attn.proj if you prefer

def forward_hook(module, input, output):
    activations["value"] = output.detach()

def backward_hook(module, grad_input, grad_output):
    gradients["value"] = grad_output[0].detach()

handle_f = target_layer.register_forward_hook(forward_hook)
handle_b = target_layer.register_backward_hook(backward_hook)




In [None]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
image_path = f"{ROOT}/datasets/ISIC2018_Task3_Test_input/ISIC_0035859.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)




output = model.forward_attention_tasks(input_tensor)
class_idx = output.argmax(dim=1).item()
score = output[0, class_idx]
model.zero_grad()
score.backward()


In [None]:
# Grad-CAM style weights
weights = gradients["value"].mean(dim=1, keepdim=True)  # average across features

# Token-wise importance
cam = (weights * activations["value"]).sum(dim=-1).squeeze()  # (num_tokens,)
cam = torch.relu(cam)
cam = cam / cam.max()  # normalize

print("cam shape", cam.shape)

if level==3:
    cam_image = cam.reshape(7, 7).cpu().numpy()
elif level==2:
    cam_image = cam.reshape(14, 14).cpu().numpy()
else:
    cam_image = cam.reshape(int(math.sqrt(cam.shape[0])), int(math.sqrt(cam.shape[0]))).cpu().numpy()


handle_f.remove()
handle_b.remove()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

cam_np = cam_image
# Normalize again (in case)
cam_np = (cam_np - cam_np.min()) / (cam_np.max() - cam_np.min() + 1e-6)
# Resize to 224x224
cam_resized = cv2.resize(cam_np, (224, 224))

heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # OpenCV uses BGR

img_np = np.array(image.resize((224, 224)))

# Blend the image and heatmap (alpha controls transparency)
overlayed_img = np.uint8(0.6 * img_np + 0.4 * heatmap)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img)
plt.axis('off')
plt.title("Grad-CAM Heatmap Overlay")
plt.show()