# Bursty Prompt Tuning: detection on COCO

This is a detection demo using our BPT with Cacade Mask R-CNN model. No GPU is needed.

## Import packages

In [13]:

import torch, json
from torchvision.transforms import functional as F
from Model.backbone import SimpleFeaturePyramid, MAE_bpt_shallow
from Model.detection import MaskRCNN
from PIL import Image
from utils import Visualizer
from collections import OrderedDict
from pathlib import Path
import os
import warnings
warnings.filterwarnings("ignore")

## Prepare model

In [14]:
def create_model(num_classes):
    vit_backbone = MAE_bpt_shallow.__dict__["MAE_bpt_vit_b"](
        drop_path_rate=0.0, 
        num_prompts=100, 
        channels=100,
    )

    backbone = SimpleFeaturePyramid(
        backbone=vit_backbone,
        out_channels=256,
        scale_factors=(4.0, 2.0, 1.0, 0.5),
        top_block=True,
        norm="LN",
    )

    roi_heads = "CascadeRoIHead"
    model = MaskRCNN(backbone=backbone, num_classes=num_classes, roi_heads=roi_heads, box_head_norm="LN", mask_head_norm="LN")

    # load weights
    weights_path = './ckpt.pth'
    ckpt = torch.load(weights_path, map_location="cpu")['model']
    new_state_dict = OrderedDict()

    for k, v in ckpt.items():
        name = k[7:]   # remove 'module.'
        new_state_dict[name] = v

    msg = model.load_state_dict(new_state_dict, strict=False)
    print(msg)
    return model

## Visualizer

In [15]:
# load image, random sample from COCO val-set
img_path_list = ["./visual/images/000000087144.jpg",
                 "./visual/images/000000110211.jpg",
                 "./visual/images/000000142238.jpg",
                 "./visual/images/000000171382.jpg",
                 "./visual/images/000000173799.jpg",
                 "./visual/images/000000185409.jpg",
                 "./visual/images/000000188296.jpg",
                 "./visual/images/000000197388.jpg",
                 "./visual/images/000000211120.jpg",
                 "./visual/images/000000227898.jpg",
                 "./visual/images/000000268375.jpg",
                 "./visual/images/000000314034.jpg",
                 "./visual/images/000000319935.jpg",
                 "./visual/images/000000336232.jpg",
                 "./visual/images/000000382111.jpg",
                 "./visual/images/000000457559.jpg",
                 "./visual/images/000000570688.jpg",
                ]
    
# create model
device = "cpu"
model = create_model(num_classes=91)
model.to(device)
model.eval()

# class-name
with open("./Dataset/coco91_classname.json", 'r') as f:
    category_index = json.load(f)
f.close()

save_dir = Path("./visual/results/cascade-bpt")
save_dir.mkdir(parents=True, exist_ok=True)
model_tags = {
    "MAE_bpt_vit_b": "vit_b",
}

for img_path in img_path_list:
    assert os.path.exists(img_path)
    img = Image.open(img_path).convert("RGB")

    img_name = (img_path.split("/")[-1]).split(".")[0]
    save_path = os.path.join(save_dir, model_tags["MAE_bpt_vit_b"] + "-" + img_name + '.png')
    print("sava results in:")
    print(save_path)

    # predict
    img = F.to_tensor(img)
    img_batch = torch.unsqueeze(img, dim=0).to(device)
    with torch.no_grad():
        pred = model(img_batch)[0]
    
    visualizer = Visualizer(img_rgb=img, visual_masks=True, linewidth=1.5, dpi=800, class_name=category_index)
    vis_output = visualizer.draw_instance_predictions(pred=pred)
    vis_output.save(filepath=save_path)
    

Fast R-CNN box head: 4Conv1FC
<All keys matched successfully>
sava results in:
visual/results/cascade-bpt/vit_b-000000087144.png
sava results in:
visual/results/cascade-bpt/vit_b-000000110211.png
sava results in:
visual/results/cascade-bpt/vit_b-000000142238.png
sava results in:
visual/results/cascade-bpt/vit_b-000000171382.png
sava results in:
visual/results/cascade-bpt/vit_b-000000173799.png
sava results in:
visual/results/cascade-bpt/vit_b-000000185409.png
sava results in:
visual/results/cascade-bpt/vit_b-000000188296.png
sava results in:
visual/results/cascade-bpt/vit_b-000000197388.png
sava results in:
visual/results/cascade-bpt/vit_b-000000211120.png
sava results in:
visual/results/cascade-bpt/vit_b-000000227898.png
sava results in:
visual/results/cascade-bpt/vit_b-000000268375.png
sava results in:
visual/results/cascade-bpt/vit_b-000000314034.png
sava results in:
visual/results/cascade-bpt/vit_b-000000319935.png
sava results in:
visual/results/cascade-bpt/vit_b-000000336232.png
