# Visualize Attention Maps

In [1]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests

In [2]:
# Load pre-trained ViT model and feature extractor
model_name = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

model

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [4]:
# Load and preprocess an image
image_url = "https://www.southernliving.com/thmb/Rz-dYEhwq_82C5_Y9GLH2ZlEoYw=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/gettyimages-837898820-1-4deae142d4d0403dbb6cb542bfc56934.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")

In [6]:
# Make prediction
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
# Get attention maps from the last layer
# last_layer_attention = outputs.attentions[-1][0]  

Predicted class: LABEL_1


In [7]:
import torch
import numpy as np
import cv2
from torchvision import transforms

# Define helper functions for normalization and visualization
def normalize_attention(attention):
  """
  Normalize attention map values between 0 and 1.
  """
  return attention.softmax(dim=-1)

def visualize_attention(attention, image):
  """
  Overlay attention map on the original image for visualization.
  """
  # Resize attention map to match image size
  attention = torch.nn.functional.interpolate(attention, size=image.shape[2:], mode='bilinear')
  # Convert to numpy array and normalize
  attention = attention.cpu().numpy().squeeze()
  attention = normalize_attention(attention)
  # Apply heatmap effect
  heatmap = cv2.applyColorMap((attention * 255).astype(np.uint8), cv2.COLORMAP_JET)
  # Overlay heatmap on original image
  image = image.numpy().transpose((1, 2, 0))  # Convert to HWC format
  image = cv2.addWeighted(heatmap, 0.5, image, 0.5, 0)
  return image.astype(np.uint8)

# Load pre-trained ViT model
model = torch.hub.load('google/vit_base_patch16_224', 'base')

# Preprocess image
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = cv2.imread("image.png")  # Replace with your image path
image = transform(image)
image = image.unsqueeze(0)  # Add batch dimension

# Extract attention maps (modify based on specific model architecture)
with torch.no_grad():
  features = model.forward_features(image)
  attention_maps = features[1]  # Assuming attention maps are in the second element

# Visualize attention maps for each head in the first transformer block
for head in range(attention_maps.shape[1]):
  attention_map = attention_maps[0, head, ...]  # Select specific head
  visualized_image = visualize_attention(attention_map, image)
  cv2.imshow(f"Attention Map (Head {head+1})", visualized_image)
  cv2.waitKey(0)

cv2.destroyAllWindows()


