In [None]:
from collections import defaultdict

import cv2
import torch
import numpy as np
from argus import load_model
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader

from src.detect.utils import cxcywh2xyxy
from src.detect.dataset import DetectDataset
from src.detect.transforms import val_transform, train_transform
from src.detect.yolox.metamodel import YOLOXMetaModel

%load_ext autoreload
%autoreload 2

In [None]:
input_size = (640, 640)
img_dir='/workdir/data/datasets/train/'
annot_file='/workdir/data/annot/converted_train.json'
samples_range = (0.0, 0.1)
transform = train_transform(input_size=input_size, fill_value=127, max_labels=128)

In [None]:
dataset = DetectDataset(img_dir=img_dir,
                        annot_file=annot_file,
                        samples_range=samples_range,
                        transform=transform)
loader = DataLoader(
        dataset, batch_size=4,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

In [None]:
batch = next(iter(loader))

In [None]:
def show_img(img):
    plt.figure(dpi=200)
    plt.imshow(img[:,:,::-1])
    plt.show()

def visualize_img(img_tensor, bbox_tensor=None, bbox_color=(255,0,0), cx: bool = True):
    img = img_tensor.clone().cpu().numpy() * 255.0
    img = np.moveaxis(img, 0, -1)
    img = np.clip(img, 0, 255).astype(np.uint8)
    img_cv = cv2.UMat(img)
    if bbox_tensor is not None:
        bboxes = bbox_tensor.clone().cpu().numpy()
        if cx:
            bboxes = cxcywh2xyxy(bboxes)
        for bbox in bboxes:
            if np.any(bbox>0):
                x1, y1, x2, y2 = map(int, bbox)
                cv2.rectangle(img_cv, (x1, y1), (x2, y2), bbox_color, 2)
    show_img(img_cv.get())
    
    

In [None]:
batch[1].shape

In [None]:
visualize_img(batch[0][0], batch[1][0, :, 1:])

In [None]:
model = load_model('/workdir/data/experiments/YOLOX_train_001/model-002-0.664820.pth', device='cuda:0')

In [None]:
pred = model.predict(batch[0])

In [None]:
def postprocess(prediction, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for i, image_pred in enumerate(prediction):

        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(image_pred[:, 4: 5], 1, keepdim=True)
        #print(class_pred)

        conf_mask = (image_pred[:, 4] >= conf_thre).squeeze()
        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat((image_pred[:, :4], class_conf, class_pred.float()), 1)
        #print(detections)
        detections = detections[conf_mask]
        if not detections.size(0):
            continue

        if class_agnostic:
            nms_out_index = torchvision.ops.nms(
                detections[:, :4],
                detections[:, 4] * detections[:, 5],
                nms_thre,
            )
        else:
            nms_out_index = torchvision.ops.batched_nms(
                detections[:, :4],
                detections[:, 4] * detections[:, 5],
                detections[:, 6],
                nms_thre,
            )

        detections = detections[nms_out_index]
        if output[i] is None:
            output[i] = detections
        else:
            output[i] = torch.cat((output[i], detections))

    return output

In [None]:
output = postprocess(pred.clone(), conf_thre=0.3, class_agnostic=True)
if output[0] is not None:
    print(output[0].shape)

In [None]:
visualize_img(batch[0][0], output[0][:, :4], cx=False)