In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
from PIL import Image
import cv2
import os
import timm
model_path = "best_vit_dino.pth"
indices_path = "test_indices.pth"
data_dir = "paddy-disease-classification"
output_dir = "dino_attention_outputs"
img_size = 224
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(output_dir, exist_ok=True)
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
full_dataset = ImageFolder(data_dir, transform=transform)
class_names = full_dataset.classes
test_indices = torch.load(indices_path)
test_dataset = Subset(full_dataset, test_indices)
chosen_indices = {}
for i in range(len(test_dataset)):
    _, label = test_dataset[i]
    if label not in chosen_indices:
        chosen_indices[label] = i
    if len(chosen_indices) == len(class_names):
        break
test_indices = torch.load("test_indices.pth")
print("Images used in attention visualization:\n")
for class_idx, dataset_idx in chosen_indices.items():
    original_index = test_indices[dataset_idx]  # map back to original dataset
    image_path, _ = full_dataset.samples[original_index]
    class_name = class_names[class_idx]
    print(f"{class_name:25s} --> {image_path}")


  test_indices = torch.load(indices_path)


Images used in attention visualization:

bacterial_leaf_blight     --> paddy-disease-classification\bacterial_leaf_blight\PDD00536.jpg
normal                    --> paddy-disease-classification\normal\PDD10618.jpg
blast                     --> paddy-disease-classification\blast\PDD03633.jpg
black_stem_borer          --> paddy-disease-classification\black_stem_borer\PDD02034.jpg
leaf_roller               --> paddy-disease-classification\leaf_roller\PDD08975.jpg
bacterial_panicle_blight  --> paddy-disease-classification\bacterial_panicle_blight\PDD01329.jpg
tungro                    --> paddy-disease-classification\tungro\PDD13085.jpg
white_stem_borer          --> paddy-disease-classification\white_stem_borer\PDD14197.jpg
downy_mildew              --> paddy-disease-classification\downy_mildew\PDD05775.jpg
hispa                     --> paddy-disease-classification\hispa\PDD07407.jpg
yellow_stem_borer         --> paddy-disease-classification\yellow_stem_borer\PDD15703.jpg
brown_spot       

  test_indices = torch.load("test_indices.pth")


In [None]:
model = timm.create_model('vit_base_patch16_224.dino', pretrained=False)
model.head = torch.nn.Identity()
model.load_state_dict(torch.load(model_path, map_location=device),strict=False)
model = model.to(device).eval()

  model.load_state_dict(torch.load(model_path, map_location=device),strict=False)


In [None]:
attention_scores = []
def get_qk_hook(module, input, output):
    B, N, _ = output.shape  # (B, N, 3*D)
    qkv = output.reshape(B, N, 3, -1).permute(2, 0, 1, 3)  # (3, B, N, D)
    q, k = qkv[0], qkv[1]  # each (B, N, D)
    attn = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)  # (B, N, N)
    attn = F.softmax(attn, dim=-1)
    attention_scores.append(attn.detach())
hook = model.blocks[-1].attn.qkv.register_forward_hook(get_qk_hook)


In [None]:
for class_idx, dataset_idx in chosen_indices.items():
    image_tensor, true_label = test_dataset[dataset_idx]
    input_tensor = image_tensor.unsqueeze(0).to(device)
    attention_scores.clear()
    _ = model(input_tensor)
    attn = attention_scores[0][0]  # (N, N)
    cls_attn = attn[0, 1:]  # CLS → patches
    cls_attn = cls_attn.reshape(14, 14).cpu().numpy()
    cls_attn = cv2.resize(cls_attn, (img_size, img_size))
    cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
    rgb_img = image_tensor.permute(1, 2, 0).cpu().numpy()
    rgb_img = ((rgb_img * 0.5) + 0.5) * 255
    rgb_img = np.uint8(np.clip(rgb_img, 0, 255))
    heatmap = cv2.applyColorMap(np.uint8(255 * cls_attn), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(heatmap, 0.5, rgb_img, 0.5, 0)
    filename = f"{output_dir}/{class_names[true_label]}_attn.png"
    Image.fromarray(overlay).save(filename)
hook.remove()
print(f"Saved attention maps to: {output_dir}")

Saved attention maps to: dino_attention_outputs
