# BEIT visualization

In [1]:
# Copyright (c) Facebook, Inc. and its affiliates.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
# import cv2
import random
import colorsys
import requests
from io import BytesIO

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon, Circle
from PIL import Image

import skimage.io
from skimage.measure import find_contours
from scipy import interpolate

import torch
import torch.nn as nn
import torchvision.transforms.functional as F
import torchvision
from torchvision import transforms as pth_transforms

from timm.models import create_model

import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

import utils

import modeling_finetune

In [2]:
def apply_mask(image, mask, color, alpha=0.5):
    for c in range(3):
        image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
    return image


def random_colors(N, bright=True):
    """
    Generate random colors.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


def display_instances(image, mask, figsize=(5, 5), blur=False, contour=True, alpha=0.5):
    fig = plt.figure(figsize=figsize, frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax = plt.gca()

    N = 1
    mask = mask[None, :, :]
    # Generate random colors
    colors = random_colors(N)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    margin = 0
    ax.set_ylim(height + margin, -margin)
    ax.set_xlim(-margin, width + margin)
    ax.axis('off')
    masked_image = image.astype(np.uint32).copy()
    for i in range(N):
        color = colors[i]
        _mask = mask[i]
        if blur:
            _mask = cv2.blur(_mask,(10,10))
        # Mask
        masked_image = apply_mask(masked_image, _mask, color, alpha)
        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        if contour:
            padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
            padded_mask[1:-1, 1:-1] = _mask
            contours = find_contours(padded_mask, 0.5)
            for verts in contours:
                # Subtract the padding and flip (y, x) to (x, y)
                verts = np.fliplr(verts) - 1
                p = Polygon(verts, facecolor="none", edgecolor=color)
                ax.add_patch(p)
    ax.imshow(masked_image.astype(np.uint8), aspect='auto')
    fig.show()


In [64]:
def show(imgs, marker=None):
    if not isinstance(imgs, list):
        imgs = [imgs]
    
    # Create subplot grid
    n_cols = min(len(imgs), 5)
    n_rows = len(imgs) // n_cols + 1
    plt.subplots_adjust(wspace=0, hspace=0)
    fix, axs = plt.subplots(nrows=n_rows, ncols=n_cols, squeeze=True, dpi=150)
    missing_imgs = n_rows * n_cols - len(imgs)
    for i in range(missing_imgs):
        insertion_index = (i+1) * n_cols
        imgs = imgs[:insertion_index] + [imgs[0]] + imgs[insertion_index:]

    for i, img in enumerate(imgs):
        img = F.to_pil_image(img)
        axs_index = (i//n_cols, i%n_cols) if n_rows > 1 else i
        axs[axs_index].imshow(np.asarray(img))
        if marker is not None:
            axs[axs_index].add_patch(Circle(marker[::-1], radius=8, color='red'))
        axs[axs_index].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
#         axs[i//n_cols, i%n_cols].set_aspect('equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
# draw_attentions(attentions_normalized[:, :-5], (0,0), THRESHOLD)

In [4]:
def get_index_attentions(attentions, threshold, index=0):
    
    nh = attentions.shape[1] # number of head

    # we keep only the output patch attention
    attentions = attentions[0, :, index, 1:].reshape(nh, -1)
    
    th_attn = None
    if threshold is not None:
        # we keep only a certain percentage of the mass
        val, idx = torch.sort(attentions)
        val /= torch.sum(val, dim=1, keepdim=True)
        cumval = torch.cumsum(val, dim=1)
        th_attn = cumval > (1 - threshold)
        idx2 = torch.argsort(idx)
        for head in range(nh):
            th_attn[head] = th_attn[head][idx2[head]]
        th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
        # interpolate
        th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=PATCH_SIZE, mode="nearest")[0].cpu().numpy()

    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=PATCH_SIZE, mode="nearest")[0].cpu().numpy()
    return attentions, th_attn


In [5]:
def create_beit_model():
    model = create_model(
        'beit_base_patch16_224',
        pretrained=False,
        num_classes=21841,
        drop_rate=0.0,
        drop_path_rate=0.1,
        attn_drop_rate=0.0,
        drop_block_rate=None,
        use_mean_pooling=False,
        init_scale=0.001,
        use_rel_pos_bias=True,
        use_shared_rel_pos_bias=False,
        use_abs_pos_emb=False,
        init_values=0.1,
        img_size=INPUT_SIZE
    )

    patch_size = model.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    window_size = (INPUT_SIZE // patch_size[0], INPUT_SIZE // patch_size[1])
    patch_size = patch_size

    if WEIGHT_PATH:
        if WEIGHT_PATH.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                WEIGHT_PATH, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(WEIGHT_PATH, map_location='cpu')

        print("Load ckpt from %s" % WEIGHT_PATH)
        checkpoint_model = None
        for model_key in MODEL_KEY.split('|'):
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
            print("Expand the shared relative position embedding to each transformer block. ")
            num_layers = model.get_num_layers()
            rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"]
            for i in range(num_layers):
                checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()

            checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")

        all_keys = list(checkpoint_model.keys())
        for key in all_keys:
            if "relative_position_index" in key:
                checkpoint_model.pop(key)
      
            if "relative_position_bias_table" in key:
                rel_pos_bias = checkpoint_model[key]
                src_num_pos, num_attn_heads = rel_pos_bias.size()
                dst_num_pos, _ = model.state_dict()[key].size()
                dst_patch_shape = model.patch_embed.patch_shape
                if dst_patch_shape[0] != dst_patch_shape[1]:
                    raise NotImplementedError()
                num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
                src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
                dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
                if src_size != dst_size:
                    print("Position interpolate for %s from %dx%d to %dx%d" % (
                        key, src_size, src_size, dst_size, dst_size))
                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                    def geometric_progression(a, r, n):
                        return a * (1.0 - r ** n) / (1.0 - r)

                    left, right = 1.01, 1.5
                    while right - left > 1e-6:
                        q = (left + right) / 2.0
                        gp = geometric_progression(1, q, src_size // 2)
                        if gp > dst_size // 2:
                            right = q
                        else:
                            left = q

                    # if q > 1.090307:
                    #     q = 1.090307

                    dis = []
                    cur = 1
                    for i in range(src_size // 2):
                        dis.append(cur)
                        cur += q ** (i + 1)

                    r_ids = [-_ for _ in reversed(dis)]

                    x = r_ids + [0] + dis
                    y = r_ids + [0] + dis

                    t = dst_size // 2.0
                    dx = np.arange(-t, t + 0.1, 1.0)
                    dy = np.arange(-t, t + 0.1, 1.0)

                    print("Original positions = %s" % str(x))
                    print("Target positions = %s" % str(dx))

                    all_rel_pos_bias = []

                    for i in range(num_attn_heads):
                        z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
                        f = interpolate.interp2d(x, y, z, kind='cubic')
                        all_rel_pos_bias.append(
                            torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))

                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
                    checkpoint_model[key] = new_rel_pos_bias

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed

        utils.load_state_dict(model, checkpoint_model, prefix='')
        # model.load_state_dict(checkpoint_model, strict=False)
        return model


In [91]:
PATCH_SIZE = 16
MODEL = f'beit_base_patch{PATCH_SIZE}_224'
WEIGHT_PATH = f'https://unilm.blob.core.windows.net/beit/{MODEL}_pt22k.pth'
INPUT_SIZE = 320
MODEL_KEY = 'model|module'

model = create_beit_model()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# build model
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)
print("done")


Patch size = (16, 16)
Load ckpt from https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k.pth
Load state_dict by model_key = model
Expand the shared relative position embedding to each transformer block. 
Position interpolate for blocks.0.attn.relative_position_bias_table from 27x27 to 39x39
Original positions = [-19.00000568303374, -16.965636986366434, -15.048173102640158, -13.240896107365035, -11.537474120973691, -9.931939124784734, -8.418666051769893, -6.992353078868327, -5.648003051801297, -4.380905977308203, -3.1866225214646713, -2.06096845626831, -1, 0, 1, 2.06096845626831, 3.1866225214646713, 4.380905977308203, 5.648003051801297, 6.992353078868327, 8.418666051769893, 9.931939124784734, 11.537474120973691, 13.240896107365035, 15.048173102640158, 16.965636986366434, 19.00000568303374]
Target positions = [-19. -18. -17. -16. -15. -14. -13. -12. -11. -10.  -9.  -8.  -7.  -6.
  -5.  -4.  -3.  -2.  -1.   0.   1.   2.   3.   4.   5.   6.   7.   8.
   9.  10.  11.  12.  

In [92]:
THRESHOLD = None
ATTENTION_PIXEL = (100,100)  # None

def draw_attentions(attentions, attention_pixel, threshold):
    if attention_pixel:
        feature_pixel = (attention_pixel[0]//PATCH_SIZE, attention_pixel[1]//PATCH_SIZE)
        feature_index = (feature_pixel[0] * IMAGE_SIZE[0]//PATCH_SIZE + feature_pixel[1]) + 1
    else:
        feature_index = 0  # class token
    
    index_attentions, th_attn = get_index_attentions(attentions, threshold, index=feature_index)

    # show attentions heatmaps
    plottable_img = torchvision.utils.make_grid(img, normalize=True, scale_each=True)
    show([plottable_img] + [attention for attention in index_attentions], marker=attention_pixel)

#     if threshold is not None:
#     #     image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
#         maskable_img = plottable_img.permute(1,2,0).numpy()
#         for j in range(nh):
#             display_instances(maskable_img, th_attn[j], blur=False)

# draw_attentions(ATTENTION_PIXEL, THRESHOLD)

In [93]:
images = !ls ../../imagenet-sample-images/
images = [None] + images
image_selector = widgets.Dropdown(options=images, value=None, description="Image to use")
display(image_selector)

Dropdown(description='Image to use', options=(None, 'README.md', 'beach.webp', 'coco_person.jpeg', 'dashcam.jp…

In [96]:
IMAGE_URL = "https://dl.fbaipublicfiles.com/dino/img.png"
# IMAGE_URL = 'https://raw.githubusercontent.com/EliSchwartz/imagenet-sample-images/master/n02676566_acoustic_guitar.JPEG'
IMAGE_SIZE = (INPUT_SIZE, INPUT_SIZE)

image_path = os.path.join('..', '..', 'imagenet-sample-images', image_selector.value) if image_selector.value else None
# open image
if image_path is None:
    # user has not specified any image - we use our own image
    print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.")
    print("Since no image path have been provided, we take the first image in our paper.")
    response = requests.get(IMAGE_URL)
    img = Image.open(BytesIO(response.content))
    img = img.convert('RGB')
elif os.path.isfile(image_path):
    with open(image_path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGB')
else:
    print(f"Provided image path {image_path} is non valid.")
    sys.exit(1)

transform = pth_transforms.Compose([
    pth_transforms.Resize(IMAGE_SIZE),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
img = transform(img)

# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % PATCH_SIZE, img.shape[2] - img.shape[2] % PATCH_SIZE
img = img[:, :w, :h].unsqueeze(0)

w_featmap = img.shape[-2] // PATCH_SIZE
h_featmap = img.shape[-1] // PATCH_SIZE

attentions = model.get_last_selfattention(img.to(device))
attentions_normalized = attentions / (attentions.mean(axis=-2)[:, :, None, :]+1e-7)
# head_sum = attentions_normalized.sum(axis=1)[:, None]
# attentions = nn.functional.normalize(attentions, p=1, dim=-1)
attentions_normalized = torch.cat((attentions_normalized,
                                   attentions_normalized.sum(axis=1)[:, None],
                                   attentions.sum(axis=1)[:,None],
                                  (attentions > 0.005).sum(axis=1)[:, None]), axis=1)


Please use the `--image_path` argument to indicate the path of the image you wish to visualize.
Since no image path have been provided, we take the first image in our paper.


In [97]:
h_slider = widgets.IntSlider(min=0, max=IMAGE_SIZE[0]-1, step=1, description="height", continuous_update=False)
w_slider = widgets.IntSlider(min=0, max=IMAGE_SIZE[1]-1, step=1, description="width", continuous_update=False)
# th_slider = widgets.FloatSlider(min=0, max=1.0, value=0.0, step=0.05, description="threshold", continuous_update=False)

def redraw(h, w, cls):
    if cls:
        pixel = None
    else:
        pixel = (h, w)
#     display(draw_attentions(attentions_normalized, pixel, THRESHOLD))
    display(draw_attentions(attentions_normalized, pixel, THRESHOLD))


interactive_plot = interactive(redraw, h=h_slider, w=w_slider, cls=True)
output = interactive_plot.children[-1]
output.layout.height = '500px'
interactive_plot

interactive(children=(IntSlider(value=0, continuous_update=False, description='height', max=319), IntSlider(va…