In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from utils import *
def bbox_iou(bboxes_1, bboxes_2):
    len_bboxes_1 = bboxes_1.shape[0]
    len_bboxes_2 = bboxes_2.shape[0]
    ious = np.zeros((len_bboxes_1, len_bboxes_2))

    for idx, bbox_1 in enumerate(bboxes_1):
        yy1_max = np.maximum(bbox_1[0], bboxes_2[:, 0])
        xx1_max = np.maximum(bbox_1[1], bboxes_2[:, 1])
        yy2_min = np.minimum(bbox_1[2], bboxes_2[:, 2])
        xx2_min = np.minimum(bbox_1[3], bboxes_2[:, 3])

        height = np.maximum(0.0, yy2_min - yy1_max)
        width = np.maximum(0.0, xx2_min - xx1_max)

        eps = np.finfo(np.float32).eps
        inter = height * width
        union = (bbox_1[2] - bbox_1[0]) * (bbox_1[3] - bbox_1[1]) + \
                (bboxes_2[:, 2] - bboxes_2[:, 0]) * (bboxes_2[:, 3] - bboxes_2[:, 1]) - inter + eps
        iou = inter / union
        ious[idx] = iou

    return ious

def format_loc(anchors, base_anchors):
    height = anchors[:, 2] - anchors[:, 0]
    width = anchors[:, 3] - anchors[:, 1]
    ctr_y = anchors[:, 0] + height*0.5
    ctr_x = anchors[:, 1] + width*0.5

    base_height = base_anchors[:, 2] - base_anchors[:, 0]
    base_width = base_anchors[:, 3] - base_anchors[:, 1]
    base_ctr_y = base_anchors[:, 0] + base_height*0.5
    base_ctr_x = base_anchors[:, 1] + base_width*0.5

    eps = np.finfo(np.float32).eps
    height = np.maximum(eps, height)
    width = np.maximum(eps, width)

    dy = (base_ctr_y - ctr_y) / height
    dx = (base_ctr_x - ctr_x) / width
    dh = np.log(base_height / height)
    dw = np.log(base_width / width)

    anchor_loc_target = np.stack((dy, dx, dh, dw), axis=1)
    return anchor_loc_target


def deformat_loc(anchors, formatted_base_anchor):
    height = anchors[:, 2] - anchors[:, 0]
    width = anchors[:, 3] - anchors[:, 1]
    ctr_y = anchors[:, 0] + height*0.5
    ctr_x = anchors[:, 1] + width*0.5

    dy, dx, dh, dw = formatted_base_anchor.T
    base_height = np.exp(dh) * height
    base_width = np.exp(dw) * width
    base_ctr_y = dy * height + ctr_y
    base_ctr_x = dx * width + ctr_x

    base_anchors = np.zeros_like(anchors)
    base_anchors[:, 0] = base_ctr_y - base_height*0.5
    base_anchors[:, 1] = base_ctr_x - base_width*0.5
    base_anchors[:, 2] = base_ctr_y + base_height*0.5
    base_anchors[:, 3] = base_ctr_x + base_width*0.5

    return base_anchors


def nms(rois, scores, nms_thresh):
    order = scores.argsort()[::-1]
    y1, x1, y2, x2 = rois.T

    keep_index = []

    while order.size > 0:
        i = order[0]
        keep_index.append(i)
        ious = bbox_iou(rois[i][np.newaxis, :], rois[order[1:]])
        inds = np.where(ious <= nms_thresh)[1]
        order = order[inds + 1]
    return keep_index


In [4]:
image = torch.zeros((1, 3, 800, 800)).float()
image_size = (800, 800)

# bbox -> y1, x1, y2, x2
bbox = torch.FloatTensor([[20, 30, 400, 500], [300, 400, 500, 600]])
labels = torch.LongTensor([6, 8])

sub_sample = 16

vgg16 = torchvision.models.vgg16(pretrained=True)
req_features = vgg16.features[:30]
output_map = req_features(image)
print(output_map.shape)

torch.Size([1, 512, 50, 50])


In [5]:
anchor_scale = [8, 16, 32]
ratio = [0.5, 1, 2] # H/W

len_anchor_scale = len(anchor_scale)
len_ratio = len(ratio)
len_anchor_template = len_anchor_scale * len_ratio
anchor_template = np.zeros((9, 4))

