diff --git a/config/diadem_metric.json b/config/diadem_metric.json index 66694ea..341b979 100644 --- a/config/diadem_metric.json +++ b/config/diadem_metric.json @@ -3,9 +3,9 @@ "remove_spur": false, "count_excess_nodes": true, "align_tree_by_root": false, - "list_miss": false, - "list_distant_matches": false, - "list_continuations": false, + "list_miss": true, + "list_distant_matches": true, + "list_continuations": true, "find_proper_root": true, "z_scale": 1, "TRAJECTORY_NONE": -1.0, diff --git a/config/ssd_metric.json b/config/ssd_metric.json index aca4ea2..7a23452 100644 --- a/config/ssd_metric.json +++ b/config/ssd_metric.json @@ -1,6 +1,6 @@ { "threshold_mode": 1, - "ssd_threshold": 2, + "ssd_threshold": 1, "up_sample_threshold": 1.0, "z_scale": 1, "debug": false diff --git a/pyneval/cli/pyneval.py b/pyneval/cli/pyneval.py index caa90a8..2be81f0 100644 --- a/pyneval/cli/pyneval.py +++ b/pyneval/cli/pyneval.py @@ -185,16 +185,6 @@ def set_configs(abs_dir, args): # info: how many trees read print("There are {} test image(s)".format(len(test_swc_trees))) - # argument: output - output_dir = None - if args.output: - output_dir = os.path.join(abs_dir, args.output) - - # argument: detail - detail_dir = None - if args.detail: - detail_dir = os.path.join(abs_dir, args.detail) - # argument: config config_path = args.config if config_path is None: @@ -208,36 +198,46 @@ def set_configs(abs_dir, args): except Exception: raise Exception("[Error: ]Error in analyzing config json file") + # argument: output + output_dir = None + if args.output: + output_dir = os.path.join(abs_dir, args.output) + + # argument: detail + detail_dir = None + if args.detail: + detail_dir = os.path.join(abs_dir, args.detail) + config["detail"] = True + # argument: debug is_debug = args.debug return gold_swc_tree, test_swc_trees, metric, output_dir, detail_dir, config, is_debug -def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir, file_name_extra=""): +def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir): metric_method = get_metric_method(metric) test_swc_name = test_swc_tree.get_name() - gold_swc_name = gold_swc_tree.get_name() - result = metric_method(gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree, config=config) + result, res_gold_swc_tree, res_test_swc_tree = metric_method(gold_swc_tree=gold_swc_tree, + test_swc_tree=test_swc_tree, config=config) print("---------------Result---------------") for key in result: print("{} = {}".format(key.ljust(15, ' '), result[key])) print("----------------End-----------------\n") - if file_name_extra == "reverse": - file_name = gold_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc" - else: - file_name = test_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc" + file_name = test_swc_name[:-4] + "_" + metric + "_" if detail_dir: - swc_save(swc_tree=gold_swc_tree, - out_path=os.path.join(detail_dir, file_name)) + swc_save(swc_tree=res_gold_swc_tree, + out_path=os.path.join(detail_dir, file_name + "recall.swc")) + swc_save(swc_tree=res_test_swc_tree, + out_path=os.path.join(detail_dir, file_name + "precision.swc")) if output_dir: read_json.save_json(data=result, - json_file_path=os.path.join(output_dir, file_name)) + json_file_path=os.path.join(output_dir, file_name + ".json")) # command program @@ -254,9 +254,6 @@ def run(): for test_swc_tree in test_swc_trees: excute_metric(metric=metric, gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree, config=config, detail_dir=detail_dir, output_dir=output_dir) - if metric in ["length_metric", "diadem_metric"]: - excute_metric(metric=metric, gold_swc_tree=test_swc_tree, test_swc_tree=gold_swc_tree, - config=config, detail_dir=detail_dir, output_dir=output_dir, file_name_extra="reverse") if __name__ == "__main__": @@ -278,4 +275,4 @@ def run(): # pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric -# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/ --metric branch_metric --detail ./output +# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/test_fake_data1.swc --metric branch_metric --detail ./output/detail --output ./output/output diff --git a/pyneval/io/read_json.py b/pyneval/io/read_json.py index 48df698..06a2c5f 100644 --- a/pyneval/io/read_json.py +++ b/pyneval/io/read_json.py @@ -26,7 +26,7 @@ def save_json(json_file_path, data, DEBUG=False): raise Exception("[Error: ] \" {} \" is not a json file. Wrong format".format(json_file_path)) try: with open(json_file_path, 'w') as f: - json.dump(data, f) + json.dump(data, f, indent=4) if DEBUG: print(type(data)) except: diff --git a/pyneval/metric/branch_leaf_metric.py b/pyneval/metric/branch_leaf_metric.py index 34ae7d1..d6d8fb4 100644 --- a/pyneval/metric/branch_leaf_metric.py +++ b/pyneval/metric/branch_leaf_metric.py @@ -188,7 +188,7 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config): "pt_cost": branch_result_tuple[7], "iso_node_num": branch_result_tuple[8] } - return branch_result + return branch_result, gold_swc_tree, test_swc_tree if __name__ == "__main__": diff --git a/pyneval/metric/diadem_metric.py b/pyneval/metric/diadem_metric.py index 91cadd4..cf6ed1e 100644 --- a/pyneval/metric/diadem_metric.py +++ b/pyneval/metric/diadem_metric.py @@ -780,24 +780,24 @@ def color_tree_only(): if g_list_miss: if len(g_miss) > 0: for node in g_miss: - # 3 means this node is missed - node.data._type = 2 + # 9 means this node is missed + node.data._type = 9 if len(g_excess_nodes) > 0: for node in g_excess_nodes.keys(): - # 4 means this node is excessive - node.data._type = 3 + # 10 means this node is excessive + node.data._type = 10 if g_list_continuations: if len(g_continuation) > 0: for node in g_continuation: - # 5 means this node is a continuation - node.data._type = 4 + # 11 means this node is a continuation + node.data._type = 11 if g_list_distant_matches: if len(g_distance_match) > 0: for node in g_distance_match: - # 6 means this node is a distant match - node.data._type = 5 + # 12 means this node is a distant match + node.data._type = 12 def print_result(): @@ -821,8 +821,8 @@ def print_result(): print("node_ID = {} poi = {} weight = {}".format( node.data.get_id(), node.data._pos, g_weight_dict[node] )) - # 3 means this node is missed - node.data._type = 2 + # 9 means this node is missed + node.data._type = 9 print("--END--") else: print("---Nodes that are missed:None---") @@ -835,8 +835,8 @@ def print_result(): print("node_ID = {} poi = {} weight = {}".format( node.data.get_id(), node.data._pos, g_excess_nodes[node] )) - # 4 means this node is excessive - node.data._type = 3 + # 10 means this node is excessive + node.data._type = 10 else: print("---extra Nodes in test reconstruction: None---") @@ -848,8 +848,8 @@ def print_result(): print("node_ID = {} poi = {} weight = {}".format( node.data.get_id(), node.data._pos, g_weight_dict[node] )) - # 5 means this node is a continuation - node.data._type = 4 + # 11 means this node is a continuation + node.data._type = 11 else: print("---continuation Nodes None---") @@ -861,8 +861,8 @@ def print_result(): print("node_ID = {} poi = {} weight = {}".format( node.data.get_id(), node.data._pos, g_weight_dict[node] )) - # 6 means this node is a distant match - node.data._type = 5 + # 12 means this node is a distant match + node.data._type = 12 else: print("Distant Matches: none") @@ -935,8 +935,8 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config): """ global g_spur_set global g_weight_dict - gold_swc_tree.type_clear(0) - test_swc_tree.type_clear(1) + gold_swc_tree.set_node_type_by_topo(root_id=1) + test_swc_tree.set_node_type_by_topo(root_id=5) diadem_init() config_init(config) diadam_match_utils.diadem_utils_init(config) @@ -979,6 +979,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config): print('match1 = {}, match2 = {}'.format( key.data.get_id(), g_matches[key].data.get_id() )) + color_tree_only() if debug: for k in g_weight_dict: print("id = {} wt = {}".format(k.data.get_id(), g_weight_dict[k])) @@ -991,7 +992,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config): "score_sum": g_score_sum, "final_score": g_final_score } - return res + return res, gold_swc_tree, test_swc_tree def pyneval_diadem_metric(gold_swc, test_swc, config): @@ -1010,7 +1011,7 @@ def pyneval_diadem_metric(gold_swc, test_swc, config): gold_tree.load_list(read_swc.adjust_swcfile(gold_swc)) test_tree.load_list(read_swc.adjust_swcfile(test_swc)) - diadem_res= diadem_metric(gold_swc_tree=gold_tree, + diadem_res = diadem_metric(gold_swc_tree=gold_tree, test_swc_tree=test_tree, config=config) @@ -1030,8 +1031,8 @@ def pyneval_diadem_metric(gold_swc, test_swc, config): testTree = swc_node.SwcTree() goldTree = swc_node.SwcTree() - goldTree.load("../../data/test_data/topo_metric_data/ExampleGoldStandard.swc") - testTree.load("../../data/test_data/topo_metric_data/ExampleTest.swc") + goldTree.load("../../data/test_data/topo_metric_data/gold_fake_data3.swc") + testTree.load("../../data/test_data/topo_metric_data/test_fake_data3.swc") config_utils.get_default_threshold(goldTree) config = read_json.read_json("../../config/diadem_metric.json") config_schema = read_json.read_json("../../config/schemas/diadem_metric_schema.json") @@ -1041,9 +1042,9 @@ def pyneval_diadem_metric(gold_swc, test_swc, config): except Exception as e: raise Exception("[Error: ]Error in analyzing config json file") - diadem_result = diadem_metric(test_swc_tree=testTree, - gold_swc_tree=goldTree, - config=config) + diadem_result, tmp1, tmp2 = diadem_metric(test_swc_tree=testTree, + gold_swc_tree=goldTree, + config=config) print("matched weight = {}\n" "total weight = {}\n" "diadem score = {}\n". diff --git a/pyneval/metric/length_metric.py b/pyneval/metric/length_metric.py index e68977d..567232a 100644 --- a/pyneval/metric/length_metric.py +++ b/pyneval/metric/length_metric.py @@ -93,6 +93,8 @@ def length_metric(gold_swc_tree, test_swc_tree, config): gold_swc_tree.z_rescale(z_scale) test_swc_tree.z_rescale(z_scale) + gold_swc_tree.set_node_type_by_topo(root_id=1) + test_swc_tree.set_node_type_by_topo(root_id=5) if rad_mode == 1: rad_threshold *= -1 @@ -113,7 +115,7 @@ def length_metric(gold_swc_tree, test_swc_tree, config): "recall": recall, "precision": precision } - return res + return res, gold_swc_tree, test_swc_tree def web_length_metric(gold_swc, test_swc, mode, rad_threshold, len_threshold): diff --git a/pyneval/metric/link_metric.py b/pyneval/metric/link_metric.py index 2f199a3..94785da 100644 --- a/pyneval/metric/link_metric.py +++ b/pyneval/metric/link_metric.py @@ -88,7 +88,7 @@ def link_metric(gold_swc_tree, test_swc_tree, config): "edge_loss": edge_loss, "tree_dis_loss": tree_dis_loss } - return res + return res, None, None if __name__ == "__main__": diff --git a/pyneval/metric/ssd_metric.py b/pyneval/metric/ssd_metric.py index f5335c6..f216b63 100644 --- a/pyneval/metric/ssd_metric.py +++ b/pyneval/metric/ssd_metric.py @@ -119,10 +119,6 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree, t2g_score, t2g_num = get_mse(src_tree=u_test_swc_tree, tar_tree=u_gold_swc_tree, ssd_threshold=ssd_threshold, mode=threshold_mode) - if "detail_path" in config: - swc_writer.swc_save(u_gold_swc_tree, config["detail_path"][:-4] + "_gold_upsampled.swc") - swc_writer.swc_save(u_test_swc_tree, config["detail_path"][:-4] + "_test_upsampled.swc") - if debug: print("recall_num = {}, pre_num = {}, gold_tot_num = {}, test_tot_num = {} {} {}".format( g2t_num, t2g_num, u_gold_swc_tree.size(), u_test_swc_tree.size(), gold_swc_tree.length(), test_swc_tree.length() @@ -133,6 +129,9 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree, "recall": 1 - g2t_num/u_gold_swc_tree.size(), "precision": 1 - t2g_num/u_test_swc_tree.size() } + + if "detail" in config: + return res, u_gold_swc_tree, u_test_swc_tree return res diff --git a/pyneval/metric/volume_metric.py b/pyneval/metric/volume_metric.py index 2459c09..78152dd 100644 --- a/pyneval/metric/volume_metric.py +++ b/pyneval/metric/volume_metric.py @@ -104,7 +104,7 @@ def volume_metric(gold_swc_tree, test_swc_tree, config=None): res = { "recall": recall } - return res + return res, None, None if __name__ == "__main__":