This notebook tests error detection through checking number of connected components from slie to slice

## Error detection

In [1]:
from cerebellum.utils.data_io import *
import json

with open('data_locs.json') as f:
	data_locs = json.load(f)
block_index = 0 # TO CHANGE: which block you are processing
zz = data_locs["block-size"]*block_index+14

# load GT and initial segmentation
gt_file = data_locs["trials"]["dir"] + data_locs["trials"]["gt48nm-cropped-relabeled"]
pred_file = data_locs["trials"]["dir"] + data_locs["trials"]["pf48nm-cropped-relabeled"]
if zz!=14: # adjust block index
    gt_file = gt_file[:-7]+"%04d.h5"%(zz)
    pred_file = pred_file[:-7]+"%04d.h5"%(zz)
gt = read3d_h5(gt_file, 'main')
pred = read3d_h5(pred_file, 'main')
print gt.shape, pred.shape

  from ._conv import register_converters as _register_converters


(90, 540, 488) (90, 540, 488)


In [2]:
n_objs = np.max(pred)+1
print n_objs

3660


In [3]:
# test for single object
import time

from skimage.morphology import label
from skimage.measure import regionprops

start_time = time.time()
obj_id = 1073
obj_mask = np.zeros_like(pred)
obj_mask[pred==obj_id] = 1

