Experiment to fix split errors within a block using skeletons

In [1]:
from cerebellum.data_prep.seg_prep import *
from cerebellum.skeletonize import gen_skeletons
from cerebellum.ibex.utilities.dataIO import ReadSkeletons
from cerebellum.error_analysis.skel_segeval import SkelEval
from cerebellum.error_analysis.voxel_segeval import VoxEval

resolution = (30, 48, 48)
output_resolution=(80, 80, 80) # skeleton resolution
affinity_offset = 14 # affinity offset along z-axis
block_size = 60
wz_thresh = 0.5
# set skeleton error analysis thresholds
t_om = 0.9
t_m = 0.5
t_s = 0.8

  from ._conv import register_converters as _register_converters


In [2]:
block_id = 0
zz = block_id*block_size + affinity_offset
block_name = "waterz%.2f-48nm-crop2gt-%04d"%(wz_thresh, zz)
# block = SegPrep(block_name, resolution)
stage_pre_relabel = "filt-dsmpl"
stage_post_relabel = stage_pre_relabel+"-relab"
# block.read_internal(stage=stage_pre_relabel)
# block.relabel(use_bboxes=True)
# block.write(stage=stage_post_relabel)

In [3]:
# save skeletons to ./skeletons/<block_name>/. does not support saving separate skeletons for each stage yet
# gen_skeletons(block_name, resolution, stage=stage_post_relabel, dsmpl_res=output_resolution, overwrite_prev=True)

In [4]:
skeletons = ReadSkeletons(block_name, downsample_resolution=output_resolution, read_edges=True)

In [15]:
# visualize known split skeletons
#relmap = np.load('./segs/' + block_name + '/relabeling-map.npy')
#pre_ids = [2461, 30296]
#plot_ids = [np.argwhere(relmap==pre_id)[0][0] for pre_id in pre_ids]
#print plot_ids
plot_ids = [21607, 4544, 4567]

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
%matplotlib qt

fig = plt.figure(figsize=(16,12))
ax = Axes3D(fig)
for pi in plot_ids:
    print "%d has length: %f"%(pi, skeletons[pi].length())
    nodes = skeletons[pi].get_nodes()
    edges = skeletons[pi].get_edges()
    eps, epvecs = skeletons[pi].get_endpoints()
    print "%d endpoints:"%(pi), eps
    print "%d endpoint vectors:"%(pi), epvecs
    ax.scatter(nodes[:,2],nodes[:,1],nodes[:,0], s=10, c='r')
    ax.set_xlim3d(0,skeletons[pi].grid_size[2])
    ax.set_ylim3d(0,skeletons[pi].grid_size[1])
    ax.set_zlim3d(0,skeletons[pi].grid_size[0])
    for i in range(edges.shape[0]):
        ln_x = [edges[i][0][2], edges[i][1][2]]
        ln_y = [edges[i][0][1], edges[i][1][1]]
        ln_z = [edges[i][0][0], edges[i][1][0]]
        plt.plot(ln_x, ln_y, ln_z, 'b-')
#         ax.quiver(eps[:,2],eps[:,1],eps[:,0],
#                   epvecs[:,2],epvecs[:,1],epvecs[:,0],length=10,normalize=True)
    plt.show()

21607 has length: 1141.003138
21607 endpoints: [[ 28 279 321]
 [ 58 276 324]]
21607 endpoint vectors: [[-0.94868332  0.         -0.31622776]
 [ 0.94868332 -0.31622776  0.        ]]
4544 has length: 819.800635
4544 endpoints: [[  7 283 316]
 [ 10 278 315]
 [ 18 281 318]]
4544 endpoint vectors: [[-0.89442718  0.44721359  0.        ]
 [-0.57735026 -0.57735026 -0.57735026]
 [ 0.70710677  0.          0.70710677]]
4567 has length: 2119.512225
4567 endpoints: [[ 52 281 321]
 [  7 291 319]
 [ 23 281 319]]
4567 endpoint vectors: [[ 0.94868332 -0.31622776  0.        ]
 [-0.94868332  0.31622776  0.        ]
 [-0.31622776 -0.94868332  0.        ]]


