In [None]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import cv2
import random
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

from mmdet.datasets import build_dataset, build_dataloader
from mmdet.apis import init_detector, inference_detector
from mmcv import Config

In [None]:
# Specify the path to model config and checkpoint file
cfg = Config.fromfile("../configs/gwhd/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_gwhd.py")
checkpoint_file = '../experiments/moby_swin_t_imnet_mask_rcnn_3x/latest.pth'

# build the model from a config file and a checkpoint file
model = init_detector(cfg, checkpoint_file)

In [None]:
cfg.data.val.ann_file = "gwhd_2021/annotations/test.json"
cfg.data.val.img_prefix = "gwhd_2021/images"
dataset = build_dataset(cfg.data.val, dict(test_mode=True))

In [None]:
from mmcv import ProgressBar

dataloader = build_dataloader(
                dataset, 
                16,
                cfg.data.workers_per_gpu,
                shuffle=False,
                dist=False)

results = []
prog_bar = ProgressBar(len(dataset))
for idx, batch in enumerate(dataloader):
    imgs = [item["filename"] for item in batch['img_metas'][0].data[0]]
    result = inference_detector(model, imgs)
    batch_size = len(result)
    results.extend(result)
    for _ in range(batch_size):
        prog_bar.update()

In [None]:
len(results)

In [None]:
eval_res = dataloader.dataset.evaluate(results)

In [None]:
hlines = []
for i in range(4):
  vlines = []
  for j in range(4):
    idx = random.randint(0, len(dataset))

    #img = dataset[idx]['img'][0]._data.permute(1, 2, 0)
    img_path = dataset[idx]["img_metas"][0]._data["filename"]
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    result = inference_detector(model, img)

    pboxes = result[0][:, 0:4]
    scores = result[0][:, 4]
    pboxes = pboxes[scores > 0.5]

    for (x, y, xx, yy) in pboxes:
      cv2.rectangle(img, (int(x.item()), int(y.item())), (int(xx.item()), int(yy.item())), (0, 255, 255), 2)

    vlines.append(img)
  hlines.append(cv2.vconcat(vlines))

final_img = cv2.hconcat(hlines)

fig ,ax = plt.subplots(1, 1, figsize=(20, 20))
plt.imshow(final_img)