## Preliminaries
This section contains the boilerplate necessary for the other sections. Run it first.

In [1]:
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from datasets_.coco import build, CocoDetection
from pathlib import Path
import cv2
from PIL import Image
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from util.misc import nested_tensor_from_tensor_list
import skimage
import colorsys
import random
from skimage.measure import find_contours
from matplotlib.patches import Polygon
from skimage import io
import argparse
import datasets_.transforms as T
import copy
import glob
import re
torch.set_grad_enabled(False);

In [2]:
# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

In [3]:
def make_coco_transforms(image_set, args):

    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    scales = [800]

    if image_set == 'val' or image_set == 'train':
        return T.Compose([
            T.RandomResize([scales[-1]], max_size=scales[-1] * 1333 // 800),
            normalize,
        ])

    raise ValueError(f'unknown {image_set}')

def plot_gt(im, labels, bboxes_scaled, output_dir):
    tl = 3
    tf = max(tl-1, 1)
    tempimg = copy.deepcopy(im)
    color = [255,0,0]
    for label, (xmin, ymin, xmax, ymax) in zip(labels.tolist(), bboxes_scaled.tolist()):
        c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
        cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA)
        text = f'{CLASSES[label]}'
        t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    fname = os.path.join(output_dir,'gt_img.png')
    cv2.imwrite(fname, tempimg)
    print(f"{fname} saved.")

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b
def draw_bbox_in_img(fname, bbox_scaled, score, color=[0,255,0]):
    tl = 3
    tf = max(tl-1,1) # font thickness
    # color = [0,255,0]
    im = cv2.imread(fname)
    for p, (xmin, ymin, xmax, ymax) in zip(score, bbox_scaled.tolist()):
        c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
        cv2.rectangle(im, c1, c2, color, tl, cv2.LINE_AA)
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(im, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    cv2.imwrite(fname, im)

def plot_results(cv2_img, prob, boxes, output_dir):
    tl = 3 # thickness line
    tf = max(tl-1,1) # font thickness
    tempimg = copy.deepcopy(cv2_img)
    color = [0,0,255]
    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
        c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
        cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA)
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    fname = os.path.join(output_dir,'pred_img.png')
    cv2.imwrite(fname, tempimg)
    print(f"{fname} saved.")
    
def increment_path(path, exist_ok=False, sep='', mkdir=False):
    # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
    path = Path(path)  # os-agnostic
    if path.exists() and not exist_ok:
        suffix = path.suffix
        path = path.with_suffix('')
        dirs = glob.glob(f"{path}{sep}*")  # similar paths
        matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]  # indices
        n = max(i) + 1 if i else 2  # increment number
        path = Path(f"{path}{sep}{n}{suffix}")  # update path
    dir = path if path.suffix == '' else path.parent  # directory
    if not dir.exists() and mkdir:
        dir.mkdir(parents=True, exist_ok=True)  # make directory
    return path

def save_pred_fig(output_dir, output_dic, keep):
    # im = Image.open(os.path.join(output_dir, "img.png"))
    im = cv2.imread(os.path.join(output_dir, "img.png"))
    h, w = im.shape[:2]
    bboxes_scaled = rescale_bboxes(output_dic['pred_boxes'][0, keep].cpu(), (w,h))
    prob = output_dic['pred_logits'].softmax(-1)[0, :, :-1]
    scores = prob[keep]
    plot_results(im, scores, bboxes_scaled, output_dir)

def save_gt_fig(output_dir, gt_anno):
    im = cv2.imread(os.path.join(output_dir, "img.png"))
    h, w = im.shape[:2]
    bboxes_scaled = rescale_bboxes(gt_anno['boxes'], (w,h))
    labels = gt_anno['labels']
    plot_gt(im, labels, bboxes_scaled, output_dir)

