In [None]:
import torch
from transformers import ViTModel, ViTFeatureExtractor

device = 'cuda' if torch.cuda.is_available() else 'cpu'

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
model.eval()


: 

In [None]:
outputs = model(pixel_values=inputs, output_attentions=True)
attentions = outputs.attentions  # A tuple/list of attention layers

In [None]:
from PIL import Image
import requests

url = "https://example.com/your_image.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt").to(device)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

def visualize_attention_on_image(image, attention_map):
    # attention_map: [14x14] for example
    attention_map = cv2.resize(attention_map, (image.width, image.height))
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    img_np = np.array(image)
    heatmap = cv2.applyColorMap((attention_map*255).astype(np.uint8), cv2.COLORMAP_JET)
    blended = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
    
    plt.figure(figsize=(8,8))
    plt.imshow(blended)
    plt.axis('off')
    plt.show()

layer_to_visualize = 5
head_to_visualize = 0
cls_attn = attentions[layer_to_visualize][0, head_to_visualize, 0, 1:].reshape(14,14).cpu().numpy()
visualize_attention_on_image(image, cls_attn)


: 

In [None]:
from sklearn.cluster import KMeans

def extract_all_heads_attn(attentions, layer):
    num_heads = attentions[layer].shape[1]
    head_attns = []
    for h in range(num_heads):
        attn_map = attentions[layer][0, h, 0, 1:].cpu().numpy().reshape(-1)
        head_attns.append(attn_map)
    return np.stack(head_attns) # shape: [num_heads, #patches]

head_attns = extract_all_heads_attn(attentions, layer_to_visualize)
kmeans = KMeans(n_clusters=3)
clusters = kmeans.fit_predict(head_attns)

In [None]:
def mask_top_patches(image, attention_map, top_k=5):
    flat_attn = attention_map.flatten()
    top_indices = np.argsort(flat_attn)[-top_k:]