In [40]:
import torch
from torch import nn

In [41]:
class DistanceLoss(nn.Module):
    def __init__(self, loss_type='L2', reduction='mean'):
        super().__init__()
        if loss_type.lower() == 'l2':
            self.criterion = nn.MSELoss(reduction='none')
        elif loss_type.lower() == 'l1':
            self.criterion = nn.L1Loss(reduction='none')
        else:
            self.criterion = nn.SmoothL1Loss(reduction='none')
        
        assert reduction in ['mean', 'sum', None], f"Error: {reduction=}"
        self.reduction = reduction

    def forward(self, output, target, target_weight):
        """
        Args:
            output (tensor): [N, K, H, W]
            target (tensor): [N, K, H, W]
            target_weight (tensor): [N, K, 1]
        """
        loss = self.criterion(output, target)
        loss *= target_weight.unsqueeze(-1)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

In [42]:
class KLDiscretLoss(nn.Module):
    """计算预测x, y的 1d vector的损失。

    Args:
        nn ([type]): [description]
    """
    def __init__(self):
        super(KLDiscretLoss, self).__init__()
        # self.LogSoftmax = nn.LogSoftmax(dim=1)  # [B,LOGITS]
        # self.criterion_ = nn.KLDivLoss(reduction='none')
        # self.softmax = nn.Softmax(dim=1)
        self.criterion_ = nn.SmoothL1Loss(reduction='mean')

        
    def criterion(self, dec_outs, labels):
        # scores = self.LogSoftmax(dec_outs)
        # gt_scorces = self.softmax(labels)
        # loss = torch.mean(self.criterion_(scores, gt_scorces), dim=1)
        loss = self.criterion_(dec_outs, labels)
        return loss

    def forward(self, output_x, output_y, target_x, target_y, target_weight):
        num_joints = output_x.size(1)
        loss = 0

        for idx in range(num_joints):
            coord_x_pred = output_x[:, idx].squeeze()
            coord_y_pred = output_y[:, idx].squeeze()
            coord_x_gt = target_x[:, idx].squeeze()
            coord_y_gt = target_y[:, idx].squeeze()
            weight = target_weight[:, idx].squeeze()
            loss += (self.criterion(coord_x_pred, coord_x_gt).mul(weight).mean())
            loss += (self.criterion(coord_y_pred, coord_y_gt).mul(weight).mean())
        return loss / num_joints

In [43]:
class SimDRLoss(nn.Module):
    def __init__(self, cfg=None):
        super().__init__()
        # image_size = cfg.DATASET.image_size
        # heatmap_size = cfg.DATASET.heatmap_size
        # k = cfg.LOSS.simdr_split_ratio
        image_size = (12, 12)
        heatmap_size = (6, 6)
        k = 2

        self.simdr_width = int(k * image_size[0])
        self.simdr_height = int(k * image_size[1])

        in_features = int(heatmap_size[0] * heatmap_size[1])
        self.x_shared_decoder = nn.Linear(in_features, self.simdr_width)
        self.y_shared_decoder = nn.Linear(in_features, self.simdr_height)
        self.loss = KLDiscretLoss()

    def forward(self, heatmap, simdr_x, simdr_y, target_weight):
        """
        Args:
            heatmap (tensor): [B, K, H, W]
            simdr_x (tensor): [B, K, simdr_width] target x vector
            simdr_y (tensor): [B, K, simdr_height] target y vector
            target_weight (tensor): [B, K, 1]
        """
        pred_x = self.x_shared_decoder(heatmap.flatten(start_dim=2))
        pred_y = self.y_shared_decoder(heatmap.flatten(start_dim=2))
        loss = self.loss(pred_x, pred_y, simdr_x, simdr_y, target_weight)
        
        return loss


In [44]:
hm = torch.rand(2, 3, 6, 6)
tx = torch.rand(2, 3, 24)
ty = torch.rand(2, 3, 24)
tw = torch.ones((2, 3, 1))
a = SimDRLoss()
a(hm, tx, ty, tw)

tensor(0.4933, grad_fn=<DivBackward0>)

In [None]:

class MultiTaskLoss(nn.Module):
    """
        MTL多任务学习,自动权重调节: https://zhuanlan.zhihu.com/p/367881339
    """
    def __init__(self, cfg):
        super().__init__()
        self.criterion = DistanceLoss(loss_type='L2', reduction='mean')
        # self.smoothl1_loss = JointsDistanceLoss(use_target_weight=True,
        #                                         loss_type='smoothl1')
        self.loss_weight = [1., 1.]   # cfg.LOSS.loss_weight
        self.auto_weight = True       # cfg.LOSS.auto_weight
        if self.auto_weight:
            params = torch.ones(len(self.loss_weight), requires_grad=True)
            # TODO:将这个参数也放入优化器的参数优化列表中
            self.p = nn.Parameter(params, requires_grad=True)

    def forward(self, output, target, target_weight):
        device = output.device
        # print(f"{output.shape=}")
        # print(f"{target.shape=}")
        # print(f"{target_weight.shape=}")
        kpt_loss = self.criterion(output, target.to(device),
                                target_weight.to(device)) * self.loss_weight[0]

        loss_dict = dict(kpt_loss=kpt_loss.item())
        return kpt_loss, loss_dict