In [2]:
from glob import glob
import os
from natsort import os_sorted

In [3]:
dataset_dir = 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/'
class_name = ['tumor', 'normal'][0]

img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg"))
img_paths = os_sorted(img_paths)

In [4]:
from dino.vision_transformer import DINOHead, VisionTransformer
from dino.vim.models_mamba import VisionMamba
from dino.config import configurations
from dino.main import get_args_parser
from functools import partial
from dino.utils import load_pretrained_weights
from torchvision import transforms
from torch import nn
import torch


In [5]:
parser = get_args_parser()
args = parser.parse_known_args()[0]

In [6]:
def get_model(args):

    config = configurations[args.arch]
    config['img_size'] = args.image_size
    config['patch_size'] = args.patch_size
    config['num_classes'] = args.num_classes
    if args.arch in configurations:
        config = configurations[args.arch]
        config['img_size'] = args.image_size
        config['patch_size'] = args.patch_size
        config['num_classes'] = args.num_classes

        if 'norm_layer' in config and config['norm_layer'] == "nn.LayerNorm":
            config['norm_layer'] = partial(nn.LayerNorm, eps=config['eps'])
        config['drop_path_rate'] = 0  
        if args.arch.startswith('vim'):
            model = VisionMamba(return_features=True, **config)
            embed_dim = model.embed_dim
        elif args.arch.startswith('vit'):
            model = VisionTransformer(**config)
            embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
        print('EMBEDDED DIM:', embed_dim)
    else:
        print(f"Unknown architecture: {args.arch}")
    return model

In [7]:
args.image_size = 224
args.patch_size = 16
args.num_classes = 2
args.n_last_blocks = 4
args.avgpool_patchtokens = False

args.checkpoint_key = 'teacher'

