In [25]:
import argparse
import torch
import lightning as L
from config import load_config
from models.classifier import Classifier
from datasets import instentiate_dataloader
from fvcore.nn import FlopCountAnalysis, flop_count_table
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from torchvision.transforms import transforms

cfg = "lightning_logs/version_166775/hparams.yaml"
new_checkpoint = "lightning_logs/version_166775/checkpoints/epoch=199-step=17000.ckpt"
checkpoint = "lightning_logs/version_166775/VIT_B_16.ckpt"
cfg = load_config(cfg)

In [26]:
tm = torch.load(checkpoint, map_location="cpu")
for param in tm.parameters():
    param.requires_grad = False

model = Classifier(cfg)

if new_checkpoint:
    model = Classifier.load_from_checkpoint(new_checkpoint)
model.eval()
for param in model.parameters():
    param.requires_grad = False

In [34]:
img = Image.open("./data/flowers-17/jpg/image_0376.jpg")

transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

x = transform(img)

# out, att_mat = tm.get_submodule("backbone")(x.unsqueeze(0))
out, att_mat = model.backbone(x.unsqueeze(0))

att_mat = [x.unsqueeze(0) for x in att_mat]

att_mat = torch.stack(att_mat).squeeze(1)

att_mat = torch.mean(att_mat, dim=1)

residual_att = torch.eye(att_mat.size(1))

aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])

grid_size = int(np.sqrt(aug_att_mat.size(-1)))

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
for i, v in enumerate(joint_attentions):
    if i == 0:
        continue
    # Attention from the output token to the input space.
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
    red_mask = np.zeros_like(mask)
    red_mask[..., 0] = 255  # Set the red channel to 255
    result = (mask * img + red_mask).astype("uint8")



    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(img)
    _ = ax2.imshow(mask)