In [1]:

import einops
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS

import cv2

In [2]:

# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=24, zero_head=False, img_size=224, vis=True)
checkpoint = torch.load("output/test_checkpoint.pth",map_location=torch.device('cpu'))  # Load the checkpoint #delete map_location=torch.device('cpu') if run on GPU
model.load_state_dict(checkpoint['state_dict'])
print(model)

model.eval()

transform = transforms.Compose([
    #transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

imageName = 'img6725_rotate_rescale_augmented.jpg'
im = Image.open("augmented_data_test/" + imageName)
x = transform(im)
x.size()

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0-11): 12 x Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=Tru

torch.Size([3, 224, 228])

In [None]:
#pull out the logits (i believe key points) from the model, and the attention maps. x.unsqueeze(0) adds a new dimension at position 0, effectively converting x into a batch with a single item. If x originally had the shape [C, H, W] (where C is channels, H is height, W is width), after unsqueeze(0), it would have the shape [1, C, H, W]. model(...) runs the input through the Vision Transformer model.
#The model returns two outputs: logits: The raw output from the model's head (likely KP scores). att_mat: The attention matrices from the attention layers in the model. This is usually a list of attention maps from different layers of the transformer.
logits, att_mat = model(x.unsqueeze(0))

#torch.stack(att_mat) converts this list of matrices into a single tensor by stacking them along a new dimension. If att_mat is a list of tensors with shapes [B, N, N] (where B is batch size, and N is the number of tokens), then torch.stack(att_mat) will have shape [L, B, N, N] where L is the number of layers.
att_mat = torch.stack(att_mat).squeeze(1)
#torch.Size([12, 12, 197, 197])

# Average the attention weights across all heads.
#second dimension(1) (12): This represents the number of attention heads in each transformer layer. Each layer in a multi-head attention mechanism typically has multiple heads (in this case, 12), and each head produces its own attention matrix.
att_mat = torch.mean(att_mat, dim=1)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
# This function creates a 2D identity matrix of size 197 x 197. In an identity matrix, all the diagonal elements are 1, and all off-diagonal elements are 0. This matrix represents a situation where each token attends only to itself with full attention (a residual connection).
residual_att = torch.eye(att_mat.size(1))


# When you add residual_att (which has shape [197, 197]) to att_mat (which has shape [12, 12, 197, 197]), PyTorch automatically broadcasts the residual_att matrix across the first two dimensions (12 layers and 12 heads).
# This means that residual_att is added to each [197, 197] attention matrix in att_mat.
# Result (aug_att_mat):
# Shape: [12, 12, 197, 197]
# The resulting aug_att_mat tensor now contains the original attention values from att_mat plus 1s along the diagonal for each of the attention matrices (due to the addition of the identity matrix from residual_att).
# This effectively means that each token has a stronger self-attention component since the diagonal elements (representing self-attention) have been incremented by 1.
# Why Do This?
# Adding residual_att to att_mat augments the attention matrices by reinforcing the self-attention mechanism. This can help stabilize training and ensure that each token's original information is preserved across layers. In essence, even after the attention mechanism mixes information from other tokens, each token retains some of its original identity (as reflected by the diagonal elements).
aug_att_mat = att_mat + residual_att


#THIS STEP I AM MORE IN DOUBT ABOUT
# Step-by-Step Explanation:
# aug_att_mat.sum(dim=-1):
# dim=-1 refers to the last dimension of aug_att_mat, which is 197 (the sequence length, or the number of tokens).
# The sum(dim=-1) operation sums the elements along this last dimension for each attention matrix.
# The result is a tensor of shape [12, 12, 197], where each element is the sum of the attention scores for a particular token across all other tokens (including itself).
# .unsqueeze(-1):
# This adds a new dimension at the end of the tensor resulting from the sum operation.
# The shape of the tensor changes from [12, 12, 197] to [12, 12, 197, 1].
# This is necessary for broadcasting during the division operation that follows.
# aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1):
# Here, element-wise division is performed between aug_att_mat and the summed attention scores.
# The broadcasting mechanism allows the [12, 12, 197, 1] tensor to be divided across the [12, 12, 197, 197] tensor.
# This operation normalizes each attention matrix along its last dimension so that the sum of attention scores for each token across all tokens (including itself) equals 1.

aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
# joint_attentions = torch.zeros(aug_att_mat.size()):
# Creates a tensor joint_attentions initialized to zeros with the same size as aug_att_mat, which is [12, 12, 197, 197].
# This tensor will store the cumulative attention matrices for each layer.
joint_attentions = torch.zeros(aug_att_mat.size())

# joint_attentions[0] = aug_att_mat[0]:
# The first entry in joint_attentions (corresponding to the first layer) is set equal to the first augmented attention matrix (aug_att_mat[0]).
# This means that for the first layer, the joint attention is simply the attention of that layer.
joint_attentions[0] = aug_att_mat[0]



# . Recursive Multiplication:

# for n in range(1, aug_att_mat.size(0)):
# Iterates over each layer starting from the second one (n=1) to the last one (n=11).
# joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]):
# For each layer n, the code multiplies the current layer's attention matrix (aug_att_mat[n]) by the cumulative attention matrix from the previous layer (joint_attentions[n-1]).
# torch.matmul(...): This performs matrix multiplication, which combines the current attention matrix with the joint attention from the previous layer.
# The result is stored in joint_attentions[n].
# 3. What This Achieves:

# Cumulative Attention Across Layers:
# The matrix multiplication accumulates the effects of attention across multiple layers.
# joint_attentions[n] captures how attention is distributed from the input to the current layer n through the entire sequence of preceding layers.
# By the time you reach the last layer, joint_attentions[-1] will represent the overall attention map that considers the cumulative effect of all layers.
# Summary:
# joint_attentions is a tensor that stores the cumulative attention maps for each layer in the Vision Transformer.
# The recursive multiplication effectively traces how attention flows through the network from the input layer to each subsequent layer, providing a way to visualize or analyze the overall attention mechanism across the entire model.

for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]