for idx, scale in enumerate(anchor_scale):
    h = scale * np.sqrt(ratio) * sub_sample
    w = scale / np.sqrt(ratio) * sub_sample
    y1 = -h/2
    x1 = -w/2
    y2 = h/2
    x2 = w/2
    anchor_template[idx*len_ratio:(idx+1)*len_ratio, 0] = y1
    anchor_template[idx*len_ratio:(idx+1)*len_ratio, 1] = x1
    anchor_template[idx*len_ratio:(idx+1)*len_ratio, 2] = y2
    anchor_template[idx*len_ratio:(idx+1)*len_ratio, 3] = x2

print(anchor_template)


[[ -45.254834    -90.50966799   45.254834     90.50966799]
 [ -64.          -64.           64.           64.        ]
 [ -90.50966799  -45.254834     90.50966799   45.254834  ]
 [ -90.50966799 -181.01933598   90.50966799  181.01933598]
 [-128.         -128.          128.          128.        ]
 [-181.01933598  -90.50966799  181.01933598   90.50966799]
 [-181.01933598 -362.03867197  181.01933598  362.03867197]
 [-256.         -256.          256.          256.        ]
 [-362.03867197 -181.01933598  362.03867197  181.01933598]]


In [6]:
feature_map_size = (50, 50)
# The first center coors is (8, 8)
ctr_y = np.arange(8, 800, 16)
ctr_x = np.arange(8, 800, 16)

ctr = np.zeros((*feature_map_size, 2))
for idx, y in enumerate(ctr_y):
    ctr[idx, :, 0] = y
    ctr[idx, :, 1] = ctr_x
print(ctr.shape)


(50, 50, 2)


In [7]:
anchors = np.zeros((*feature_map_size, 9, 4))

for idx_y in range(feature_map_size[0]):
    for idx_x in range(feature_map_size[1]):
        anchors[idx_y, idx_x] = (ctr[idx_y, idx_x] + anchor_template.reshape(-1, 2, 2)).reshape(-1, 4)

anchors = anchors.reshape(-1, 4)
print(anchors.shape) # (22500, 4)

(22500, 4)


In [None]:
valid_index = np.where((anchors[:, 0] >= 0)
                      &(anchors[:, 1] >= 0)
                      &(anchors[:, 2] <= 800)
                      &(anchors[:, 3] <= 800))[0]
print(valid_index.shape) # 8940

In [None]:
valid_labels = np.empty((valid_index.shape[0],), dtype=np.int32)
valid_labels.fill(-1)

valid_anchors = anchors[valid_index]

print(valid_anchors.shape) # (8940,4)
print(bbox.shape) # torch.Size([2,4])




In [9]:
ious = bbox_iou(valid_anchors, bbox.numpy()) # anchor 8940 : bbox 2

pos_iou_thres = 0.7
neg_iou_thred = 0.3

# Scenario A - 논문대로 iou가 0.7 이상이면 pos, 0.3 이하면 neg 나며진 무시
anchor_max_iou = np.amax(ious, axis=1)
pos_iou_anchor_label = np.where(anchor_max_iou >= pos_iou_thres)[0]
neg_iou_anchor_label = np.where(anchor_max_iou < neg_iou_thred)[0]
valid_labels[pos_iou_anchor_label] = 1
valid_labels[neg_iou_anchor_label] = 0

# Scenario B
gt_max_iou = np.amax(ious, axis=0)
gt_max_iou_anchor_label = np.where(ious == gt_max_iou)[0]
print(gt_max_iou_anchor_label)
valid_labels[gt_max_iou_anchor_label] = 1


NameError: name 'bbox_iou' is not defined

In [None]:
n_sample_anchors = 256
pos_ratio = 0.5

total_n_pos = len(np.where(valid_labels == 1)[0])
n_pos_sample = n_sample_anchors*pos_ratio if total_n_pos > n_sample_anchors*pos_ratio else total_n_pos
n_neg_sample = n_sample_anchors - n_pos_sample

pos_index = np.where(valid_labels == 1)[0]
if len(pos_index) > n_sample_anchors*pos_ratio:
    disable_index = np.random.choice(pos_index, size=len(pos_index)-n_pos_sample, replace=False)
    valid_labels[disable_index] = -1

neg_index = np.where(valid_labels == 0)[0]
disable_index = np.random.choice(neg_index, size=len(neg_index) - n_neg_sample, replace=False)
valid_labels[disable_index] = -1

In [None]:

argmax_iou = np.argmax(ious, axis=1)
max_iou_box = bbox[argmax_iou].numpy()
print(max_iou_box.shape) # 8940, 4
print(valid_anchors.shape) # 8940, 4

anchor_loc_format_target = format_loc(valid_anchors, max_iou_box)
print(anchor_loc_format_target.shape) # 8940, 4

In [None]:
anchor_target_labels = np.empty((len(anchors),), dtype=np.int32)
anchor_target_format_locations = np.zeros((len(anchors), 4), dtype=np.float32)