In [16]:
def skel_detector(skeletons, thresholds):
    """
    Prescribes IDs of split skeletons to merge by tracking them along z-axis
    """
    start_time = time.time()
    min_len = thresholds["min-len"]
    max_len = thresholds["max-len"]
    # first stage
    # get IDs of skeletons between requested length thresholds
    sk_inspect = [sk.label for sk in skeletons if min_len<sk.length()<max_len and len(sk.endpoints)==2]
    n_inspect = len(sk_inspect)
    print "Found %d skeletons to inspect out of %d in volume"%(n_inspect, len(skeletons))
#     # second stage
#     # search through first stage labels and find split pairs
#     for o_id in range(n_inpsect):
#         o_eps, o_epvecs = skeletons[sk_inspect[o_id]].get_endpoints()
#         for i_id in range(o_id, n_inpsect):
#             i_eps, i_epvecs = skeletons[sk_inspect[i_id]].get_endpoints()
#                 for ep, epvec in zip(eps, epvecs):
#                     # skip spines
#                     if np.array_equal(epvec, np.zeros(3)):
#                         continue
    print "Split skeleton detection complete in %f s"%(time.time()-start_time)
    
    return sk_inspect

In [28]:
block_extent = block_size*resolution[0]
sk_threshes = {"min-len": 0.2*block_extent,
               "max-len": 1.0*block_extent
                }
detected_sks = skel_detector(skeletons, sk_threshes)

Found 3773 skeletons to inspect out of 43300 in volume
Split skeleton detection complete in 3.366117 s


In [29]:
# check for known splits
gt_splits = [804, 388, 81, 861, 363, 84, 89, 91, 658] # feed in
pred_splits = [sk_eval.sk_eval.gt2pred[gid].ids_out for gid in gt_splits] 
preds_caught = [([pid in detected_sks for pid in ps]) for ps in pred_splits]
print len_caught

[[False, False], [True, True, False, False], [True, False, False], [False, False, False, False, False], [True, True, False, False], [False, False], [False, False, False, False], [True, False], [True, False, False]]


In [8]:
# benchmark against results from GT skeleton error analysis
gt_block_name = "gt-48nm-%dslices-%04d"%(block_size, zz)
sk_eval = SkelEval(gt_block_name, block_name, stage=stage_post_relabel, 
                   t_om=t_om, t_m=t_m, t_s=t_s, 
                   include_zero_split=False, include_zero_merge=True,
                   overwrite_prev=True)
sk_eval.pr_analysis(detected_sks, "split")

Starting error analysis of waterz0.50-48nm-crop2gt-0014 against skeletons of gt-48nm-60slices-0014
Starting evaluation of 43300 labels in 60x540x489 predicted segmentation against 1003 GT skeletons
Using error thresholds: t_om=0.90, t_m=0.50, t_s=0.80
Skeleton evaluation time: 3.78245401382
Results:
3 omissions, 10 merges, 160 splits, 4 hybrid, 826 correct
GT ERL: 2005, Prediction ERL: 1651
GT TRL: 1742560, Prediction TRL: 1572320
Omitted RL: 3839, Merged RL: 25897, Split RL: 140503
Benchmarking against 557 error IDs and 826 correct IDs
True positives: 176
False positives: 45
True negatives: 781
False negatives: 381
Precision: 0.796380
Recall: 0.315978
False pos: [4994, 4843, 3973, 10247, 4233, 2314, 3214, 2833, 3348, 3096, 3935, 2213, 2215, 4915, 2356, 2742, 4281, 3039, 4797, 4036, 4933, 3145, 3786, 2507, 4172, 3022, 2640, 12280, 3416, 2777, 2652, 2397, 2655, 2256, 4200, 4331, 2156, 3566, 26277, 3953, 2674, 2804, 3576, 3453, 2687]
False neg: [10241, 4100, 36872, 2062, 36879, 2065, 411

(0.7963800904977375, 0.31597845601436264)

[2474, 33712]