In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2

from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from torchvision import utils

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
import numpy as np
import math
from tqdm.auto import tqdm
import math

from pytorch_grad_cam import GradCAM, \
                            ScoreCAM, \
                            GradCAMPlusPlus, \
                            AblationCAM, \
                            XGradCAM, \
                            EigenCAM, \
                            EigenGradCAM, \
                            LayerCAM, \
                            FullGrad
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

from Model import models_vit

IMG_SIZE = 448

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_cuda = torch.cuda.is_available()
device = torch.device('cpu')
use_cuda = False
device

device(type='cpu')

In [2]:
# image function

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def circle_crop(img_path, sigmaX=10, scale=1.0, img_size=224):   
    """
    Create circular crop around image centre 
    Scale(0~1) is a percentage of original image
    """    
    
    img = cv2.imread(img_path)
    #img = crop_image_from_gray(img)    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (img_size, img_size))
    
    height, width, depth = img.shape    
    
    x = int(width/2)
    y = int(height/2)
    r = np.amin((x,y)) * scale
    
    circle_img = np.zeros((height, width), np.uint8)
    cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
    # bitwise_and 来裁剪原始图像，得到一个圆形图像
    img = cv2.bitwise_and(img, img, mask=circle_img)
    #img = crop_image_from_gray(img)
    #img= cv2.addWeighted ( img,4, cv2.GaussianBlur( img , (0,0) , sigmaX) ,-4 ,128)
    return Image.fromarray(img)

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

In [3]:
# get attention: visualize_predict(model, img_path, patch_size, device, attn_layer=-1)

def get_last_selfattention(model, x, patch_size=16, attn_layer=-1):
    B = x.shape[0]
    x = model.patch_embed(x)

    cls_tokens = model.cls_token.expand(B, -1, -1) 
    x = torch.cat((cls_tokens, x), dim=1)
    x = x + model.pos_embed
    x = model.pos_drop(x)

    if attn_layer < 0:
        attn_layer = len(model.blocks) + attn_layer

    # attention blocks
    for i in range(len(model.blocks)):
        if i != attn_layer:
            x = model.blocks[i](x)
        else:
            x0 = x
            blk = model.blocks[i]
            attn = blk.attn
            x = blk.norm1(x)

            B, N, C = x.shape
            qkv = attn.qkv(x).reshape(B, N, 3, attn.num_heads, 64).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            # q, k = attn.q_norm(q), attn.k_norm(k)
            q = q * attn.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            x = attn @ v

            x = x.transpose(1, 2).reshape(B, N, C)
            #print(x.shape)
            break

    nh = attn.shape[1]  # number of head

    # keep only the output patch attention
    attentions = attn[0, :, 0, 1:].reshape(nh, -1)

    attentions = attentions.reshape(nh, 28, 28)
    attentions = nn.functional.interpolate(attentions.unsqueeze(
        0), scale_factor=patch_size, mode="nearest")[0].cpu().detach().numpy()
    return attentions


def plot_attention(img, attention):
    n_heads = attention.shape[0]
    n = math.ceil(math.sqrt(n_heads))
    plt.figure(figsize=(10, 10))
    text = ["Original Image", "Head Mean Attention"]
    mean_attention = np.mean(attention, 0)
    # img = show_cam_on_image(img, mean_attention)
    for i, fig in enumerate([img, mean_attention]):
        plt.subplot(1, 2, i+1)
        plt.imshow(fig, cmap='inferno')
        plt.title(text[i])
    plt.show()

    plt.figure(figsize=(10, 10))
    for i in range(n_heads):
        plt.subplot(n_heads//n, n, i+1)
        plt.imshow(attention[i], cmap='inferno')
        plt.title(f"Head n: {i+1}")
    plt.tight_layout()
    plt.show()



def visualize_predict(model, img_path, patch_size, device, attn_layer=-1):
    '''
    preprocess img and visualize the attention map
    '''
    # img = Image.open(img_path)
    img_no_crop = Image.open(img_path).resize((IMG_SIZE, IMG_SIZE))
    img = circle_crop(img_path, scale=0.9, img_size=IMG_SIZE)

    img = np.array(img) / 255.
    img0 = img_no_crop
    img = img - imagenet_mean
    img = img / imagenet_std
    x = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device).float()

    y = model(x)
    pred = torch.max(nn.Softmax(dim=1)(y), dim=1)[1].item()
    print(pred)
    
    attention = get_last_selfattention(model, x, patch_size=patch_size, attn_layer=attn_layer)
    plot_attention(img0, attention)
    return img0, np.mean(attention, 0)

