<a href="https://colab.research.google.com/github/aycaaozturk/Image-Based-Detection-of-Nail-Melanoma-Using-Deep-Learning-Techniques/blob/main/XAI_ViT_b_16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code loads the trained Vision Transformer, runs it on a desired image, collects the self attention matrices (attention rollout) and gives a heatmap, showing which regions influenced the prediction

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.kill(os.getpid(), 9)   # to kill the process

In [None]:


import torch
import numpy as np
import cv2
from PIL import Image

from torch.nn import MultiheadAttention
from torchvision import models
from torchvision.models.vision_transformer import VisionTransformer, EncoderBlock


# 1. Monkey-patch MultiheadAttention -> we change the original function so that we can capture the attention matrice

_old_mha_forward = MultiheadAttention.forward  # saves the original forward method of MultiheadAttention as old mha forward

def _patched_mha_forward(self, query, key, value, *args, **kwargs): # we define a new forward function
    # ensure attention weights always returned

    kwargs['need_weights'] = True  # return attention weights with the output

    kwargs['average_attn_weights'] = False  # dont average over heads

    out, attn = _old_mha_forward(self, query, key, value, *args, **kwargs)  # original forward function
    self.last_attn = attn   # stores the last attention weigths into the attribute
    return out, attn    # returns output + attention weigths

MultiheadAttention.forward = _patched_mha_forward   # -> multiheadattention layers will use this function



# 2. Load full ViT model from the path

def load_vit_full(model_path, device):
    torch.serialization.add_safe_globals([VisionTransformer])
    model = torch.load(model_path, weights_only=False, map_location=device)
    model.to(device)
    model.eval()
    return model



# 3. Preprocessing
# we get the same preprocessing as the training

def get_preprocess():
    weights = models.ViT_B_16_Weights.DEFAULT   # we get the default pretrained configuration
    return weights.transforms()



# 4. Attention Rollout implementation (Torchvision ViT)

#an explainability technique for Transformer based models (e.g. ViT)
#that aggregates the attention information across all layers to produce a single, global importance map.

def attention_rollout_torchvision(model, x, discard_ratio=0.9, head_fusion="mean"):
    device = x.device

    # 1) Forward pass so .last_attn is populated
    with torch.no_grad():
        _ = model(x)

    # 2) Collect all attention blocks
    attn_blocks = []
    for module in model.modules():
        if isinstance(module, EncoderBlock):
            attn = module.self_attention.last_attn
            if attn is not None:
                attn_blocks.append(attn[0].to(device))  # [heads, tokens, tokens]

    if len(attn_blocks) == 0:
        raise RuntimeError("No attention matrices collected. Patch may not have applied.")

    # 3) Rollout calculation
    tokens = attn_blocks[0].size(-1)
    result = torch.eye(tokens, device=device)

    for attn in attn_blocks:
        if head_fusion == "mean":
            attn_fused = attn.mean(dim=0)
        elif head_fusion == "max":
            attn_fused = attn.max(dim=0).values
        else:
            attn_fused = attn.min(dim=0).values

        flat = attn_fused.flatten()
        k = int(flat.numel() * discard_ratio)

        if k > 0:
            _, idxs = flat.topk(k, largest=False)
            attn_fused = attn_fused.reshape(-1)
            attn_fused[idxs] = 0
            attn_fused = attn_fused.view(tokens, tokens)

        attn_fused = attn_fused / (attn_fused.sum(dim=-1, keepdim=True) + 1e-6)

        result = result @ (attn_fused + torch.eye(tokens, device=device))

    mask = result[0, 1:]  # remove CLS token

    side = int(np.sqrt(mask.numel()))
    return mask[:side * side].reshape(side, side).cpu().numpy()



# 5. Full pipeline (image → rollout → heatmap)

def explain_image(model_path, image_path, save_path="rollout.jpg", discard_ratio=0.8):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = load_vit_full(model_path, device)

    preprocess = get_preprocess()
    img = Image.open(image_path).convert("RGB")
    x = preprocess(img).unsqueeze(0).to(device)

    print("Computing attention rollout...")
    mask = attention_rollout_torchvision(model, x, discard_ratio=discard_ratio)

    mask_resized = cv2.resize(mask, img.size)
    mask_resized = (mask_resized - mask_resized.min()) / (mask_resized.max() - mask_resized.min())

    img_np = np.array(img)
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    heatmap = cv2.applyColorMap((mask_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)

    cv2.imwrite(save_path, overlay)
    print(f"Saved rollout visualization → {save_path}")

    return overlay




# 6. Example usage

if __name__ == "__main__":
    explain_image(
        model_path="/content/drive/My Drive/models best/vision transformer/nail_vit_full_model.pth",
        image_path="/content/drive/My Drive/models best/vision transformer/pics/cropped_nail2.jpg",
        save_path="attention_rollout_vit_b16_nail.jpg",
        discard_ratio=0.8
    )


Computing attention rollout...
Saved rollout visualization → attention_rollout_vit_b16_nail.jpg
