In [2]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2

from models.modeling import VisionTransformer, CONFIGS
from baselines.ViT.ViT_explanation_generator import LRP

# Load the model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=24, zero_head=False, img_size=224, vis=True)
checkpoint = torch.load("output/test_checkpoint.pth")  # Load the checkpoint
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Load the image
im = Image.open("augmented_data_test/img8504_flip_rescale_augmented.jpg")

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Transform the image
x = transform(im).unsqueeze(0)

# Generate attention maps for each keypoint
attribution_generator = LRP(model)
keypoint_attention_maps = []

for i in range(12):
    # Generate attention map for each keypoint
    transformer_attribution = attribution_generator.generate_LRP(x, method="transformer_attribution", index=i).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = F.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    keypoint_attention_maps.append(transformer_attribution)

# Plot the attention maps for each keypoint
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for i, ax in enumerate(axes.flat):
    ax.set_title(f"Keypoint {i+1}")
    ax.imshow(im)
    ax.imshow(keypoint_attention_maps[i], alpha=0.7, cmap='hot')
    ax.axis('off')

plt.tight_layout()
plt.show()

ModuleNotFoundError: No module named 'models'