def reshape_transform(tensor, height=28, width=28):
    result = tensor[:, 1:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))
    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result


# visualize_predict(model, 
#                 '/data/home/liuchunyu/code/uwf/wuhan/whxh_011775-20210306@152058-R1-S.png',
#                 patch_size=16, device=device,
#                 attn_layer=-1)

In [2]:
# load model weight
checkpoint = torch.load('./checkpoints/finetune_pplhk_pretrain_ffm_checkpoint-best.pth', map_location='cpu')

model = models_vit.__dict__['vit_large_patch16'](
        img_size=448,
        num_classes=2,
        drop_path_rate=0.15,
        global_pool=True
    )
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model.eval()


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(


In [3]:
len(model.state_dict())

296

In [7]:
# test partial parameter update {key: params}
import random
random.seed(0)
shared_weight_id = random.sample(range(len(model.state_dict())), int(len(model.state_dict()) * 0.1))
params = {key:val.cpu().numpy() for key, val in model.state_dict().items() if list(model.state_dict()).index(key) in shared_weight_id}
len(params.keys()), params

(29,
 {'blocks.1.attn.proj.weight': array([[ 4.4117227e-02, -1.0261547e-02,  3.6659241e-02, ...,
          -1.4531501e-02,  1.9236078e-03, -2.7613488e-03],
         [-1.3469950e-02,  2.1114601e-02, -1.5358692e-05, ...,
          -1.8817597e-03, -1.7974429e-02, -5.6418002e-02],
         [ 3.1448521e-02,  1.7557865e-02, -3.1185860e-02, ...,
           1.6667198e-03, -3.4617085e-02, -2.4910484e-02],
         ...,
         [-2.8154265e-02,  1.7749771e-02, -1.3607264e-02, ...,
          -1.6493179e-02,  3.3368211e-02, -1.3745355e-02],
         [-2.7519224e-02, -5.3474687e-02, -1.5414399e-01, ...,
           9.1188490e-02,  1.7152630e-02,  2.7798945e-02],
         [-1.7259818e-02, -4.6782430e-02, -9.9382447e-03, ...,
          -5.0568726e-02,  4.5884568e-02, -7.7889143e-03]], dtype=float32),
  'blocks.2.mlp.fc1.bias': array([-0.42142227, -1.3955742 , -0.3453499 , ..., -0.5576897 ,
         -1.1525309 , -2.6160698 ], dtype=float32),
  'blocks.3.mlp.fc1.weight': array([[-0.02059307, -0.0003176

In [27]:
# test partial parameter update
import random
random.seed(0)
shared_weight_id = random.sample(range(len(model.state_dict())), int(len(model.state_dict()) * 0.1))
params = [val.cpu().numpy() for _, val in model.state_dict().items()]
parameters = [params[id]*10 for id in shared_weight_id]
parameters

[array([-0.10773506,  0.31134847, -0.00285576, ..., -0.5876403 ,
         0.12584244, -0.01486699], dtype=float32),
 array([-0.5522769 , -1.3769426 , -0.45492682, ...,  1.1340352 ,
         0.7825753 ,  0.64382696], dtype=float32),
 array([[ 4.4117227e-01, -1.0261547e-01,  3.6659241e-01, ...,
         -1.4531501e-01,  1.9236078e-02, -2.7613489e-02],
        [-1.3469951e-01,  2.1114601e-01, -1.5358691e-04, ...,
         -1.8817598e-02, -1.7974429e-01, -5.6418002e-01],
        [ 3.1448519e-01,  1.7557865e-01, -3.1185859e-01, ...,
          1.6667198e-02, -3.4617084e-01, -2.4910483e-01],
        ...,
        [-2.8154266e-01,  1.7749771e-01, -1.3607264e-01, ...,
         -1.6493179e-01,  3.3368212e-01, -1.3745356e-01],
        [-2.7519223e-01, -5.3474689e-01, -1.5414399e+00, ...,
          9.1188490e-01,  1.7152630e-01,  2.7798945e-01],
        [-1.7259818e-01, -4.6782431e-01, -9.9382445e-02, ...,
         -5.0568724e-01,  4.5884567e-01, -7.7889144e-02]], dtype=float32),
 array([[ 0.261139

In [28]:
for i,id in enumerate(shared_weight_id):
    tmp = model.state_dict()[list(model.state_dict().keys())[id]].data
    model.state_dict()[list(model.state_dict().keys())[id]].data = torch.tensor(parameters[i])
    print(model.state_dict()[list(model.state_dict().keys())[id]].data - tmp)

tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([0., 0., 0.,  ..., 0., 0., 0.])
tens

In [44]:
from collections import OrderedDict
state_dict = {}
key_list = list(model.state_dict().keys())
for i,id in enumerate(key_list):
    if i in shared_weight_id:
        state_dict[id] = torch.tensor(parameters[shared_weight_id.index(i)])
    else:
        state_dict[id] = model.state_dict()[id].data
state_dict = OrderedDict(state_dict)
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [4]:
from Model import models_mae
model = models_mae.__dict__['mae_vit_large_patch16'](
        img_size=448,
    )
checkpoint = torch.load('./checkpoints/pretrain_distill_ffm_checkpoint-best.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [5]:
len(model.state_dict())

398

In [4]:
for i, id in enumerate(model.state_dict().keys()):
    print(i, id, model.state_dict()[id].shape)

0 cls_token torch.Size([1, 1, 1024])
1 pos_embed torch.Size([1, 785, 1024])
2 patch_embed.proj.weight torch.Size([1024, 3, 16, 16])
3 patch_embed.proj.bias torch.Size([1024])
4 blocks.0.norm1.weight torch.Size([1024])
5 blocks.0.norm1.bias torch.Size([1024])
6 blocks.0.attn.qkv.weight torch.Size([3072, 1024])
7 blocks.0.attn.qkv.bias torch.Size([3072])
8 blocks.0.attn.proj.weight torch.Size([1024, 1024])
9 blocks.0.attn.proj.bias torch.Size([1024])
10 blocks.0.norm2.weight torch.Size([1024])
11 blocks.0.norm2.bias torch.Size([1024])
12 blocks.0.mlp.fc1.weight torch.Size([4096, 1024])
13 blocks.0.mlp.fc1.bias torch.Size([4096])
14 blocks.0.mlp.fc2.weight torch.Size([1024, 4096])
15 blocks.0.mlp.fc2.bias torch.Size([1024])
16 blocks.1.norm1.weight torch.Size([1024])
17 blocks.1.norm1.bias torch.Size([1024])
18 blocks.1.attn.qkv.weight torch.Size([3072, 1024])
19 blocks.1.attn.qkv.bias torch.Size([3072])
20 blocks.1.attn.proj.weight torch.Size([1024, 1024])
21 blocks.1.attn.proj.bias torc

In [14]:
params = [val.cpu().numpy() for _, val in model.state_dict().items()]


[array([[[0.00709271, 0.00508971, 0.03333836, ..., 0.00427388,
          0.00245943, 0.00645396]]], dtype=float32),
 array([[[ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ],
         [-0.08874889, -0.08668015, -0.0845819 , ...,  1.        ,
           1.        ,  1.        ],
         [ 0.18826124,  0.18217435,  0.17620367, ...,  1.        ,
           1.        ,  1.        ],
         ...,
         [-0.27093548, -0.6430575 , -0.891936  , ...,  0.999999  ,
           0.99999905,  0.9999991 ],
         [ 0.21969765, -0.20620184, -0.58193254, ...,  0.999999  ,
           0.9999991 ,  0.99999917],
         [ 0.52107316,  0.05971598, -0.39567512, ...,  0.999999  ,
           0.99999905,  0.9999991 ]]], dtype=float32),
 array([[[-1.66367907e-02, -3.89367044e-02,  1.14553720e-02,
           1.10435886e-02, -1.39265591e-02,  5.09065995e-03,
          -1.64080150e-02, -4.33897693e-03, -1.35082752e-02,
          -2.06270032e-02, -9.86519456e-02, -

In [7]:
for i, x in enumerate(model.state_dict()):
    print(i, x)

0 cls_token
1 pos_embed
2 patch_embed.proj.weight
3 patch_embed.proj.bias
4 blocks.0.norm1.weight
5 blocks.0.norm1.bias
6 blocks.0.attn.qkv.weight
7 blocks.0.attn.qkv.bias
8 blocks.0.attn.proj.weight
9 blocks.0.attn.proj.bias
10 blocks.0.norm2.weight
11 blocks.0.norm2.bias
12 blocks.0.mlp.fc1.weight
13 blocks.0.mlp.fc1.bias
14 blocks.0.mlp.fc2.weight
15 blocks.0.mlp.fc2.bias
16 blocks.1.norm1.weight
17 blocks.1.norm1.bias
18 blocks.1.attn.qkv.weight
19 blocks.1.attn.qkv.bias
20 blocks.1.attn.proj.weight
21 blocks.1.attn.proj.bias
22 blocks.1.norm2.weight
23 blocks.1.norm2.bias
24 blocks.1.mlp.fc1.weight
25 blocks.1.mlp.fc1.bias
26 blocks.1.mlp.fc2.weight
27 blocks.1.mlp.fc2.bias
28 blocks.2.norm1.weight
29 blocks.2.norm1.bias
30 blocks.2.attn.qkv.weight
31 blocks.2.attn.qkv.bias
32 blocks.2.attn.proj.weight
33 blocks.2.attn.proj.bias
34 blocks.2.norm2.weight
35 blocks.2.norm2.bias
36 blocks.2.mlp.fc1.weight
37 blocks.2.mlp.fc1.bias
38 blocks.2.mlp.fc2.weight
39 blocks.2.mlp.fc2.bias
40

In [23]:
import random
indices = random.sample(range(296), 32)

In [26]:
for i, val in enumerate(model.state_dict().items()):

    print(i, val)

0 ('cls_token', tensor([[[0.0071, 0.0051, 0.0333,  ..., 0.0043, 0.0025, 0.0065]]]))
1 ('pos_embed', tensor([[[-1.0481e-06, -1.1606e-06, -5.6500e-07,  ..., -8.9698e-07,
          -4.9777e-07,  7.4613e-07],
         [-8.8749e-02, -8.6680e-02, -8.4582e-02,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00],
         [ 1.8826e-01,  1.8217e-01,  1.7620e-01,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00],
         ...,
         [-2.7094e-01, -6.4306e-01, -8.9194e-01,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00],
         [ 2.1970e-01, -2.0620e-01, -5.8193e-01,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00],
         [ 5.2107e-01,  5.9716e-02, -3.9568e-01,  ...,  1.0000e+00,
           1.0000e+00,  1.0000e+00]]]))
2 ('patch_embed.proj.weight', tensor([[[[-1.2844e-02, -1.7784e-03, -1.4400e-02,  ..., -2.7521e-03,
            6.2393e-03,  3.7619e-03],
          [ 4.3625e-03, -1.9412e-03, -1.0231e-02,  ...,  2.1845e-03,
            1.4907e-03, -3.0791e-03],
         

In [None]:
# show attention map
dir = '/data/home/liuchunyu/code/UWFound/labeled_data/uwf_dr_labeled_senior/train/class_3/'
import pathlib
j = 0
for file in pathlib.Path(dir).iterdir():
        model.zero_grad()
        
        img_path = os.path.join(dir, file.name)
        img0, attn = visualize_predict(model, img_path, patch_size=16, device=device, attn_layer=-1)
        
        j += 1
        if j > 2:
            break

        

In [None]:
# prepare CAM
def reshape_transform(tensor, height=28, width=28):
    result = tensor[:, 1:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))
    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

# 创建 GradCAM 对象
cam = GradCAM(model=model,
                target_layers=[model.blocks[-1].norm1],
              #target_layers=[model.blocks[18].norm1],
                # 这里的target_layer要看模型情况，
                # 比如还有可能是：target_layers = [model.blocks[-1].ffn.norm]
                use_cuda=False,
                reshape_transform=reshape_transform
            )

In [None]:
#img_path = '/data/home/liuchunyu/code/UWFound/labeled_data/uwf_dr_labeled_junior_japan/test/class_1/004442_01.jpg'
img = circle_crop(img_path, img_size=IMG_SIZE)
img = np.array(img) / 255.
img0 = np.array(Image.open(img_path).resize((IMG_SIZE, IMG_SIZE))) / 255.
img = img - imagenet_mean
img = img / imagenet_std
x = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device).float()

# 计算 grad-cam

target_category = 1 # 可以指定一个类别，或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=x)
grayscale_cam = grayscale_cam[0, :]



In [None]:
# 将 grad-cam 的输出叠加到原始图像上
visualization = show_cam_on_image(img0, grayscale_cam, image_weight=0.6, colormap=cv2.COLORMAP_INFERNO)

# 保存可视化结果
plt.title('Class Activation Map')
plt.imshow(visualization)