In [1]:
import pickle

In [2]:
import torch
from yolo.model import box_ops

In [3]:
predictions = []
for i in range(3):
    with open(f'yolov5_preds_{i}.pkl', 'rb') as fp:
        predictions.append(pickle.load(fp))

for pred in predictions:
    print(pred.shape, pred.dtype)

predictions_tensor = [torch.from_numpy(preds) for preds in predictions]
for pred in predictions_tensor:
    print(pred.shape, pred.dtype)


(1, 40, 40, 3, 85) float32
(1, 20, 20, 3, 85) float32
(1, 10, 10, 3, 85) float32
torch.Size([1, 40, 40, 3, 85]) torch.float32
torch.Size([1, 20, 20, 3, 85]) torch.float32
torch.Size([1, 10, 10, 3, 85]) torch.float32


In [4]:
strides = (8,16,32)
anchors = [
            [[10, 13], [16, 30], [33, 23]],
            [[30, 61], [62, 45], [59, 119]],
            [[116, 90], [156, 198], [373, 326]]
        ]
detections = 100

In [None]:
def inference(preds, image_shapes, scale_factors, max_size, score_thresh, nms_thresh, merge=True):
    anchors_tens = torch.tensor(anchors)
    ids, ps, boxes = [], [], []
    for pred, stride, wh in zip(preds, strides, anchors_tens): # 3.54s
        pred = torch.sigmoid(pred)
        n, y, x, a = torch.where(pred[..., 4] > score_thresh)
        # print(f'{n}\n, {y}\n, {x}\n, {a}\n')
        p = pred[n, y, x, a]
        # print(f'selected preds: {p}')
        xy = torch.stack((x, y), dim=1)
        # print(f'selected xy: {xy}')
        xy = (2 * p[:, :2] - 0.5 + xy) * stride
        wh = 4 * p[:, 2:4] ** 2 * wh[a]
        box = torch.cat((xy, wh), dim=1)
        
        ids.append(n)
        ps.append(p)
        boxes.append(box)
        
    ids = torch.cat(ids)
    ps = torch.cat(ps)
    boxes = torch.cat(boxes)
    # print("selected boxes",boxes)
    boxes = box_ops.cxcywh2xyxy(boxes)
    # print(boxes)
    logits = ps[:, [4]] * ps[:, 5:]
    indices, labels = torch.where(logits > score_thresh) # 4.94s
    ids, boxes, scores = ids[indices], boxes[indices], logits[indices, labels]
    
    results = []
    for i, im_s in enumerate(image_shapes): # 20.97s
        keep = torch.where(ids == i)[0] # 3.11s
        box, label, score = boxes[keep], labels[keep], scores[keep]
        #ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] # 0.27s
        #keep = torch.where((ws >= self.min_size) & (hs >= self.min_size))[0] # 3.33s
        #boxes, objectness, logits = boxes[keep], objectness[keep], logits[keep] # 0.36s
        
        if len(box) > 0:
            box[:, 0].clamp_(0, im_s[1]) # 0.39s
            box[:, 1].clamp_(0, im_s[0]) #~
            box[:, 2].clamp_(0, im_s[1]) #~
            box[:, 3].clamp_(0, im_s[0]) #~
            
            keep = box_ops.batched_nms(box, score, label, nms_thresh, max_size) # 4.43s
            keep = keep[:detections]
            
            nms_box, nms_label = box[keep], label[keep]
            if merge: # slightly increase AP, decrease speed ~14%
                mask = nms_label[:, None] == label[None]
                iou = (box_ops.box_iou(nms_box, box) * mask) > nms_thresh # 1.84s
                weights = iou * score[None] # 0.14s
                nms_box = torch.mm(weights, box) / weights.sum(1, keepdim=True) # 0.55s
                
            box, label, score = nms_box / scale_factors[i], nms_label, score[keep] # 0.30s
        results.append(dict(boxes=box, labels=label, scores=score)) # boxes format: (xmin, ymin, xmax, ymax)
        
    return results

In [18]:
from yolo.model.transform import Transformer
import cv2
from torchvision import transforms

In [19]:
transformer = Transformer(
min_size=320, max_size=320, stride=max(strides))
transformer.eval()
ori_img = cv2.imread("test_one/000000317863.jpg")
resized_img = cv2.resize(ori_img, (320, 320))
img = transforms.ToTensor()(ori_img)
images, targets, scale_factors, image_shapes = transformer([img], targets=None)
max_size = max(images.shape[2:])

In [20]:
res = inference(predictions_tensor, image_shapes, scale_factors, max_size, 0.3, 0.4)

tensor([[207.5607, 165.4065, 267.2700, 215.4290],
        [194.9681, 180.0877, 264.3637, 214.4550],
        [208.8387, 173.3563, 265.4771, 217.8499],
        [207.2782, 172.2788, 266.4529, 219.1154],
        [182.8741,  47.3400, 278.9813, 196.6384],
        [183.1544,  48.9135, 278.5332, 195.0475],
        [183.4637,  47.9795, 278.3604, 195.7554],
        [183.7036,  48.6794, 279.6385, 197.2514],
        [184.0784,  48.9719, 279.8970, 196.6795],
        [184.4643,  48.1092, 279.6195, 197.4302],
        [184.8304,  48.3282, 282.0228, 197.2558],
        [184.6781,  49.6504, 282.3526, 195.9738],
        [184.9131,  48.3072, 282.2008, 196.6118],
        [182.8175,  57.2802, 281.7211, 197.1939],
        [183.8640,  57.7651, 281.3806, 196.4873],
        [184.1888,  56.7443, 281.2935, 197.4725],
        [196.7213, 169.3411, 267.6656, 215.3331],
        [196.9402, 169.3810, 267.5216, 215.4283],
        [196.6767, 169.0664, 267.5294, 215.3776],
        [187.4988, 179.6867, 268.2648, 218.8331],
