In [2]:
import torch
import timm
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 1: Load the Model
num_output_classes = 24  # Same as used during training

model = timm.create_model(
    "vit_base_patch8_224",
    pretrained=True,
    in_chans=1,
    num_classes=num_output_classes,
    patch_size=224,
)

model.load_state_dict(torch.load("models/best_model.pt"))
model.eval()
model = model.to(device)

In [3]:
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(224, 224), stride=(224, 224))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Ident

In [7]:
# Step 2: Prepare the Image
def preprocess_image(image_path):
    image = Image.open(image_path).convert("L")  # Convert to grayscale
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to the model's input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),  # Normalize for grayscale
    ])
    input_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return input_tensor.to(device)


In [8]:

# Step 3: Hook for Attention Weights
attention_maps = []

def get_attention_maps(module, input, output):
    attention_maps.append(output)

# Register hook to all attention layers
for name, module in model.named_modules():
    if "attn_drop" in name:  # Look for attention layers
        module.register_forward_hook(get_attention_maps)


In [11]:

# Step 4: Forward Pass to Extract Attention Maps
image_path = r"C:\Users\avs20\Documents\GitHub\facemap\cam1_G7c1_1_img2041_rescale_augmented.jpg"  # Replace with your image path
input_tensor = preprocess_image(image_path)

with torch.no_grad():
    x = model(input_tensor)  # Run the model to get attention weights


In [19]:
import torch
import timm
from torchvision.models.feature_extraction import get_graph_node_names

timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed

mm = timm.create_model(
    "vit_base_patch8_224",
    pretrained=True,
    in_chans=1,
    num_classes=num_output_classes,
    patch_size=224,
)
mm.load_state_dict(torch.load("models/best_model.pt"))
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n]
ff = timm.models.create_feature_extractor(mm, softmax_nodes)
with torch.no_grad():
    output = ff(torch.randn(768, 1, 224, 224))





In [20]:
#output = ff(input_tensor)
print(output)

{'blocks.0.attn.softmax': tensor([[[[7.0444e-01, 2.9556e-01],
          [9.9988e-01, 1.2252e-04]],

         [[9.9998e-01, 1.7074e-05],
          [2.4978e-18, 1.0000e+00]],

         [[1.0000e+00, 2.5649e-10],
          [3.3278e-07, 1.0000e+00]],

         ...,

         [[1.6865e-05, 9.9998e-01],
          [8.8699e-03, 9.9113e-01]],

         [[9.9995e-01, 4.6160e-05],
          [9.9918e-01, 8.1812e-04]],

         [[1.0000e+00, 2.7169e-06],
          [5.2816e-07, 1.0000e+00]]],


        [[[3.9648e-11, 1.0000e+00],
          [1.0000e+00, 2.5367e-13]],

         [[9.9883e-01, 1.1749e-03],
          [7.8896e-08, 1.0000e+00]],

         [[1.0000e+00, 5.2692e-08],
          [2.1705e-04, 9.9978e-01]],

         ...,

         [[2.0639e-07, 1.0000e+00],
          [1.0000e+00, 9.1595e-12]],

         [[9.9832e-01, 1.6774e-03],
          [3.1447e-15, 1.0000e+00]],

         [[9.9999e-01, 6.0735e-06],
          [8.8293e-09, 1.0000e+00]]],


        [[[1.0000e+00, 1.3219e-07],
          [1.000

In [27]:
print(output[[]])

TypeError: unhashable type: 'list'

In [22]:
last_layer_attention = output[:-1]  # Last attention layer
# Average across heads and reshape
attention_map = last_layer_attention.squeeze(0).mean(dim=0).detach().cpu().numpy()
attention_map = attention_map.reshape(14, 14)  # Reshape to (H, W)

# Visualize the attention map
plt.imshow(attention_map, cmap="viridis")
plt.colorbar()
plt.title("Attention Map from Last Layer")
plt.axis('off')  # Hide axes
plt.show()



KeyError: slice(None, -1, None)