args.arch = 'vim-s'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-s_224-96/checkpoint.pth'
model_vim = get_model(args)
model_vim.cuda()
model_vim.eval()
load_pretrained_weights(model_vim, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

args.arch = 'vit-s'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth'
model_vit = get_model(args)
model_vit.cuda()
model_vit.eval()
load_pretrained_weights(model_vit, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)



args.arch = 'vit-s'
args.pretrained_weights = '/home/ubuntu/checkpoints/imagenet/dino_deitsmall16_pretrain.pth'
model_vit2 = get_model(args)
model_vit2.cuda()
model_vit2.eval()
load_pretrained_weights(model_vit2, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)


EMBEDDED DIM: 384
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vim-s_224-96/checkpoint.pth and loaded with msg: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
EMBEDDED DIM: 1536
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/chec

In [8]:
val_transform = transforms.Compose([
    transforms.Resize(args.image_size, interpolation=3),
    transforms.CenterCrop(args.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [9]:
from PIL import Image
import torchvision
import matplotlib.pyplot as plt
import numpy as np

In [10]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


In [11]:
def reshape_transform_vit(tensor, height=14, width=14):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))

    result = result.transpose(2, 3).transpose(1, 2)
    return result

In [12]:
def reshape_transform_vim(tensor, height=14, width=14, token_position=98):
    hidden_state = tensor
    hidden_state = torch.cat((hidden_state[:, 1:token_position, :], hidden_state[:, token_position+1:, :]), dim=1)
    result = hidden_state.reshape(hidden_state.size(0), height, width, hidden_state.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

In [13]:
os.makedirs('heatmaps', exist_ok=True)

In [22]:
from tqdm import tqdm
import random

['dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_090_742.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_026_123.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_090_635.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_021_1345.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_105_2053.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_021_1228.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_105_6.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_105_628.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_105_1446.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_021_694.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_090_361.jpg',
 'dataset/Camelyon16/Extracted_labels/cam16_balanced/test/tumor/test_105_2162.jpg',
 

In [27]:
for i in tqdm(range(len(random.sample(img_paths, 600)))):
    img = Image.open(img_paths[i])
    img_transformed = val_transform(img).unsqueeze(0)
    cam_vim = GradCAM(model=model_vim, target_layers=[model_vim.layers[-1].drop_path],
                      reshape_transform=reshape_transform_vim)
    grayscale_cam_vim = cam_vim(input_tensor=img_transformed)
    grayscale_cam_vim = grayscale_cam_vim[0, :]  


    cam_vit = GradCAM(model=model_vit, target_layers=[model_vit.blocks[-1].norm1],
                      reshape_transform=reshape_transform_vit)
    grayscale_cam_vit = cam_vit(input_tensor=img_transformed)
    grayscale_cam_vit = grayscale_cam_vit[0, :]  

    
    cam_vit_imagenet = GradCAM(model=model_vit2, target_layers=[model_vit2.blocks[-1].norm1],
                              reshape_transform=reshape_transform_vit)
    grayscale_cam_vit_imagenet = cam_vit_imagenet(input_tensor=img_transformed)
    grayscale_cam_vit_imagenet = grayscale_cam_vit_imagenet[0, :]  
    
    img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy()
    img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())  # Normalize to [0,1]

    cam_image_vim = show_cam_on_image(img_show, grayscale_cam_vim, use_rgb=True)
    cam_image_vit = show_cam_on_image(img_show, grayscale_cam_vit, use_rgb=True)
    cam_image_vit_imagenet = show_cam_on_image(img_show, grayscale_cam_vit_imagenet, use_rgb=True)


    plt.figure(figsize=(10, 10))

    plt.subplot(2, 2, 1)
    plt.imshow(img_show)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(2, 2, 2)
    plt.imshow(cam_image_vit)
    plt.title('Grad-CAM-ViT')
    plt.axis('off')

    plt.subplot(2, 2, 3)
    plt.imshow(cam_image_vim)
    plt.title('Grad-CAM-Vim')
    plt.axis('off')
    
    
    plt.subplot(2, 2, 4)
    plt.imshow(cam_image_vit_imagenet)
    plt.title('Grad-CAM-ViT-ImageNet')
    plt.axis('off')

    plt.savefig(f'heatmaps/{i}.jpg', bbox_inches='tight')
    plt.close()

100%|█████████████████████████████████████████| 600/600 [01:48<00:00,  5.51it/s]


In [None]:
def vit_compute_heatmap(model, img, device, threshold=None):
    w_featmap = img.shape[-2] // args.patch_size
    h_featmap = img.shape[-1] // args.patch_size
    attentions = model.get_last_selfattention(img.to(device))

    nh = attentions.shape[1] # number of head

    # we keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

    if threshold is not None:
        # we keep only a certain percentage of the mass
        val, idx = torch.sort(attentions)
        val /= torch.sum(val, dim=1, keepdim=True)
        cumval = torch.cumsum(val, dim=1)
        th_attn = cumval > (1 - threshold)
        idx2 = torch.argsort(idx)
        for head in range(nh):
            th_attn[head] = th_attn[head][idx2[head]]
        th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
        # interpolate
        th_attn = nn.functional.interpolate(th_attn.unsqueeze(0),
                                            scale_factor=args.patch_size, mode="nearest")[0].cpu().detach().numpy()

    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), 
                                           scale_factor=args.patch_size, mode="nearest")[0].cpu().detach().numpy()
    attentions = np.repeat(attentions[:, :, :, np.newaxis], 3, -1)
    attentions = (attentions-attentions.min())/attentions.max()
    attentions.max()

    img = torchvision.utils.make_grid(img, normalize=True, scale_each=True)
    return img, attentions

img = Image.open(img_paths[40])
img = val_transform(img)
img = img.unsqueeze(0)
device = 'cuda'
img, attentions = vit_compute_heatmap(model_vit, img, device)
img = img.transpose(0, 1).transpose(1, 2).cpu().numpy()


plt.figure(figsize=(30, 30))
plt.imshow(np.concatenate([img, attentions.max(0)], axis=1))

plt.show()