anchor_target_labels.fill(-1)
anchor_target_labels[valid_index] = valid_labels

anchor_target_format_locations[valid_index] = anchor_loc_format_target

print(anchor_target_labels.shape) # 22500,
print(anchor_target_format_locations.shape) # 22500, 4

# RPN 코드

In [None]:
mid_channel = 512
in_channel = 512
n_anchor = 9

conv1 = nn.Conv2d(in_channel, mid_channel, 3, 1, 1)
reg_layer = nn.Conv2d(mid_channel, n_anchor*4, 1, 1, 0)
cls_layer = nn.Conv2d(mid_channel, n_anchor*2, 1, 1, 0)

x = conv1(output_map)
anchor_pred_format_locations = reg_layer(x)
anchor_pred_scores = cls_layer(x)

print(anchor_pred_format_locations.shape) # torch.Size([1, 36, 50, 50])
print(anchor_pred_scores.shape) # torch.Size([1, 18, 50, 50])


In [None]:
anchor_pred_format_locations = anchor_pred_format_locations.permute(0, 2, 3, 1).contiguous().view(1, -1, 4)
anchor_pred_scores = anchor_pred_scores.permute(0, 2, 3, 1).contiguous().view(1, -1, 2)
objectness_pred_scores = anchor_pred_scores[:, :, 1]

In [None]:
print(anchor_target_labels.shape)
print(anchor_target_format_locations.shape)
print(anchor_pred_scores.shape)
print(anchor_pred_format_locations.shape)

gt_rpn_format_locs = torch.from_numpy(anchor_target_format_locations)
gt_rpn_scores = torch.from_numpy(anchor_target_labels)

rpn_format_locs = anchor_pred_format_locations[0]
rpn_scores = anchor_pred_scores[0]

In [None]:
rpn_cls_loss = F.cross_entropy(rpn_scores, gt_rpn_scores.long(), ignore_index=-1)
print(rpn_cls_loss)


####### location loss
mask = gt_rpn_scores > 0
mask_target_format_locs = gt_rpn_format_locs[mask]
mask_pred_format_locs = rpn_format_locs[mask]

print(mask_target_format_locs.shape)
print(mask_pred_format_locs.shape)

x = torch.abs(mask_target_format_locs - mask_pred_format_locs)
rpn_loc_loss = ((x<0.5).float()*(x**2)*0.5 + (x>0.5).float()*(x-0.5)).sum()
print(rpn_loc_loss)

rpn_lambda = 10
N_reg = mask.float().sum()

rpn_loss = rpn_cls_loss + rpn_lambda / N_reg * rpn_loc_loss
print(rpn_loss)

In [None]:
nms_thresh = 0.7
n_train_pre_nms = 12000
n_train_post_nms = 2000
n_test_pre_nms = 6000
n_test_post_nms = 300
min_size = 16


print(anchors.shape) # 22500, 4
print(anchor_pred_format_locations.shape) # 22500, 4

rois = deformat_loc(anchors=anchors, formatted_base_anchor=anchor_pred_format_locations[0].data.numpy())
print(rois.shape) # 22500, 4

print(rois)

In [None]:
rois[:, 0:4:2] = np.clip(rois[:, 0:4:2], a_min=0, a_max=image_size[0])
rois[:, 1:4:2] = np.clip(rois[:, 1:4:2], a_min=0, a_max=image_size[1])
print(rois)

h = rois[:, 2] - rois[:, 0]
w = rois[:, 3] - rois[:, 1]

valid_index = np.where((h>min_size)&(w>min_size))[0]
valid_rois = rois[valid_index]
valid_scores = objectness_pred_scores[0][valid_index].data.numpy()

In [None]:
valid_score_order = valid_scores.ravel().argsort()[::-1]

pre_train_valid_score_order = valid_score_order[:n_train_pre_nms]
pre_train_valid_rois = valid_rois[pre_train_valid_score_order]
pre_train_valid_scores = valid_scores[pre_train_valid_score_order]

print(pre_train_valid_rois.shape) # 12000, 4
print(pre_train_valid_scores.shape) # 12000,
print(pre_train_valid_score_order.shape) # 12000,

In [None]:
keep_index = nms(rois=pre_train_valid_rois, scores=pre_train_valid_scores, nms_thresh=nms_thresh)
post_train_valid_rois = pre_train_valid_rois[keep_index][:n_train_post_nms]
post_train_valid_scores = pre_train_valid_scores[keep_index][:n_train_post_nms]
print(post_train_valid_rois.shape) # 2000, 4
print(post_train_valid_scores.shape) # 2000,