#mask * im: Multiplies the attention mask with the original image im, element-wise. This highlights the regions of the image that the model attends to, with more attention resulting in brighter areas.
result = (mask * im).astype("uint8")

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
#_ = ax2.imshow(result)
_ = ax2.imshow(mask) #trying to use mask instead so it doesn't just show the bright areas



In [None]:
#iterate over attention layers

for i, v in enumerate(joint_attentions):
    # Attention from the output token to the input space.
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
    #result = (mask * im).astype("uint8")

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(im)
    #_ = ax2.imshow(result)
    _ = ax2.imshow(mask)

# Each KP individually



In [4]:
logits, att_mat = model(x.unsqueeze(0))
att_mat = torch.stack(att_mat).squeeze(1)

att_mat = torch.mean(att_mat, dim=1)
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])



Input tensor size: torch.Size([1, 197, 768])


In [6]:
# Assuming `im` is the image you want to visualize the attention on
grid_size = int(np.sqrt(aug_att_mat.size(-1)))  # Assuming grid size 14x14 for 197 tokens

for kp_index in range(12):  # 12 keypoints (each with x, y)
    key_point_name = f"Key Point {kp_index + 1}"
    for layer_index, v in enumerate(joint_attentions):
        # Attention from the output token to the input space
        mask_x = v[0, kp_index * 2 + 1].reshape(grid_size, grid_size).detach().cpu().numpy()
        mask_y = v[0, kp_index * 2 + 2].reshape(grid_size, grid_size).detach().cpu().numpy()
        
        # Average the attention masks for the x and y coordinates
        mask = (mask_x + mask_y) / 2.0
        mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
        result = (mask * im).astype("uint8")

        fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
        ax1.set_title(f'Original - {key_point_name}')
        ax2.set_title(f'Attention Map for {key_point_name} - Layer {layer_index + 1}')
        _ = ax1.imshow(im)
        _ = ax2.imshow(result)
        plt.show()

RuntimeError: shape '[14, 14]' is invalid for input of size 1