### Dependencies

In [23]:
'''
Visualization tools are adapted from https://github.com/facebookresearch/dino.
'''

# Base Dependencies
import argparse
import colorsys
import os
import random
import sys
import requests
from io import BytesIO

# LinAlg / Stats / Plotting Dependencies
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import skimage.io
from skimage.measure import find_contours
from tqdm import tqdm

# Torch Dependencies
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms as pth_transforms
import numpy as np
from PIL import Image

# Utils
import nn_encoder_arch.vision_transformer as vits
from attention_visualization_utils import create_256x256_map_concat

### Loading Pretrained ViT-S/16

In [24]:
arch = 'vit_small'
patch_size = 16
pretrained_weights = './ckpts/vits_tcga_brca_dino.pt'
checkpoint_key = 'teacher'
threshold = 0.5

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### Build model
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)

if os.path.isfile(pretrained_weights):
    state_dict = torch.load(pretrained_weights, map_location="cpu")
    if checkpoint_key is not None and checkpoint_key in state_dict:
        print(f"Take key {checkpoint_key} in provided checkpoint dict")
        state_dict = state_dict[checkpoint_key]
    # remove `module.` prefix
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    # remove `backbone.` prefix induced by multicrop wrapper
    state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
    msg = model.load_state_dict(state_dict, strict=False)
    print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))

Take key teacher in provided checkpoint dict
Pretrained weights found at ./ckpts/vits_tcga_brca_dino.pt and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


In [25]:
heads = list(range(6))
IMG_SIZE_DINO = 256  # image size of dino model (do not change)
IMG_SIZE_MINE = 1024  # image size of your images (change)

In [26]:
import glob
images = glob.glob('/gpfs/space/projects/PerkinElmer/testis/tiles/train2017/*.png')

In [5]:
# np.array(Image.open('attention_visualization_results/image_4k/image_0_0.png_mask_th.png')).shape

In [6]:
# np.array(Image.open('attention_visualization_results/image_4k/image_0_0.png_mask_th0.5_head0.png')).shape

In [7]:
# np.array(Image.open('attention_visualization_results/oneshot/18H14294I_lvl0_'
#                     '103424_208896/18H14294I_lvl0_103424_208896.png_mask_th0.5_head0.png')).shape

In [8]:
# np.array(Image.open('attention_visualization_results/oneshot/18H14294I_lvl0_'
#                     '103424_208896/18H14294I_lvl0_103424_208896.png_mask_th0.5_head0_annotated.png')).shape

In [9]:
# np.array(Image.open('attention_visualization_results/oneshot/18H14294I_lvl0_'
#                     '103424_208896/18H14294I_lvl0_103424_208896.png_mask_th.png')).shape

In [10]:
# np.array(Image.open('attention_visualization_results/'
#            'joined_patches/18H14294II_lvl0_70656_167936/18H14294II_lvl0_70656_167936_head0.png')).shape

In [11]:
# np.array(Image.open('attention_visualization_results/'
#            'joined_patches/18H14294II_lvl0_70656_167936/18H14294II_lvl0_70656_167936_head0_annotated.png')).shape

In [12]:
# np.array(Image.open('attention_visualization_results/'
#            'per_patch/18H14294I_lvl0_103424_208896/image_0_0.png_mask_th0.5_head0.png')).shape

In [13]:
# np.array(Image.open('attention_visualization_results/'
#            'per_patch/18H14294I_lvl0_103424_208896/image_256_0.png_mask_th.png')).shape

In [5]:
# run with the original image size
# save to attention_visualization_results/oneshot

for img_path in tqdm(images):
    img = Image.open(img_path)
    img_fname = os.path.basename(img_path)
    output_dir = os.path.join('attention_visualization_results', 'oneshot', img_fname[:-4])
    os.makedirs(output_dir, exist_ok=True)
    create_256x256_map_concat(model, img, img_fname, output_dir, 
                              image_size=(IMG_SIZE_MINE, IMG_SIZE_MINE), display=False, 
                              which_concat=heads)

100%|██████████| 257/257 [1:15:15<00:00, 17.57s/it]


In [31]:
# run with patches
# save to attention_visualization_results/per_patch

