In [71]:
import torch
import numpy as np
from torch.nn import MSELoss, Module
from tqdm import tqdm

In [70]:
def JointMSE(output, target, target_weight=None):
    criterion = MSELoss()
    batch_size = output.shape[0]
    num_joints = output.shape[1]
    heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
    heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
    loss = 0

    for idx in range(num_joints):
        heatmap_pred = heatmaps_pred[idx].squeeze()
        heatmap_gt = heatmaps_gt[idx].squeeze()
        if target_weight is not None:
            loss += 0.5 * criterion(heatmap_pred.mul(target_weight[:, idx]),
                                    heatmap_gt.mul(target_weight[:, idx]))
        else:
            loss += 0.5 * criterion(heatmap_pred, heatmap_gt)
    
    return loss / num_joints

In [42]:
class JointsMSELoss(Module):
    def __init__(self, use_target_weight=False):
        super(JointsMSELoss, self).__init__()
        self.criterion = MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight=None):
        batch_size = output.shape[0]
        num_joints = output.shape[1]
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                if target_weight is None:
                    raise NameError
                loss += 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                )
            else:
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

        return loss / num_joints

In [43]:
crit = JointsMSELoss()

In [75]:
mxx = 0
for i in tqdm(range(100)):
    aa = torch.randn(256, 17, 64, 48)
    bb = torch.randn(256, 17, 64, 48)
    mxx = max(mxx, crit(aa, bb))
print(mxx)

100%|███████████████████████████████████████████████████████████████████████████| 100/100 [00:19<00:00,  5.15it/s]

tensor(1.0011)





In [62]:
A = np.zeros((14, 17, 64, 48))
A[0][0][1][1] = 1
A[1][0][32][24] = 1
A[2][0][1][1] = 1
A[3][0][32][24] = 1
A[4][0][1][1] = 1
A[5][0][32][24] = 1
A[6][0][1][1] = 1
A[7][0][32][24] = 1
A[8][0][1][1] = 1
A[9][0][32][24] = 1
A[10][0][1][1] = 1
A[11][0][32][24] = 1
A[12][0][1][1] = 1
A[13][0][32][24] = 1
A = torch.tensor(A, dtype=torch.float32)
B = np.zeros((14, 17, 64, 48))
B[0][0][63][44] = 1
B = torch.tensor(B, dtype=torch.float32)

In [68]:
L = crit(A,B)

In [69]:
print(L)

tensor(1.0258e-05)


In [56]:
L.backward()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn