In [1]:
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models import create_model
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 引数の設定

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='vit_tiny_patch16_224', choices=['vit_tiny_patch16_224', 'vit_small_patch16_18x2_224'], type=str, help='model name')
parser.add_argument('--checkpoint', default='./ImageNet/tiny16/best_checkpoint.pth', type=str, help='checkpoint')
# args=[]を追加する。
args = parser.parse_args(args=[])
print(args)


Namespace(checkpoint='./ImageNet/tiny16/best_checkpoint.pth', model='vit_tiny_patch16_224')


In [3]:
# ViTモデルを読み込む
model = create_model(args.model, pretrained=False)
# 学習済みモデルを読み込む
#finetune = os.path.join('./path/to/file.pth')
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint["model"])
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()


In [4]:
# position embedding
# モデルから位置埋め込みを読み込む
# N:パッチ数+クラストークン、D:次元数
pos_embed = model.state_dict()['pos_embed'] # shape:(1, N, D)
H_and_W = int(math.sqrt(pos_embed.shape[1]-1)) # クラストークン分を引いて平方根をとる
# パッチ間のコサイン類似度を求め可視化
fig = plt.figure(figsize=(10, 10))
for i in range(1, pos_embed.shape[1]):
    sim = F.cosine_similarity(pos_embed[0, i:i+1], pos_embed[0, 1:], dim=1)
    sim = sim.reshape((H_and_W, H_and_W)).detach().cpu().numpy()
    ax = fig.add_subplot(H_and_W, H_and_W, i)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.imshow(sim)
plt.savefig("./position_embedding.pdf")
plt.clf()

<Figure size 1000x1000 with 0 Axes>