In [1]:
import torch
import torchvision
from mmtrack.models.mot import QDTrack
from mmtrack.apis import batch_inference_mot, init_model, inference_mot
from mmdet.models import FasterRCNN
import cv2
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image

from mmtrack.utils import register_all_modules
register_all_modules()
from mmdet.models import StandardRoIHead

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def draw_boxes(boxes, labels, image):
    for i, box in enumerate(boxes):
        color = (0, 0, 0)
        cv2.rectangle(
            image,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 2
        )
        cv2.putText(image, str(labels[i]), (int(box[0]), int(box[1] - 5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                    lineType=cv2.LINE_AA)
    return image

In [12]:
class BoxScoreTarget:
    def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
        self.labels = labels
        self.bounding_boxes = bounding_boxes
        self.iou_threshold = iou_threshold

    def __call__(self, model_outputs):
        output = torch.Tensor([0])
        if torch.cuda.is_available():
            output = output.cuda()

        if len(model_outputs["boxes"]) == 0:
            return output

        for box, label in zip(self.bounding_boxes, self.labels):
            box = torch.Tensor(box[None, :])
            if torch.cuda.is_available():
                box = box.cuda()

            ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
            index = ious.argmax()
            if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label:
                score = ious[0, index] + model_outputs["scores"][index]
                output = output + score
        return output


class WrappedModel(torch.nn.Module):
    def __init__(self, config, checkpoint):
        super().__init__()
        self.model = init_model(config, checkpoint, device="cpu")
        assert isinstance(self.model, QDTrack)
        self.target_layers = self.model.detector.backbone

    def forward(self, img: torch.Tensor):
        img = img.moveaxis(0, -1).cpu().numpy()
        # results = batch_inference_mot(self.model, [img], [0])[0]
        results = inference_mot(self.model, img, 0)

        return {
            "boxes": results.pred_track_instances.bboxes,
            "labels": results.pred_track_instances.labels,
            "scores": results.pred_track_instances.scores,
        }

In [20]:
from mmcv import VideoReader
imgs = VideoReader("../demo/test1.mp4")

model = WrappedModel(
    config="../configs/mot/qdtrack/qdtrack_faster-rcnn_r50_fpn_aic.py",
    checkpoint="../checkpoints/qdtrack_faster-rcnn_aic.pth")

img_np = imgs[0]
img = torch.from_numpy(img_np).moveaxis(-1, 0)
data = model(img)
targets = [BoxScoreTarget(data["labels"], data["boxes"])]

03/19 09:58:18 - mmengine - INFO - load model from: open-mmlab://detectron2/resnet50_caffe
03/19 09:58:18 - mmengine - INFO - Loads checkpoint by openmmlab backend from path: open-mmlab://detectron2/resnet50_caffe

unexpected key in source state_dict: conv1.bias

03/19 09:58:19 - mmengine - INFO - load model from: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth
03/19 09:58:19 - mmengine - INFO - Loads checkpoint by http backend from path: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth
Loads checkpoint by local backend from path: ../checkpoints/qdtrack_faster-rcnn_aic.pth


In [None]:
from mmdet.models import StandardRoIHead

In [24]:
target_layers = [model.model.detector.neck]
def reshape_transform(tensor):
    target_size = tensor[0].shape[-2:]
    outs = []
    for feat in tensor:
        outs.append(torch.nn.functional.interpolate(torch.abs(feat), target_size, mode='bilinear'))
    outs = torch.cat(outs, dim=1)
    print(outs.shape)
    return outs
cam = EigenCAM(model, target_layers, reshape_transform=reshape_transform, use_cuda=False)

grayscale_cam = cam(img, targets=targets)

torch.Size([1, 1280, 96, 160])


In [16]:
grayscale_cam.shape

(1, 1080, 1920)

In [25]:
cam_img = show_cam_on_image(img_np / 255.0, grayscale_cam[0], use_rgb=False)
cam_img.shape

(1080, 1920, 3)

In [26]:
import matplotlib.pyplot as plt
draw_boxes(data["boxes"], data["labels"], cam_img)
cv2.imwrite("cam_img.png", cam_img)

True