def get_one_query_meanattn(vis_attn,h_featmap,w_featmap):
    mean_attentions = vis_attn.mean(0).reshape(h_featmap, w_featmap)
    mean_attentions = nn.functional.interpolate(mean_attentions.unsqueeze(0).unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
    return mean_attentions

def get_one_query_attn(vis_attn, h_featmap, w_featmap, nh):
    attentions = vis_attn.reshape(nh, h_featmap, w_featmap)
    # attentions = vis_attn.sum(0).reshape(h_featmap, w_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
    return attentions

## Set args
set some args for loading model and saving visualization:
- --patch_size should align with load model setting
- --project set the path, where visualization save
- --name dont change default value
- --index set the img index in coco train/val split
- --backbone_name should align with load model setting
- --coco_path set coco dataset path
- --resume set load model path

In [4]:
def get_args_parser():
    parser = argparse.ArgumentParser('Visualize Self-Attention maps', add_help=False)
    parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
    parser.add_argument('--project', default='./visualization_new', help='Path where to save visualizations.')
    parser.add_argument('--name', default='exp', help='save to project/name')
    parser.add_argument('--index', default=5, type=int, help='index of dataset')
    parser.add_argument('--backbone_name', default='tiny', type=str,
                        help="Name of the deit backbone to use")
    parser.add_argument('--coco_path', default='/data/qingsong/dataset/coco', type=str,
                        help="split")
    parser.add_argument('--image_set', default='val', type=str,
                        help="split")
    parser.add_argument('--pre_trained', default='',
                        help="set imagenet pretrained model path if not train yolos from scatch")
    parser.add_argument("--det_token_num", default=100, type=int,
                        help="Number of det token in the deit backbone")
    parser.add_argument('--init_pe_size', nargs='+', type=int, default=[512,864],
                        help="init pe size (h,w)")
    parser.add_argument('--mid_pe_size', nargs='+', type=int, default=[512,864],
                        help="mid pe size (h,w)")
    parser.add_argument('--resume', default='', help='resume from checkpoint') 
    return parser
parser = argparse.ArgumentParser('Visualize Self-Attention maps', parents=[get_args_parser()])
args = parser.parse_args("")
args.output_dir = str(increment_path(Path(args.project) / args.name))

## load model & coco dataset

In [5]:
# import os
# pretrain_path = '/data/qingsong/pretrain'

# import torch
# from models.detector import Detector
# yolos = Detector(
#         num_classes=91,
#         pre_trained=None,
#         det_token_num=100,
#         backbone_name='tiny',
#         init_pe_size=[800, 1333],
#         mid_pe_size=[512, 864],
#         use_checkpoint=False,
#     )
# yolos.load_state_dict(torch.load(os.path.join(pretrain_path, 'yolos_ti.pth'))['model'], strict=False)


swiss_args = argparse.Namespace(
    num_layers=12,
    vocab_size=101,
    num_det_tokens=100,
    hidden_size=192,
    num_attention_heads=3,
    hidden_dropout=0.,
    attention_dropout=0.,
    in_channels=3,
    image_size=[512, 864],
    patch_size=16,
    pre_len=1,
    post_len=100,
    inner_hidden_size=None,
    hidden_size_per_attention_head=None,
    checkpoint_activations=False,
    checkpoint_num_layers=1,
    sandwich_ln=False,
    post_ln=False,
    model_parallel_size=1,
    world_size=1,
    rank=0,
    num_classes=1000,
    load='/data/qingsong/pretrain/yolos',
    old_image_size=[800, 1333],
    old_pre_len=1,
    old_post_len=100,
    old_checkpoint=None,
    mode='inference',
    num_det_classes=92
    )

import os
import torch
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
master_port = os.getenv('MASTER_PORT', '12468')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
        backend='nccl',
        world_size=swiss_args.world_size, rank=swiss_args.rank, init_method=init_method)
import SwissArmyTransformer.mpu as mpu
mpu.initialize_model_parallel(swiss_args.model_parallel_size)
from yolos_model import YOLOS
from SwissArmyTransformer.training.deepspeed_training import load_checkpoint
swiss_model = YOLOS(swiss_args)
load_checkpoint(swiss_model, swiss_args)
swiss_model.get_mixin('pos_embedding').reinit() # patch_embedding should not reinit for inference
model = swiss_model

root = Path(args.coco_path)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = 'instances'
image_set=args.image_set
PATHS = {
    "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
    "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
}
img_folder, ann_file = PATHS[image_set]
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, None), return_masks=False)
img_data, img_anno = dataset.__getitem__(args.index)
ret=nested_tensor_from_tensor_list(img_data.unsqueeze(0))

> initializing model parallel with size 1
global rank 0 is loading checkpoint /data/qingsong/pretrain/yolos/1/mp_rank_00_model_states.pt
  successfully loaded /data/qingsong/pretrain/yolos/1/mp_rank_00_model_states.pt
loading annotations into memory...
Done (t=0.62s)
creating index...
index created!


## forward to get pred & attn

In [6]:
# device = torch.device("cuda")
device = torch.device("cpu")
model = model.eval()
model.to(device)
ret = ret.to(device)


images = ret.tensors
batch_size, _, height, width = images.shape
num_patches = (height//16) * (width//16)
seq_len = 1 + num_patches + model.get_mixin('det_head').num_det_tokens
position_ids = torch.cat([torch.arange(seq_len)[None,]]*batch_size)
encoded_input = {'input_ids':torch.cat([torch.arange(1+model.get_mixin('det_head').num_det_tokens)[None,]]*batch_size).long(), 'image':images, 'position_ids':position_ids}
encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
encoded_input['attention_mask'] = None

outputs = model(**encoded_input, offline=False, height=height//16, width=width//16)[0]

# outputs = yolos(ret)

# attention = model.forward_return_attention(ret)
# attention = attention[-1].detach().cpu()
# nh = attention.shape[1] # number of head
# attention = attention[0, :, -args.det_token_num:, 1:-args.det_token_num]
#forward input to get pred
result_dic = outputs
# get visualize dettoken index
probas = result_dic['pred_logits'].softmax(-1)[0, :, :-1].cpu()
keep = probas.max(-1).values > 0.9
vis_indexs = torch.nonzero(keep).squeeze(1)
# save original image
os.makedirs(args.output_dir, exist_ok=True)
img = ret.tensors.squeeze(0).cpu()
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))

# save pred image
save_pred_fig(args.output_dir, result_dic, keep)

# save gt image
save_gt_fig(args.output_dir, img_anno)




visualization_new/exp/pred_img.png saved.
visualization_new/exp/gt_img.png saved.
