Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config/diadem_metric.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"remove_spur": false,
"count_excess_nodes": true,
"align_tree_by_root": false,
"list_miss": false,
"list_distant_matches": false,
"list_continuations": false,
"list_miss": true,
"list_distant_matches": true,
"list_continuations": true,
"find_proper_root": true,
"z_scale": 1,
"TRAJECTORY_NONE": -1.0,
Expand Down
2 changes: 1 addition & 1 deletion config/ssd_metric.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"threshold_mode": 1,
"ssd_threshold": 2,
"ssd_threshold": 1,
"up_sample_threshold": 1.0,
"z_scale": 1,
"debug": false
Expand Down
45 changes: 21 additions & 24 deletions pyneval/cli/pyneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,6 @@ def set_configs(abs_dir, args):
# info: how many trees read
print("There are {} test image(s)".format(len(test_swc_trees)))

# argument: output
output_dir = None
if args.output:
output_dir = os.path.join(abs_dir, args.output)

# argument: detail
detail_dir = None
if args.detail:
detail_dir = os.path.join(abs_dir, args.detail)

# argument: config
config_path = args.config
if config_path is None:
Expand All @@ -208,36 +198,46 @@ def set_configs(abs_dir, args):
except Exception:
raise Exception("[Error: ]Error in analyzing config json file")

# argument: output
output_dir = None
if args.output:
output_dir = os.path.join(abs_dir, args.output)

# argument: detail
detail_dir = None
if args.detail:
detail_dir = os.path.join(abs_dir, args.detail)
config["detail"] = True

# argument: debug
is_debug = args.debug

return gold_swc_tree, test_swc_trees, metric, output_dir, detail_dir, config, is_debug


def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir, file_name_extra=""):
def excute_metric(metric, gold_swc_tree, test_swc_tree, config, detail_dir, output_dir):
metric_method = get_metric_method(metric)
test_swc_name = test_swc_tree.get_name()
gold_swc_name = gold_swc_tree.get_name()

result = metric_method(gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree, config=config)
result, res_gold_swc_tree, res_test_swc_tree = metric_method(gold_swc_tree=gold_swc_tree,
test_swc_tree=test_swc_tree, config=config)

print("---------------Result---------------")
for key in result:
print("{} = {}".format(key.ljust(15, ' '), result[key]))
print("----------------End-----------------\n")

if file_name_extra == "reverse":
file_name = gold_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
else:
file_name = test_swc_name[:-4] + "_" + metric + "_" + file_name_extra + ".swc"
file_name = test_swc_name[:-4] + "_" + metric + "_"

if detail_dir:
swc_save(swc_tree=gold_swc_tree,
out_path=os.path.join(detail_dir, file_name))
swc_save(swc_tree=res_gold_swc_tree,
out_path=os.path.join(detail_dir, file_name + "recall.swc"))
swc_save(swc_tree=res_test_swc_tree,
out_path=os.path.join(detail_dir, file_name + "precision.swc"))

if output_dir:
read_json.save_json(data=result,
json_file_path=os.path.join(output_dir, file_name))
json_file_path=os.path.join(output_dir, file_name + ".json"))


# command program
Expand All @@ -254,9 +254,6 @@ def run():
for test_swc_tree in test_swc_trees:
excute_metric(metric=metric, gold_swc_tree=gold_swc_tree, test_swc_tree=test_swc_tree,
config=config, detail_dir=detail_dir, output_dir=output_dir)
if metric in ["length_metric", "diadem_metric"]:
excute_metric(metric=metric, gold_swc_tree=test_swc_tree, test_swc_tree=gold_swc_tree,
config=config, detail_dir=detail_dir, output_dir=output_dir, file_name_extra="reverse")


if __name__ == "__main__":
Expand All @@ -278,4 +275,4 @@ def run():

# pyneval --gold .\\data\test_data\geo_metric_data\gold_34_23_10.swc --test .\data\test_data\geo_metric_data\test_34_23_10.swc --metric branch_metric

# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/ --metric branch_metric --detail ./output
# pyneval --gold ./data/test_data/geo_metric_data/gold_fake_data1.swc --test ./data/test_data/geo_test/test_fake_data1.swc --metric branch_metric --detail ./output/detail --output ./output/output
2 changes: 1 addition & 1 deletion pyneval/io/read_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def save_json(json_file_path, data, DEBUG=False):
raise Exception("[Error: ] \" {} \" is not a json file. Wrong format".format(json_file_path))
try:
with open(json_file_path, 'w') as f:
json.dump(data, f)
json.dump(data, f, indent=4)
if DEBUG:
print(type(data))
except:
Expand Down
2 changes: 1 addition & 1 deletion pyneval/metric/branch_leaf_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def branch_leaf_metric(gold_swc_tree, test_swc_tree, config):
"pt_cost": branch_result_tuple[7],
"iso_node_num": branch_result_tuple[8]
}
return branch_result
return branch_result, gold_swc_tree, test_swc_tree


if __name__ == "__main__":
Expand Down
51 changes: 26 additions & 25 deletions pyneval/metric/diadem_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,24 +780,24 @@ def color_tree_only():
if g_list_miss:
if len(g_miss) > 0:
for node in g_miss:
# 3 means this node is missed
node.data._type = 2
# 9 means this node is missed
node.data._type = 9
if len(g_excess_nodes) > 0:
for node in g_excess_nodes.keys():
# 4 means this node is excessive
node.data._type = 3
# 10 means this node is excessive
node.data._type = 10

if g_list_continuations:
if len(g_continuation) > 0:
for node in g_continuation:
# 5 means this node is a continuation
node.data._type = 4
# 11 means this node is a continuation
node.data._type = 11

if g_list_distant_matches:
if len(g_distance_match) > 0:
for node in g_distance_match:
# 6 means this node is a distant match
node.data._type = 5
# 12 means this node is a distant match
node.data._type = 12


def print_result():
Expand All @@ -821,8 +821,8 @@ def print_result():
print("node_ID = {} poi = {} weight = {}".format(
node.data.get_id(), node.data._pos, g_weight_dict[node]
))
# 3 means this node is missed
node.data._type = 2
# 9 means this node is missed
node.data._type = 9
print("--END--")
else:
print("---Nodes that are missed:None---")
Expand All @@ -835,8 +835,8 @@ def print_result():
print("node_ID = {} poi = {} weight = {}".format(
node.data.get_id(), node.data._pos, g_excess_nodes[node]
))
# 4 means this node is excessive
node.data._type = 3
# 10 means this node is excessive
node.data._type = 10
else:
print("---extra Nodes in test reconstruction: None---")

Expand All @@ -848,8 +848,8 @@ def print_result():
print("node_ID = {} poi = {} weight = {}".format(
node.data.get_id(), node.data._pos, g_weight_dict[node]
))
# 5 means this node is a continuation
node.data._type = 4
# 11 means this node is a continuation
node.data._type = 11
else:
print("---continuation Nodes None---")

Expand All @@ -861,8 +861,8 @@ def print_result():
print("node_ID = {} poi = {} weight = {}".format(
node.data.get_id(), node.data._pos, g_weight_dict[node]
))
# 6 means this node is a distant match
node.data._type = 5
# 12 means this node is a distant match
node.data._type = 12
else:
print("Distant Matches: none")

