Skip to content

Commit

Permalink
Fixed bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
NieXC committed Aug 22, 2018
1 parent 9a7bb58 commit 1121ab6
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')

parser.add_argument('--evaluate', default=False, type=bool, metavar='BOOL', help='evaluate or train')
parser.add_argument('--calc-map', default=False, type=bool, metavar='BOOL', help='Calculate mAP or not')
parser.add_argument('--pred-path', default='exps/preds/mat_results/pred_keypoints_mpii_multi.mat', type=str, metavar='PATH', help='path to save the predction results in .mat format')
parser.add_argument('--visualization', default=False, type=bool, metavar='BOOL', help='visualize prediction or not')
parser.add_argument('--vis-dir', default='exps/preds/vis_results', metavar='DIR', help='path to save visualization results')
Expand Down Expand Up @@ -142,7 +143,7 @@ def main():
visualization=args.visualization, \
vis_result_dir=args.vis_dir, \
pred_path=args.pred_path, \
phase='VAL')
is_calc_map=args.calc_map)
return

for epoch in range(args.start_epoch, args.epochs):
Expand Down Expand Up @@ -181,7 +182,7 @@ def main():
visualization=args.visualization, \
vis_result_dir=args.vis_dir, \
pred_path=args.pred_path, \
phase='VAL')
is_calc_map=True)

is_best = map_avg > best_map
best_map = max(map_avg, best_map)
Expand Down Expand Up @@ -275,7 +276,7 @@ def evaluate(model, \
vis_result_dir='preds/vis_results', \
gt_path='dataset/mpi/val_gt/mpi_val_groundtruth.mat', \
pred_path='exps/preds/mat_results/pred_keypoints_mpii_multi.mat', \
phase='VAL'):
is_calc_map=True):

model.eval()
mp_pose_list = eval_util.multi_image_testing_on_mpi_mp_dataset(model, \
Expand All @@ -297,9 +298,9 @@ def evaluate(model, \
vis_result_dir=vis_result_dir)

eval_util.save_mppe_results_to_mpi_format(mp_pose_list, save_path=pred_path)

map_avg = 0.0
if phase == 'VAL':
if is_calc_map:
map_all = calc_mAP(gt_path=gt_path, pred_path=pred_path)
map_avg = map_all[-1]
map_all_list.append(map_all)
Expand Down

0 comments on commit 1121ab6

Please sign in to comment.