Skip to content

Commit

Permalink
fix a bug in the get_accuracy function of interhand_3d_head (open-mml…
Browse files Browse the repository at this point in the history
  • Loading branch information
zengwang430521 committed Sep 1, 2021
1 parent 7500a83 commit daa2407
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmpose/core/evaluation/top_down_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
pred, gt = pred[valid], gt[valid]

if pred.shape[0] == 0:
acc = 0 # when no sample is with gt labels, set acc to 0.
acc = 0.0 # when no sample is with gt labels, set acc to 0.
else:
# The classification of a sample is regarded as correct
# only if it's correct for all labels.
Expand Down
3 changes: 2 additions & 1 deletion mmpose/models/heads/interhand_3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,12 @@ def get_accuracy(self, output, target, target_weight):
multiple heads.
"""
accuracy = dict()
accuracy['acc_classification'] = multilabel_classification_accuracy(
avg_acc = multilabel_classification_accuracy(
output[2].detach().cpu().numpy(),
target[2].detach().cpu().numpy(),
target_weight[2].detach().cpu().numpy(),
)
accuracy['acc_classification'] = float(avg_acc)
return accuracy

def forward(self, x):
Expand Down

0 comments on commit daa2407

Please sign in to comment.