In [None]:
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# download some images
import kagglehub
path = kagglehub.dataset_download("adityajn105/flickr8k")
print("Path to dataset files:", path)

In [None]:
!ln -sf $path/Images ./images

In [None]:
!ls -l ./images/

In [None]:
imagepath = './images/978580450_e862715aba.jpg'

In [None]:
import cv2
image = cv2.imread(imagepath)
from google.colab.patches import cv2_imshow
cv2_imshow(image)

In [None]:
!wget http://www.agentspace.org/download/ViT-B_32.pth

In [None]:
# Load Model
import torch
model = torch.load('ViT-B_32.pth', weights_only=False)
model.eval()

In [None]:


# Test Image
im = Image.open("img.jpg")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
blob = transform(im).unsqueeze(0).to(device)
print('image', blob.shape)

# Call the transformer model

def embed(self, x):
    B = x.shape[0]
    cls_tokens = self.cls_token.expand(B, -1, -1) # nn.Parameters  1 x 1 x 768

    if self.hybrid:
        x = self.hybrid_model(x)

    x = self.patch_embeddings(x) # Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32)) -> 1 x 768 x 7 x 7
    x = x.flatten(2) # 1 x 768 x 49
    x = x.transpose(-1, -2) # 1 x 49 x 768
    x = torch.cat((cls_tokens, x), dim=1) # 1 x 50 x 768

    embeddings = x + self.position_embeddings #  1 x 50 x 768
    embeddings = self.dropout(embeddings)
    return embeddings

hidden_states = embed(model.transformer.embeddings, blob) # 1 x 50 x 768

def encode(self, hidden_states):
    attn_maps = []
    for layer_block in self.layer:
        hidden_states, coefs = layer_block(hidden_states)
        if self.vis:
            attn_maps.append(coefs)

    encoded = self.encoder_norm(hidden_states)
    return encoded, attn_maps

hidden_states, att_maps = encode(model.transformer.encoder, hidden_states) # 1 x 50 x 768, 12 x 50 x 50

def lmhead(self, hidden_states):
    logits = self.head(hidden_states[:, 0]) # Linear(in_features=768, out_features=1000, bias=True) 1 x 1 x 768 -> 1 x 1000
    return logits

logits = lmhead(model, hidden_states)

print('logits',logits.shape)
print('att maps',[att_map.shape for att_map in att_maps])

# Present probabilities of categories
probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
imagenet_labels = dict(enumerate(open('checkpoint/ilsvrc2012_wordnet_lemmas.txt')))
print("Prediction:")
for idx in top5[0, :5]:
    print(f'{probs[0, idx.item()]:.5f} : {imagenet_labels[idx.item()]}', end='')

# Present attention maps.
att_mat = torch.stack(att_maps).squeeze(1)
att_mat = torch.mean(att_mat, dim=1) # average through heads
att_mat = att_mat / att_mat.sum(dim=-1).unsqueeze(-1) # normalize
base = cv2.cvtColor(np.array(im),cv2.COLOR_RGB2GRAY)
grid_size = int(np.sqrt(att_mat.size(-1)))
for i, v in enumerate(att_mat):
    mask = v[0, 1:].reshape(grid_size,grid_size).detach().cpu().numpy()
    mask = cv2.resize(mask / mask.max(), (base.shape[1],base.shape[0]), interpolation=cv2.INTER_NEAREST)
    mask = (mask*255).astype(np.uint8)
    red = np.copy(mask)
    red[red < 127] = 0
    green = 255 - mask
    green[green < 127] = 0
    disp = cv2.merge([base,base|green,base|red])
    cv2.imwrite(f'outputs/att{i}.png',disp)
