In [3]:
import torch
import timm
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 1: Load the Model
num_output_classes = 24  # Same as used during training

model = timm.create_model(
    "vit_base_patch8_224",
    pretrained=True,
    in_chans=1,
    num_classes=num_output_classes,
    patch_size=224,
)

model.load_state_dict(torch.load("models/best_model.pt", map_location=torch.device('cpu')))
model.eval()
model = model.to(device)

https://www.kaggle.com/code/arnavs19/attention-rollout-for-vision-transformers

In [4]:
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(224, 224), stride=(224, 224))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Ident

In [3]:
# Step 2: Prepare the Image
def preprocess_image(image_path):
    image = Image.open(image_path).convert("L")  # Convert to grayscale
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to the model's input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),  # Normalize for grayscale
    ])
    input_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return input_tensor.to(device)


In [4]:

# Step 3: Hook for Attention Weights
attention_maps = []

def get_attention_maps(module, input, output):
    attention_maps.append(output)

# Register hook to all attention layers
for name, module in model.named_modules():
    if "attn_drop" in name:  # Look for attention layers
        module.register_forward_hook(get_attention_maps)


In [5]:

# Step 4: Forward Pass to Extract Attention Maps
image_path = r"C:\Users\avs20\Documents\GitHub\facemap\cam1_G7c1_1_img2041_rescale_augmented.jpg"  # Replace with your image path
input_tensor = preprocess_image(image_path)

with torch.no_grad():
    x = model(input_tensor)  # Run the model to get attention weights


  x = F.scaled_dot_product_attention(


In [10]:
import torch
import timm
from torchvision.models.feature_extraction import get_graph_node_names

# Step 1: Load the Model
num_output_classes = 24  # Same as used during training


timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed

mm = timm.create_model(
    "vit_base_patch8_224",
    pretrained=True,
    in_chans=1,
    num_classes=num_output_classes,
    patch_size=224,
)
mm.load_state_dict(torch.load("models/best_model.pt"))
mm.to('cuda')
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n]
ff = timm.models.create_feature_extractor(mm, softmax_nodes)
ff.to('cuda')
#with torch.no_grad():
#    output = ff(torch.randn(768, 1, 224, 224))





VisionTransformer(
  (patch_embed): Module(
    (proj): Conv2d(1, 768, kernel_size=(224, 224), stride=(224, 224))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Module(
    (0): Module(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Module(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Module(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
 

In [11]:
output = ff(input_tensor)
#print(output)

https://gist.github.com/rwightman/dbb5a8222df173687d734ad5e257908b

In [13]:
print(output)

output['blocks.11.attn.softmax']

{'blocks.0.attn.softmax': tensor([[[[1.0000e+00, 1.6815e-31],
          [1.0000e+00, 7.2125e-24]],

         [[1.0000e+00, 2.5425e-15],
          [1.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 3.7143e-10],
          [3.5318e-09, 1.0000e+00]],

         [[9.9959e-01, 4.1461e-04],
          [1.1311e-22, 1.0000e+00]],

         [[1.0000e+00, 2.0082e-31],
          [1.0000e+00, 9.9221e-23]],

         [[1.0000e+00, 1.4859e-08],
          [2.5809e-09, 1.0000e+00]],

         [[1.0000e+00, 1.6077e-12],
          [1.5816e-21, 1.0000e+00]],

         [[1.0000e+00, 2.5624e-16],
          [2.5563e-31, 1.0000e+00]],

         [[1.0000e+00, 4.4911e-12],
          [1.4378e-03, 9.9856e-01]],

         [[1.0000e+00, 1.5622e-07],
          [4.4129e-01, 5.5871e-01]],

         [[1.0000e+00, 8.4027e-07],
          [0.0000e+00, 1.0000e+00]],

         [[1.0000e+00, 1.4518e-22],
          [8.1098e-20, 1.0000e+00]]]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>), 'blocks.1.attn.softmax': tensor(

In [15]:
type(output)

dict

In [None]:
last_layer_attention = output[:-1]  # Last attention layer
# Average across heads and reshape
attention_map = last_layer_attention.squeeze(0).mean(dim=0).detach().cpu().numpy()
attention_map = attention_map.reshape(14, 14)  # Reshape to (H, W)

# Visualize the attention map
plt.imshow(attention_map, cmap="viridis")
plt.colorbar()
plt.title("Attention Map from Last Layer")
plt.axis('off')  # Hide axes
plt.show()

