diff --git a/src/detector.c b/src/detector.c index fc560f8048f..ba4d809b3a6 100644 --- a/src/detector.c +++ b/src/detector.c @@ -556,7 +556,7 @@ int detections_comparator(const void *pa, const void *pb) return 0; } -void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou) +void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh) { int j; list *options = read_data_cfg(datacfg); @@ -597,7 +597,7 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float const float thresh = .005; const float nms = .45; - const float iou_thresh = 0.5; + //const float iou_thresh = 0.5; int nthreads = 4; image *val = calloc(nthreads, sizeof(image)); @@ -876,7 +876,12 @@ void validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100); mean_average_precision = mean_average_precision / classes; - printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision*100); + if (iou_thresh == 0.5) { + printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision * 100); + } + else { + printf("\n average precision (AP) = %f, or %2.2f %% for IoU threshold = %f \n", mean_average_precision, mean_average_precision * 100, iou_thresh); + } for (i = 0; i < classes; ++i) { @@ -1235,6 +1240,7 @@ void run_detector(int argc, char **argv) char *outfile = find_char_arg(argc, argv, "-out", 0); char *prefix = find_char_arg(argc, argv, "-prefix", 0); float thresh = find_float_arg(argc, argv, "-thresh", .25); // 0.24 + float iou_thresh = find_float_arg(argc, argv, "-iou_thresh", .5); // 0.5 for mAP float hier_thresh = find_float_arg(argc, argv, "-hier", .5); int cam_index = find_int_arg(argc, argv, "-c", 0); int frame_skip = find_int_arg(argc, argv, "-s", 0); @@ -1285,7 +1291,7 @@ void run_detector(int argc, char **argv) else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show); else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); - else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh); + else if(0==strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh); else if(0==strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show); else if(0==strcmp(argv[2], "demo")) { list *options = read_data_cfg(datacfg);