Expand Down Expand Up @@ -935,8 +935,8 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
"""
global g_spur_set
global g_weight_dict
gold_swc_tree.type_clear(0)
test_swc_tree.type_clear(1)
gold_swc_tree.set_node_type_by_topo(root_id=1)
test_swc_tree.set_node_type_by_topo(root_id=5)
diadem_init()
config_init(config)
diadam_match_utils.diadem_utils_init(config)
Expand Down Expand Up @@ -979,6 +979,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
print('match1 = {}, match2 = {}'.format(
key.data.get_id(), g_matches[key].data.get_id()
))
color_tree_only()
if debug:
for k in g_weight_dict:
print("id = {} wt = {}".format(k.data.get_id(), g_weight_dict[k]))
Expand All @@ -991,7 +992,7 @@ def diadem_metric(gold_swc_tree, test_swc_tree, config):
"score_sum": g_score_sum,
"final_score": g_final_score
}
return res
return res, gold_swc_tree, test_swc_tree


def pyneval_diadem_metric(gold_swc, test_swc, config):
Expand All @@ -1010,7 +1011,7 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
gold_tree.load_list(read_swc.adjust_swcfile(gold_swc))
test_tree.load_list(read_swc.adjust_swcfile(test_swc))

diadem_res= diadem_metric(gold_swc_tree=gold_tree,
diadem_res = diadem_metric(gold_swc_tree=gold_tree,
test_swc_tree=test_tree,
config=config)

Expand All @@ -1030,8 +1031,8 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
testTree = swc_node.SwcTree()
goldTree = swc_node.SwcTree()

goldTree.load("../../data/test_data/topo_metric_data/ExampleGoldStandard.swc")
testTree.load("../../data/test_data/topo_metric_data/ExampleTest.swc")
goldTree.load("../../data/test_data/topo_metric_data/gold_fake_data3.swc")
testTree.load("../../data/test_data/topo_metric_data/test_fake_data3.swc")
config_utils.get_default_threshold(goldTree)
config = read_json.read_json("../../config/diadem_metric.json")
config_schema = read_json.read_json("../../config/schemas/diadem_metric_schema.json")
Expand All @@ -1041,9 +1042,9 @@ def pyneval_diadem_metric(gold_swc, test_swc, config):
except Exception as e:
raise Exception("[Error: ]Error in analyzing config json file")

diadem_result = diadem_metric(test_swc_tree=testTree,
gold_swc_tree=goldTree,
config=config)
diadem_result, tmp1, tmp2 = diadem_metric(test_swc_tree=testTree,
gold_swc_tree=goldTree,
config=config)
print("matched weight = {}\n"
"total weight = {}\n"
"diadem score = {}\n".
Expand Down
4 changes: 3 additions & 1 deletion pyneval/metric/length_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def length_metric(gold_swc_tree, test_swc_tree, config):

gold_swc_tree.z_rescale(z_scale)
test_swc_tree.z_rescale(z_scale)
gold_swc_tree.set_node_type_by_topo(root_id=1)
test_swc_tree.set_node_type_by_topo(root_id=5)

if rad_mode == 1:
rad_threshold *= -1
Expand All @@ -113,7 +115,7 @@ def length_metric(gold_swc_tree, test_swc_tree, config):
"recall": recall,
"precision": precision
}
return res
return res, gold_swc_tree, test_swc_tree


def web_length_metric(gold_swc, test_swc, mode, rad_threshold, len_threshold):
Expand Down
2 changes: 1 addition & 1 deletion pyneval/metric/link_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def link_metric(gold_swc_tree, test_swc_tree, config):
"edge_loss": edge_loss,
"tree_dis_loss": tree_dis_loss
}
return res
return res, None, None


if __name__ == "__main__":
Expand Down
7 changes: 3 additions & 4 deletions pyneval/metric/ssd_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
t2g_score, t2g_num = get_mse(src_tree=u_test_swc_tree, tar_tree=u_gold_swc_tree,
ssd_threshold=ssd_threshold, mode=threshold_mode)

if "detail_path" in config:
swc_writer.swc_save(u_gold_swc_tree, config["detail_path"][:-4] + "_gold_upsampled.swc")
swc_writer.swc_save(u_test_swc_tree, config["detail_path"][:-4] + "_test_upsampled.swc")

if debug:
print("recall_num = {}, pre_num = {}, gold_tot_num = {}, test_tot_num = {} {} {}".format(
g2t_num, t2g_num, u_gold_swc_tree.size(), u_test_swc_tree.size(), gold_swc_tree.length(), test_swc_tree.length()
Expand All @@ -133,6 +129,9 @@ def ssd_metric(gold_swc_tree: swc_node.SwcTree, test_swc_tree: swc_node.SwcTree,
"recall": 1 - g2t_num/u_gold_swc_tree.size(),
"precision": 1 - t2g_num/u_test_swc_tree.size()
}

if "detail" in config:
return res, u_gold_swc_tree, u_test_swc_tree
return res


Expand Down
2 changes: 1 addition & 1 deletion pyneval/metric/volume_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def volume_metric(gold_swc_tree, test_swc_tree, config=None):
res = {
"recall": recall
}
return res
return res, None, None


if __name__ == "__main__":
Expand Down