Skip to content

Commit

Permalink
Merge pull request open-mmlab#2 from xiaohu2015/yolox
Browse files Browse the repository at this point in the history
add train loss for yolox
  • Loading branch information
hhaAndroid authored Jul 22, 2021
2 parents e2612a3 + d90150c commit 20783b6
Showing 1 changed file with 145 additions and 71 deletions.
216 changes: 145 additions & 71 deletions mmdet/models/dense_heads/yolox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,75 @@ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):

return output

class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type

def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]

pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)

area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)

en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
iou = (area_i) / (area_p + area_g - area_i + 1e-16)

if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)

if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()

return loss

def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError

if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)

area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)

@HEADS.register_module()
class YOLOXHead(BaseDenseHead, BBoxTestMixin):
Expand Down Expand Up @@ -164,8 +233,7 @@ def __init__(
self.use_l1 = False
self.l1_loss = nn.L1Loss(reduction="none")
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
# self.iou_loss = IOUloss(reduction="none")
self.iou_loss = None
self.iou_loss = IOUloss(reduction="none")
self.strides = strides
self.grids = [torch.zeros(1)] * len(in_channels)
self.expanded_strides = [None] * len(in_channels)
Expand Down Expand Up @@ -207,7 +275,7 @@ def forward(self, feats):
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)

output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
output = [reg_output, obj_output, cls_output]

outputs.append(output)

Expand All @@ -221,6 +289,8 @@ def get_bboxes(self,
rescale=False,
with_nms=True):
cfg = self.test_cfg if cfg is None else cfg
outputs = [torch.cat([reg_out, obj_out.sigmoid(), cls_out.sigmoid()], 1) \
for reg_out, obj_out, cls_out in outputs]
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
Expand All @@ -246,65 +316,71 @@ def get_bboxes(self,

imgs_det.append(tuple([mlvl_bboxes, mlvl_label]))
return imgs_det

def loss(self,
preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
preds (list[list[Tensor]]): level predictions (reg_output, obj_output, cls_output)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""

def loss(self, xin, labels=None, imgs=None):
outputs = []
origin_preds = []
x_shifts = []
y_shifts = []
expanded_strides = []

for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
):
x = self.stems[k](x)
cls_x = x
reg_x = x

cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)

reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)

if self.training:
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(xin[0])
for k, (out, stride_this_level) in enumerate(zip(preds, self.strides)):
reg_output, obj_output, cls_output = out
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(output, k, stride_this_level, reg_output.type())
# output [N, A*H*W, 5+K] (x, y, w, h, obj, classes...)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(reg_output.type())
)
if self.use_l1:
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(batch_size, self.n_anchors, 4, hsize, wsize)
reg_output = (
reg_output.permute(0, 1, 3, 4, 2)
.reshape(batch_size, -1, 4)
)
if self.use_l1:
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(batch_size, self.n_anchors, 4, hsize, wsize)
reg_output = (
reg_output.permute(0, 1, 3, 4, 2)
.reshape(batch_size, -1, 4)
)
origin_preds.append(reg_output.clone())

else:
output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
origin_preds.append(reg_output.clone())

outputs.append(output)

if self.training:
return self.get_losses(
imgs, x_shifts, y_shifts, expanded_strides, labels,
torch.cat(outputs, 1), origin_preds, dtype=xin[0].dtype
)
else:
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
if self.decode_in_inference:
return self.decode_outputs(outputs, dtype=xin[0].type())
else:
return outputs
loss_iou, loss_obj, loss_cls, loss_l1 = self.get_losses(
imgs, x_shifts, y_shifts, expanded_strides, gt_labels, gt_boxes,
torch.cat(outputs, 1), origin_preds, dtype=origin_preds[0].dtype)

if self.use_l1:
return dict(
loss_cls=loss_cls,
loss_iou=loss_iou,
loss_obj=loss_obj,
loss_l1=loss_l1)
else:
return dict(
loss_cls=loss_cls,
loss_iou=loss_iou,
loss_obj=loss_obj)

def get_output_and_grid(self, output, k, stride, dtype):
grid = self.grids[k]

Expand Down Expand Up @@ -343,27 +419,18 @@ def decode_outputs(self, outputs, dtype):
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
return outputs

def get_losses(
self, imgs, x_shifts, y_shifts, expanded_strides, labels, outputs, origin_preds, dtype,
):
def get_losses(self, imgs, x_shifts, y_shifts, expanded_strides, gt_labels, gt_boxes,
outputs, origin_preds, dtype):
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]

# calculate targets
mixup = labels.shape[2] > 5
if mixup:
label_cut = labels[..., :5]
else:
label_cut = labels
nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) # number of objects

total_num_anchors = outputs.shape[1]
x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
expanded_strides = torch.cat(expanded_strides, 1)
expanded_strides = torch.cat(expanded_strides, 1) # [1, n_anchors_all]
if self.use_l1:
origin_preds = torch.cat(origin_preds, 1)
origin_preds = torch.cat(origin_preds, 1) #[N, n_anchors_all, 4]

cls_targets = []
reg_targets = []
Expand All @@ -375,7 +442,7 @@ def get_losses(
num_gts = 0.0

for batch_idx in range(outputs.shape[0]):
num_gt = int(nlabel[batch_idx])
num_gt = len(gt_labels[batch_idx])
num_gts += num_gt
if num_gt == 0:
cls_target = outputs.new_zeros((0, self.num_classes))
Expand All @@ -384,16 +451,24 @@ def get_losses(
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
gt_classes = labels[batch_idx, :num_gt, 0]
gt_bboxes_per_image = gt_bboxes[batch_idx]
# convert x1,y1,x2,y2 to xywh
gt_bboxes_per_image = torch.stack(
[(gt_bboxes_per_image[:, 0] + gt_bboxes_per_image[:, 2]) * 0.5,
(gt_bboxes_per_image[:, 1] + gt_bboxes_per_image[:, 3]) * 0.5,
(gt_bboxes_per_image[:, 2] - gt_bboxes_per_image[:, 0]),
(gt_bboxes_per_image[:, 3] - gt_bboxes_per_image[:, 1])
], dim=1
)
gt_classes = gt_labels[batch_idx]
bboxes_preds_per_image = bbox_preds[batch_idx]

try:
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
# noqa
batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
cls_preds, bbox_preds, obj_preds, labels, imgs,
cls_preds, bbox_preds, obj_preds, imgs,
)
except RuntimeError:
logger.error(
Expand All @@ -406,7 +481,7 @@ def get_losses(
# noqa
batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
cls_preds, bbox_preds, obj_preds, labels, imgs, "cpu",
cls_preds, bbox_preds, obj_preds, imgs, "cpu",
)

torch.cuda.empty_cache()
Expand Down Expand Up @@ -452,9 +527,8 @@ def get_losses(
loss_l1 = 0.0

reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1

return loss, reg_weight * loss_iou, loss_obj, loss_cls, loss_l1, num_fg / max(num_gts, 1)
return reg_weight * loss_iou, loss_obj, loss_cls, loss_l1

def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
l1_target[:, 0] = gt[:, 0] / stride - x_shifts
Expand All @@ -467,7 +541,7 @@ def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
def get_assignments(
self, batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,
bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts,
cls_preds, bbox_preds, obj_preds, labels, imgs, mode="gpu",
cls_preds, bbox_preds, obj_preds, imgs, mode="gpu",
):

if mode == "cpu":
Expand Down

0 comments on commit 20783b6

Please sign in to comment.