From daa2407cd292145038b1750db71efd54c41214ce Mon Sep 17 00:00:00 2001 From: zengwang430521 Date: Wed, 1 Sep 2021 13:17:54 +0800 Subject: [PATCH] fix a bug in the get_accuracy function of interhand_3d_head (#890) --- mmpose/core/evaluation/top_down_eval.py | 2 +- mmpose/models/heads/interhand_3d_head.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mmpose/core/evaluation/top_down_eval.py b/mmpose/core/evaluation/top_down_eval.py index a710809577..b5395808b2 100644 --- a/mmpose/core/evaluation/top_down_eval.py +++ b/mmpose/core/evaluation/top_down_eval.py @@ -672,7 +672,7 @@ def multilabel_classification_accuracy(pred, gt, mask, thr=0.5): pred, gt = pred[valid], gt[valid] if pred.shape[0] == 0: - acc = 0 # when no sample is with gt labels, set acc to 0. + acc = 0.0 # when no sample is with gt labels, set acc to 0. else: # The classification of a sample is regarded as correct # only if it's correct for all labels. diff --git a/mmpose/models/heads/interhand_3d_head.py b/mmpose/models/heads/interhand_3d_head.py index 252963469a..3615c60931 100644 --- a/mmpose/models/heads/interhand_3d_head.py +++ b/mmpose/models/heads/interhand_3d_head.py @@ -372,11 +372,12 @@ def get_accuracy(self, output, target, target_weight): multiple heads. """ accuracy = dict() - accuracy['acc_classification'] = multilabel_classification_accuracy( + avg_acc = multilabel_classification_accuracy( output[2].detach().cpu().numpy(), target[2].detach().cpu().numpy(), target_weight[2].detach().cpu().numpy(), ) + accuracy['acc_classification'] = float(avg_acc) return accuracy def forward(self, x):