<a href="https://colab.research.google.com/github/RyosukeHanaoka/TechTeacher_New/blob/main/Transformer_Explainability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#必要なモジュールをインポート
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision.transforms as transforms
from baselines.ViT.ViT_LRP import vit_tiny_16_224 as vit_LRP
from baselines.ViT.ViT_LRP import vit_base_16_224 as vit_base_LRP
from baselines.ViT.ViT_LRP import vit_large_16_224 as vit_large_LRP
from baselines.ViT.ViT_explainer import vit_tiny_16_224 as vit_explainer
from baselines.ViT.ViT_explainer import vit_base_16_224 as vit_base_explainer
from baselines.ViT.ViT_explaination_generator import LRP_VIS

In [None]:
#画像上のマスクからヒートマップを作成する関数"show_cam_on_image"を定義
def show_cam_on_image(img, mask):
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = np.float32(heatmap) / 255
  cam = heatmap + np.float32(img)
  cam = cam / np.max(cam)
  return cam

In [None]:
#ViTの読み込み
model = vit_LRP('/path/to/pretraned.pth')
model.eval()
#モデルの勾配を求める
attribution_generator = LRP_VIS(model)

In [None]:
#attention weightを取得するための関数"generate_visualization"を定義
def generate_visualization(original_image, class_index=None):
  #モデルの勾配とattention_rolloutからattention weightを求める。
  #N: バッチ数(196)
  transformer_attribution = attribution_generator.generate_LRP(original_image. unsqueeze(0).cuda(),method="transformer_attribution", index=class_index).detach()
  #14x14にリサイズしてバイリニア補間しながらattention mapを可視化
  transformer_attribution = transformer_attribution.reshape(1,1,14,14)
  transformer_attribution = torch.nn.functional.interpolate(transformer_attribution,scale_factor=16,mode='bilinear')
  transformer_attribution = transformer_attribution.reshape(224,224).cuda().data.cpu().numpy()
  transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
  image_transformer_attribution = original_image.permute(1,2,0).data.cpu().numpy()
  image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
  vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
  vis = np.uint8(255 * vis)
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
  return vis

#画像をリサイズしてセンタークロップ
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()]
)
#画像を読み込み、猫と犬のattention mapを可視化
image = Image.open('/path/to/image.jpg')
dog_cat_image = transform(image)

cat = generate_visualization(dog_cat_image, class_index=282)
plt.imshow(cat)
plt.show()

dog = generate_visualization(dog_cat_image, class_index=243)
plt.imshow(dog)
plt.show()