bbox = regionprops(obj_mask)[0].bbox
obj_mask = obj_mask[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
print obj_mask.shape

slice_id = 62
slice_cc = label(obj_mask[slice_id,:,:], connectivity=1)
n_cc = np.max(slice_cc)
print n_cc
print time.time()-start_time

(90, 10, 14)
2
0.166497945786


In [4]:
from skimage.morphology import label
from skimage.measure import regionprops
import time

def cc_measure(seg, connectivity=2, write_path=None):
    """
    Gets # connected components across z-slices for all objects in segmentation
    TODO: make more efficient with bounding boxes
    """
    start_time = time.time()
    n_objs = np.max(seg)+1
    n_slices = seg.shape[0]
    n_cc = np.zeros((n_objs, n_slices))
    print "checking CC rule for %d objects"%(n_objs)
    # go over each object, skip 0
    for obj_id in range(1,n_objs): 
        print "on obj %d"%obj_id
        obj_mask = np.zeros_like(seg)
        obj_mask[seg==obj_id] = 1
        if np.sum(obj_mask)==0: # ignore non-existent objects
            continue
        regions = regionprops(obj_mask)
        bbox = regions[0].bbox
        obj_mask = obj_mask[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
        # go over each slice
        check_slices = obj_mask.shape[0] # n_slices
        for slice_id in range(check_slices):
            n_cc[obj_id, slice_id] = np.max(label(obj_mask[slice_id,:,:], connectivity=connectivity))
    print time.time() - start_time
    if write_path is not None:
        np.save(write_path, n_cc)
    return n_cc

In [5]:
def cc_error_detector(n_cc):
    """
    Checks if CC # is conserved across all slices, returns obj_ids that are in violation
    
    n_cc (n_objs x n_slices ndarray): 
    """
    n_objs = n_cc.shape[0]
    ## check if CC # is conserved across all slices
    # flag_cc = [obj_id for obj_id in range(n_objs) if not np.all(n_cc[obj_id,:]==n_cc[obj_id,0])]
    
    # check if CC # is 1 across all slices
    flag_cc = [obj_id for obj_id in range(n_objs) if not np.all(n_cc[obj_id,:]==1)]
    return flag_cc

In [6]:
n_cc = cc_measure(pred, connectivity=2, write_path='./err-correction/pred-all/ncc')

checking CC rule for 3660 objects
on obj 1
on obj 2
on obj 3
on obj 4
on obj 5
on obj 6
on obj 7
on obj 8
on obj 9
on obj 10
on obj 11
on obj 12
on obj 13
on obj 14
on obj 15
on obj 16
on obj 17
on obj 18
on obj 19
on obj 20
on obj 21
on obj 22
on obj 23
on obj 24
on obj 25
on obj 26
on obj 27
on obj 28
on obj 29
on obj 30
on obj 31
on obj 32
on obj 33
on obj 34
on obj 35
on obj 36
on obj 37
on obj 38
on obj 39
on obj 40
on obj 41
on obj 42
on obj 43
on obj 44
on obj 45
on obj 46
on obj 47
on obj 48
on obj 49
on obj 50
on obj 51
on obj 52
on obj 53
on obj 54
on obj 55
on obj 56
on obj 57
on obj 58
on obj 59
on obj 60
on obj 61
on obj 62
on obj 63
on obj 64
on obj 65
on obj 66
on obj 67
on obj 68
on obj 69
on obj 70


KeyboardInterrupt: 

In [None]:
detected_ids = cc_error_detector(n_cc)
passed_ids = list(set(range(n_objs)).difference(set(detected_ids)))
print len(detected_ids)
print len(passed_ids)

## Connected component properties of detected objects

In [None]:
# how many objects are merged in each detected object
n_merge = [n_cc[i,:].max() for i in detected_ids]

import matplotlib.pyplot as plt

plt.hist(x=n_merge, bins=np.arange(9)-0.5, rwidth=0.5)
plt.xlabel('Max CC #')
plt.ylabel('Number of objects')
plt.show()

So most of the detected objects have max CC # of 2

In [None]:
def cc_binning(n_cc):
    """
    Sorts objects into bins based on max CC # over all slices
    """
    n_objs = n_cc.shape[0]
    max_cc = np.max(n_cc) # max CC's per slice across all objects
    cc_binned_objs = [[] for _ in range(max_cc+1)]
    for i in range(n_objs):
        cc_binned_objs[np.max(n_cc[i,:])].append(i)
    return cc_binned_objs

In [None]:
n_cc = n_cc.astype(np.int)
objs_binned = cc_binning(n_cc)

In [None]:
# objects whose max CC # is 1, but do not extend through all slices
chopped_objs = list(set(objs_binned[1]).difference(passed_ids))
print len(chopped_objs)
print chopped_objs

The above objects fall in the `1` bin of the histogram above. They are mostly at the boundaries of the imaging volume. As a result, they do not extend through all slices. They can safely be ignored for error correction.

### 2-object merges

In [None]:
# what fraction of slices contain 2 CC's?
split_frac = [n_cc[i,:].tolist().count(2) for i in objs_binned[2]]

plt.hist(x=split_frac, bins=pred.shape[0]/10)
plt.xlabel('# slices with 2 CCs')
plt.ylabel('Number of objects')
plt.show()

So most of the objects have <10 slices out of the total 90 in the block with 2 connected components. The rest of the slices are either merges or single object slices

In [None]:
print objs_binned[2]

## Validate against GT skeletons

In [None]:
from cerebellum.error_analysis.skel_segeval import SkelEval

i = 0 # BLOCK INDEX - to change or iterate on
zz = i*data_locs["block-size"]+data_locs["aff-offset"]
gt_name = "gt%04d"%(zz)
pred_name = "pred-pf-crop2gt-%04d"%(zz)
skel_eval = SkelEval(gt_name, pred_name)
merge_ids = skel_eval.get_merges(look_in="pred")
corr_ids = skel_eval.get_corrects()
print "No. of falsely merged pred objects (from error analysis): %d"%(len(merge_ids))

In [None]:
def pr_analysis(detected_ids, merge_ids, corr_ids, write_path=None):
    true_pos = list(set(merge_ids)&(set(detected_ids)))
    print "True positives: %d"%(len(true_pos))
    false_pos = list(set(corr_ids)&(set(detected_ids)))
    print "False positives: %d"%(len(false_pos))
    true_neg = list(set(corr_ids).difference(set(detected_ids)))
    print "True negatives: %d"%(len(true_neg))
    false_neg = list(set(merge_ids).difference(set(detected_ids)))
    print "False negatives: %d"%(len(false_neg))
    precision = len(true_pos)/(1.*len(true_pos)+len(false_pos))
    recall = len(true_pos)/(1.*len(true_pos)+len(false_neg))
    print "Precision: %f"%(precision)
    print "Recall: %f"%(recall)
    print "False pos:", false_pos
    print "False neg:", false_neg
    if write_path is not None:
        np.save(write_path+'false_pos', false_pos)
        np.save(write_path+'false_neg', false_neg)
    return (precision, recall)

In [None]:
pr_analysis(detected_ids, merge_ids, corr_ids)

The above precision estimate is a lower bound, because the false positive list above has many objects that are merges between labeled and unlabeled objects in GT. To get a more reasonable precision estimate, wethrow out those objects that have a thresholded fraction of their voxels in GT 0

In [None]:
from cerebellum.error_analysis.voxel_methods import intersection_list

def filter_detector(gt, pred, detected_ids, thresh=0.4):
    filtered_ids = detected_ids[:]
    for obj_id in detected_ids:
        gt_ids, gt_vols = intersection_list(pred, gt, obj_id) # note reversed order compared to function args
        if 0 in gt_ids.tolist():
            zero_content = float(gt_vols[gt_ids==0][0])/np.sum(gt_vols)
            #print obj_id, zero_content
            if zero_content > thresh:
                filtered_ids.remove(obj_id)
    return filtered_ids

To get a better lower bound for precision, reduce `thresh` until precision increases but recall remains close to that of unfiltered set

In [None]:
detected_ids_val = filter_detector(gt, pred, detected_ids, thresh=0.4)
print len(detected_ids_val)
pr_analysis(detected_ids_val, merge_ids, corr_ids)

Visually, many of these 'false positives' still seem to be labeled-unlabeled merges, so our precision is likely much higher

## Error correction of a single object

In [None]:
n_cc= np.load('./err-correction/pred-all/ncc.npy')

In [None]:
import matplotlib.pyplot as plt

test_id = 2252
n_slices = pred.shape[0]
plt.scatter(np.arange(n_slices), n_cc[test_id,:], c="g", alpha=0.5)
plt.xlabel("slice id")
plt.ylabel("CC #")
plt.show()

In [None]:
prefix = "pred-all" # TO CHANGE
output_resolution=(80,80,80)
from cerebellum.ibex.utilities.dataIO import ReadSkeletons
skeletons = ReadSkeletons(prefix, downsample_resolution=output_resolution, read_edges=True)

In [None]:
# plot skeleton
plot_sk = skeletons[test_id]
plot_sk.save_image('./err-correction/'+prefix+'/')

In [None]:
def skel_flow(point, skeleton):
    """
    returns best non-horizontal flow vector at point based on nearest skeleton edge
    
    vector is normalized such that z-component = 1
    """
    
    def dist_pt2ln(pt, ln):
        """
        finds euclidean distance between point and line
        Args:
            pt (3x, array)
            ln (2x3 array)
        """
        p1 = ln[0,:]
        p2 = ln[1,:]
        return np.linalg.norm(np.cross(p2-p1, p1-pt))/np.linalg.norm(p2-p1)
    
    try:
        assert type(point) is list and len(point)==3
    except:
        print point
    nodes = skeleton.get_nodes()
    dists = np.array([np.linalg.norm(np.array(node-point)) for node in nodes])
    nodes = [nodes[i] for i in np.argsort(dists)]
    #print nodes
    edges = skeleton.get_edges()
    edge_vecs = [edge[0]-edge[1] for edge in edges]
    edge_vecs = [np.divide(edge_vec, float(edge_vec[0])) for edge_vec in edge_vecs] # normalize such that z-comp is 1
    # iterate over nodes till you find a non-horizontal edge at a node closest to point
    found_flow = False
    check_edges = []
    while not found_flow:
        try:
            check_node = nodes.pop(0) # closest node to input point
            #print check_node
        except:
            print "Error! Could not find non-horizontal flow vector at this point"
        for i, (edge, edge_vec) in enumerate(zip(edges, edge_vecs)):
            allow_edge = ((np.all(check_node==edge[0]) or np.all(check_node==edge[1])) 
                             and not np.any(np.isnan(edge_vec)) and not np.any(np.isinf(edge_vec)))
            if allow_edge:
                check_edges.append(i)
        found_flow = (len(check_edges)>0)

    #print edges[check_edges]
    #print check_node
    # find edges for which the point is z-midway
    midway_edges = []
    for i in check_edges:
        z_lower = min(edges[i][0,0], edges[i][1,0])
        z_higher = max(edges[i][0,0], edges[i][1,0])
        if z_lower > point[0] or z_higher < point[0]:
            midway_edges.append(i)
    if len(midway_edges)>0: 
        check_edges = midway_edges
    # find edge closest to point
    check_dists = [dist_pt2ln(np.array(point), edges[i]) for i in check_edges]
    flow_edge_id = check_edges[np.argmin(check_dists)]
    #print edges[flow_edge_id]
    flow_vec = edge_vecs[flow_edge_id]
    return flow_vec

In [None]:
# point = [38,22+bbox[1],7+bbox[2]]
# print point
# skel_flow(point, skeletons[test_id])

In [None]:
# load high resolution segmentation for splitting
pred_file = './segs/'+'pred-pf-8nm-crop2gt-%04d/'%(zz)+'seg.h5'
pred = read3d_h5(pred_file, 'main')
print pred.shape
hres = (30,8,8)

In [None]:
from skimage.measure import regionprops

test_obj_mask = np.zeros_like(pred)
test_obj_mask[pred==test_id] = 1
regions = regionprops(test_obj_mask)
bbox = regions[0].bbox
print bbox
test_obj_mask = test_obj_mask[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
print test_obj_mask.shape

In [None]:
from skimage.morphology import label

from cerebellum.utils.data_io import create_folder
# generate labels for CC's in every slice
check_slices = test_obj_mask.shape[0] # n_slices
labeled_obj_mask = np.zeros_like(test_obj_mask)
fpath = './err-correction/'+prefix+'/labeled-obj-mask-8nm-%d/'%(test_id)
create_folder(fpath)

# save plots of slices
plt.ioff()
for slice_id in range(check_slices):
    labeled_obj_mask[slice_id,:,:] = label(test_obj_mask[slice_id,:,:], connectivity=2)
    fig, ax = plt.subplots(1,1,figsize=(15,15))
    ax.imshow(labeled_obj_mask[slice_id,:,:])
    plt.savefig(fpath+'slice%d.png'%(slice_id))
    plt.close(fig)

Test watershed based splitting of a slice with a false merge

In [None]:
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.morphology import watershed

def split_slice(fill_slice, seed_method="skel-flow",
                footprint=None, 
                pivot_slice=None, pivot_id=None, fill_dir=None, bbox_offset=None, 
                seg_res = None, skeleton=None,
                prop_from="dist-xform", tweak_boundary_seeds=True, min_seed_connectivity=1,
                save_path=None, plot_seeds=False):
    
    import itertools
    
    def snap(pt, image):
        """
        Checks if pt lies outside image extents, snaps to nearest pt inside image grid
        Args:
            pt (Nx, array)
            image (N dim array)
        """
        new_pt = list(pt)
        for i in range(len(image.shape)):
            if pt[i]<0:
                new_pt[i] = 0
            elif pt[i]>=image.shape[i]:
                new_pt[i] = image.shape[i]-1
        return tuple(new_pt)
    
    def on_boundary(pt, image, connectivity=2):
        """
        Checks if pt is on a boundary in a binary image
        Args:
            pt (2x, array)
            image (2 dim array)
            connectivity (int)
        """
        assert len(pt)==2
        assert len(np.unique(image))==2
        assert connectivity==1 or connectivity==2
        # generate neighbors
        if connectivity==2:
            neighbs = [[(pt[0]+d0, pt[1]+d1) for d0 in [-1,0,1]] for d1 in [-1,0,1]]
            neighbs = list(itertools.chain.from_iterable(neighbs))
            neighbs.remove(pt)
        elif connectivity==1:
            neighbs = [(pt[0]+1,pt[1]), (pt[0]-1,pt[1]), (pt[0],pt[1]-1), (pt[0],pt[1]+1)]
        neighbs = [snap(nb, image) for nb in neighbs]
        if any([image[nb]==0 for nb in neighbs]):
            return True
        else:
            return False
        
    def push_inwards(pt, image, target_conn=2):
        """
        Pushes a pt on boundary of image inwards, tries to return best connected neighbor
        Args:
            pt (2x, array)
            image (2 dim array)
        """
        assert target_conn>0
        try:
            assert on_boundary(pt, image, connectivity=target_conn)
        except:
            print "Point already inwards, returning as is"
            return pt
        neighbs = [[(pt[0]+d0, pt[1]+d1) for d0 in [-1,0,1]] for d1 in [-1,0,1]]
        neighbs = list(itertools.chain.from_iterable(neighbs))
        neighbs.remove(pt)
        neighbs = [snap(nb, image) for nb in neighbs]
        one_neighbs = [nb for nb in neighbs if not on_boundary(nb, image, connectivity=1)]
        two_neighbs = [nb for nb in neighbs if not on_boundary(nb, image, connectivity=2)]
        if len(two_neighbs)>0 and target_conn<=2:
            return two_neighbs[0] # arbitraily chosen
        elif len(two_neighbs)==0 and len(one_neighbs)>0 and target_conn==1:
            return one_neighbs[0] # arbitraily chosen
        else:
            print "Warning: Failed to find better connected neighbor. Returning point as is"
            return pt
    
    assert seed_method=="dist-xform" or seed_method=="skel-flow"
    
    distance = ndi.distance_transform_edt(fill_slice) # could also use mean affinity map here
    
    # generate seeds for watershed
    # Method 1: local maxima of distance xform of fill slice
    if seed_method=="dist-xform":
        assert type(footprint) is np.ndarray
        local_maxi = peak_local_max(distance, indices=False, num_peaks=2, footprint=footprint) # TODO: CCs>2
        seeds = ndi.label(local_maxi)[0]
        
    # Method 2: use flow vector from skeleton
    elif seed_method=="skel-flow":
        assert prop_from=="centroids" or prop_from=="dist-xform"
        assert fill_dir==-1 or fill_dir==1
        assert pivot_slice.shape == fill_slice.shape
        assert type(pivot_id) is int
        assert bbox_offset is not None and len(bbox_offset)==2
        assert skeleton is not None
        if tweak_boundary_seeds: assert min_seed_connectivity>0
        
        # Choose points in pivot slice to propagate
        max_ccs = np.max(pivot_slice) # TO CHANGE: extend to more CCs
        pivot_regions = [np.zeros_like(pivot_slice) for _ in range(max_ccs)]
        for i, pr in enumerate(pivot_regions):
            pr[pivot_slice==i+1] = 1
        # Method 1: centroids
        if prop_from=="centroids":
            pivot_centroids = [np.mean(np.argwhere(pr!=0), axis=0) for pr in pivot_regions]
            prop_pts = pivot_centroids
        # Method 2: local maxima of distance transform of each CC
        elif prop_from=="dist-xform":
            pivot_distances = [ndi.distance_transform_edt(pr) for pr in pivot_regions]
            pivot_dx_centers = [peak_local_max(pd, indices=True, num_peaks=1)[0] for pd in pivot_distances]
            prop_pts = pivot_dx_centers
        print "Pivot points for propagation:", prop_pts
        
        # Estimate flow
        assert skeleton.resolution[0]==seg_res[0]
        dsmpl = (seg_res[1]/skeleton.resolution[1], seg_res[2]/skeleton.resolution[2])
        flow_sources = [[pivot_id, (p[0]+bbox_offset[0])/dsmpl[0], (p[1]+bbox_offset[1])/dsmpl[1]] 
                        for p in prop_pts] # repackage for skel_flow()
        #print "Flow sources:", flow_sources
        flow_vectors = [skel_flow(point, skeleton) for point in flow_sources]
        print "Flow vectors:", [f.tolist() for f in flow_vectors]
        
        # Generate seeds
        seed_locs = [(int(p[0]+fill_dir*f[1]), 
                    int(p[1]+fill_dir*f[2])) for (p, f) in zip(prop_pts, flow_vectors)]
        print "Seed locs:", seed_locs
        seeds = np.zeros_like(fill_slice)
        seed_validity = [True for _ in range(max_ccs)]
        for i, seed_loc in enumerate(seed_locs):
            if fill_slice[seed_loc]==0:
                print "Warning: Skel-flow seed generated in empty region of fill slice. Ignoring and proceeding with watershed"
            # TODO: Re-enable tweak_boundary options after fixing on_boundary and push_inwards to operate on images with >1 CC
            #elif on_boundary(seed_loc, fill_slice, connectivity=min_seed_connectivity):
            #    if tweak_boundary_seeds:
            #        print "Warning: Skel-flow seed on boundary of fill slice. Will attempt to push inwards"
            #        seed_loc = push_inwards(seed_loc, fill_slice, target_conn=min_seed_connectivity)
            #    else:
            #        print "Warning: Skel-flow seed on boundary of fill slice. Ignoring and proceeding with watershed"
            seeds[seed_loc] = i+1

    # run watershed
    filled_slice = watershed(-distance, seeds, mask=fill_slice)
    # plot original slice with seeds and watershedded slice for debugging
    if plot_seeds:
        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(15,15))
        ax1.imshow(fill_slice, alpha=0.5)
        ax1.imshow(seeds, alpha=0.5)
        ax2.imshow(distance)
        ax3.imshow(filled_slice)
        plt.show()
    # plot watershedded slice only
    if save_path is not None:
        fig, ax = plt.subplots(1,1,figsize=(15,15))
        ax.imshow(filled_slice)
        plt.savefig(save_path)
        plt.close(fig)
    
    return filled_slice

In [None]:
pivot_id = 82 # correct slice
f = 1 # flow dir
fill_id = pivot_id + f # ID of slice to work on
pivot_slice = labeled_obj_mask[pivot_id,:,:].copy()
fill_slice = labeled_obj_mask[fill_id,:,:].copy()
filled_slice = split_slice(fill_slice, 
                           seed_method="skel-flow", 
                           pivot_slice=pivot_slice, pivot_id=pivot_id,
                           fill_dir=f, bbox_offset=(bbox[1], bbox[2]), 
                           seg_res=hres, skeleton=skeletons[test_id],
                           prop_from="dist-xform", tweak_boundary_seeds=False,
                           plot_seeds=True)
#filled_slice = split_slice(fill_slice, seed_method="dist-xform", footprint=np.ones((5,5)))

Test iterative watershed based splitting for a neighborhood of incorrect slices starting from a pivotal slice

In [None]:
# generate pivotal slices and fill directions
pivot_slices = []
fill_dirs = []
for i in range(n_slices):
    if n_cc[test_id,i]==2:
        if i>0 and n_cc[test_id, i-1]==1:
            pivot_slices.append(i)
            fill_dirs.append(-1)
        if i<n_slices-1 and n_cc[test_id, i+1]==1:
            pivot_slices.append(i)
            fill_dirs.append(1)
        if len(pivot_slices)>2 and pivot_slices[-1]==pivot_slices[-2]==i:
            del pivot_slices[-1]
            del fill_dirs[-1]
            fill_dirs[-1]=0
print zip(pivot_slices, fill_dirs)
# generate fill neighborhoods
# each neighborhood is of the form [pivotal slice, terminal slice]
fill_nbds = []
for i in range(len(pivot_slices)):
    if fill_dirs[i]==-1 or fill_dirs[i]==0:
        if i==0:
            fill_nbds.append([pivot_slices[i], 0])
        else:
            fill_nbds.append([pivot_slices[i], (pivot_slices[i-1]+pivot_slices[i])/2+1])
    if fill_dirs[i]==1 or fill_dirs[i]==0:
        if i==len(pivot_slices)-1:
            fill_nbds.append([pivot_slices[i], len(pivot_slices)])
        else:
            fill_nbds.append([pivot_slices[i], (pivot_slices[i+1]+pivot_slices[i])/2])
print fill_nbds

In [None]:
def split_object(labeled_obj_mask, fill_nbds, bbox_offset, seg_res, skeleton, save_path):
    fill_nbds_updated = fill_nbds[:]
    for fill_nbd in fill_nbds_updated:
        print "\n\nSplitting slices in the neighborhood:", fill_nbd
        corrected_mask = labeled_obj_mask.copy()
        while fill_nbd[0]!=fill_nbd[1]: # till all slices in neighborhood are filled
            pivot_id = fill_nbd[0]
            pivot_slice = corrected_mask[pivot_id,:,:]
            fill_dir = int(np.sign(fill_nbd[1]-fill_nbd[0]))
            fill_id = fill_nbd[0]+fill_dir
            print "\nTracking %d and splitting %d"%(pivot_id, fill_id)
            fill_slice = corrected_mask[fill_id,:,:]
            plot_seeds = False
            if pivot_id in range(16,21):
                plot_seeds = True
            corrected_mask[fill_id,:,:] =  split_slice(fill_slice, seed_method="skel-flow", 
                                                       pivot_slice=pivot_slice, pivot_id=pivot_id,
                                                       fill_dir=fill_dir, bbox_offset=bbox_offset, 
                                                       seg_res=seg_res, skeleton=skeleton,
                                                       prop_from="dist-xform", 
                                                       tweak_boundary_seeds=False, min_seed_connectivity=1,
                                                       save_path=save_path+'slice%d.png'%(fill_id),
                                                       plot_seeds=plot_seeds)
            fill_nbd[0] += fill_dir # advance pivot slice in filling direction
    return corrected_mask

In [None]:
spath = './err-correction/pred-all/split-obj-mask-8nm-%d/'%(test_id)
create_folder(spath)
# import sys

# orig_stdout = sys.stdout
# f = open(fpath + 'split.log', 'w')
# sys.stdout = f

corrected_mask = split_object(labeled_obj_mask, fill_nbds, (bbox[1], bbox[2]),
                              hres, skeletons[test_id], 
                              spath)

# sys.stdout = orig_stdout
# f.close()