for img_path in tqdm(images):
    img = np.array(Image.open(img_path))

    img_fname = os.path.basename(img_path)
    output_dir = os.path.join('attention_visualization_results', 'per_patch', img_fname[:-4])
    os.makedirs(output_dir, exist_ok=True)

    assert IMG_SIZE_MINE / IMG_SIZE_DINO == int(IMG_SIZE_MINE / IMG_SIZE_DINO)
    
    for i in range(0, IMG_SIZE_MINE, IMG_SIZE_DINO):
        for j in range(0, IMG_SIZE_MINE, IMG_SIZE_DINO):
            patch = Image.fromarray(img[i:(i+IMG_SIZE_DINO),j:(j+IMG_SIZE_DINO),:])
            subimg_fname = 'image_%d_%d.png' % (i,j)
            create_256x256_map_concat(model, patch, subimg_fname, output_dir,
                                      image_size=(IMG_SIZE_DINO,IMG_SIZE_DINO), display=False, 
                                      which_concat=heads)

100%|██████████| 58/58 [29:32<00:00, 30.55s/it]


In [33]:
# join patches into one image
# save to attention_visualization_results/joined_patches


def join_patches(files):
    patches = []
    img_full = np.empty((IMG_SIZE_MINE, IMG_SIZE_MINE, 4), dtype='int')

    for img_path in files:
        img = np.array(Image.open(img_path))
        assert img.shape == (IMG_SIZE_DINO, IMG_SIZE_DINO, 4), str(img.shape)
        
        img_name = os.path.basename(img_path)
        _, i, j = img_name.split('.')[0].split('_')

        i, j = int(i), int(j)
        img_full[i:i+IMG_SIZE_DINO, j:j+IMG_SIZE_DINO] = img
    
    return img_full


for img_path in tqdm(images):
    img_fname = os.path.basename(img_path)
    
    for head in heads:
        patches_dir = os.path.join('attention_visualization_results', 'per_patch', img_fname[:-4])    
        patches_pat = os.path.join(patches_dir, f'*th0.5_head{head}.png')
        patch_files = glob.glob(patches_pat)

        full_img = join_patches(patch_files)

        full_img_path = os.path.join('attention_visualization_results', 'joined_patches', img_fname[:-4])
        os.makedirs(full_img_path, exist_ok=True)

        pil_img = Image.fromarray(np.uint8(full_img))
        pil_img.save(os.path.join(full_img_path, img_fname[:-4]+f'_head{head}.png'))

100%|██████████| 257/257 [38:22<00:00,  8.96s/it]


In [34]:
def get_image_id(file_name):
    for img in coco['images']:
        if img['file_name'] == file_name:
            return img['id']
        
def get_img_annotations(img_id):
    anns = []
    for ann in coco['annotations']:
        if ann['image_id'] == img_id:
            anns.append(ann)
    return anns

import json

ANNOTATION_FILE = '/gpfs/space/projects/PerkinElmer/testis/tiles/annotations/instances_train2017.json'
with open(ANNOTATION_FILE, 'r') as f:
    coco = json.load(f)

In [35]:
# draw bboxes over produced images (original image size, at attention_visualization_results/oneshot)

from PIL import ImageDraw

for dir_path in tqdm(glob.glob('attention_visualization_results/oneshot/*')):
    dir_name = os.path.basename(dir_path)
    
    for head in heads:
        img_path = os.path.join(dir_path, dir_name+f'.png_mask_th0.5_head{head}.png')

        img = Image.open(img_path)
        assert np.array(img).shape == (IMG_SIZE_MINE, IMG_SIZE_MINE, 4), str(np.array(img).shape)

        for ann in get_img_annotations(get_image_id(dir_name+'.png')):
            x, y, w, h = ann['bbox']

            img_draw = ImageDraw.Draw(img)
            img_draw.rectangle((x, y, x+w, y+h), width=2, outline='green')

        img.save(os.path.join('attention_visualization_results', 'oneshot', 
                              dir_name, dir_name+f'.png_mask_th0.5_head{head}_annotated.png'))

100%|██████████| 257/257 [23:25<00:00,  5.47s/it]


In [36]:
# draw bboxes over produced images (joined patches, at attention_visualization_results/joined_patches)

from PIL import ImageDraw

for dir_path in tqdm(glob.glob('attention_visualization_results/joined_patches/*')):
    dir_name = os.path.basename(dir_path)
    
    for head in heads:
        img_path = os.path.join(dir_path, dir_name+f'_head{head}.png')
        img = Image.open(img_path)

        for ann in get_img_annotations(get_image_id(dir_name+'.png')):
            x, y, w, h = ann['bbox']

            img_draw = ImageDraw.Draw(img)
            img_draw.rectangle((x, y, x+w, y+h), width=8, outline='green')

        img.save(os.path.join('attention_visualization_results', 'joined_patches', 
                              dir_name, dir_name+f'_head{head}_annotated.png'))

100%|██████████| 257/257 [28:14<00:00,  6.59s/it]


In [38]:
!du -sh attention_visualization_results/

15G	attention_visualization_results/


In [None]:
!ls attention_visualization_results