This notebook walks you through the procedure to link IDs across two blocks using IoU tracking of objects at the interface slices. It assumes that the source and target blocks with chosen names are generated previously using the SegPrep module.

In [1]:
from cerebellum.data_prep.seg_prep import *

import json

# set params
resolution = (30, 48, 48)
block_size = 60
affinity_offset = 14
wz_thresh = 0.5
sblock_id = 0 # source block id
tblock_id = 1 # target block id

  from ._conv import register_converters as _register_converters


In [2]:
# read source block
zz_sb = sblock_id*block_size + affinity_offset
sblock_name = "waterz%.2f-48nm-crop2gt-%04d"%(wz_thresh, zz_sb) # source block name
sblock = SegPrep(sblock_name, resolution)
sblock.read_internal(stage="filtered")
sblock.read_bboxes() # Warning! If objects are relabeled, set stage argument appropriately
sblock_seg = sblock.data
sbbox_dict = sblock.bbox_dict

In [3]:
zz_tb = tblock_id*block_size + affinity_offset
tblock_name = "waterz%.2f-48nm-crop2gt-%04d"%(wz_thresh, zz_tb)
tblock = SegPrep(tblock_name, resolution)
tblock.read_internal(stage="filtered")
tblock.read_bboxes()  # Warning! If objects are relabeled, set stage argument appropriately
tblock_seg = tblock.data
tbbox_dict = tblock.bbox_dict

In [4]:
from cerebellum.error_analysis.voxel_methods import iou_rank

sslice = np.array([sblock_seg[-1,:,:]])
tslice = np.array([tblock_seg[0,:,:]])

#iou_results = iou_rank(sslice, tslice)

In [5]:
# test single object tracking
from cerebellum.error_correction.slice_stitch import slice2slice_iou_calc

t_ids, ints, unions = slice2slice_iou_calc(sslice, tslice, 3293)
float(ints[0])/unions[0]

0.9302325581395349

In [8]:
def block_lock(sblock_seg, tblock_seg, iou_thresh=0.5, 
               sbbox_dict=None, tbbox_dict=None, search_span=None):
    """
    Locks objects in target block to IDs of objects in source block
    
    Tracks objects in last slice of source block into first slice of target block
    
    Currntly only supports sbbox_dict = None option
    """
    start_time = time.time()
    for i in [1,2]:
        assert sblock_seg.shape[i]==tblock_seg.shape[i]
    sslice = np.array([sblock_seg[-1,:,:]])
    tslice = np.array([tblock_seg[0,:,:]])
    # if bboxes are unavailable
    s_objs = np.unique(sslice)
    t_objs = np.zeros_like(s_objs, dtype=np.uint32)
    for s_obj in s_objs:
        if sbbox_dict is None:
            int_ids, ints, unions = slice2slice_iou_calc(sslice, tslice, s_obj)
        else:
            sbbox = sbbox_dict[s_obj]
            s_size = (search_span*(sbbox[4]-sbbox[1]),
                      search_span*(sbbox[5]-sbbox[2]))
            cropped_sslice = sslice[:,
                                   max(0,sbbox[1]-s_size[0]):
                                   min(sslice.shape[1],sbbox[4]+s_size[0]),
                                   max(0,sbbox[2]-s_size[1]):
                                   min(sslice.shape[2],sbbox[5]+s_size[1])]
            cropped_tslice = tslice[:,
                                   max(0,sbbox[1]-s_size[0]):
                                   min(sslice.shape[1],sbbox[4]+s_size[0]),
                                   max(0,sbbox[2]-s_size[1]):
                                   min(sslice.shape[2],sbbox[5]+s_size[1])]
            int_ids, ints, unions = slice2slice_iou_calc(cropped_sslice, 
                                                         cropped_tslice, s_obj)
        t_obj = int_ids[0]
        iou = float(ints[0])/unions[0]
        if iou>iou_thresh:
            if tbbox_dict is None:
                tblock_seg[tblock_seg==t_obj] = s_obj
            else:
                tbbox = tbbox_dict[t_obj]
                cropped_tblock = tblock_seg[tbbox[0]:tbbox[3],
                                            tbbox[1]:tbbox[4],
                                            tbbox[2]:tbbox[5]]
                change_vox = list(np.nonzero(cropped_tblock==t_obj))
                change_vox = tuple([cv + tbbox[c_id] for c_id, cv in enumerate(change_vox)])
                tblock_seg[change_vox] = s_obj
    print "Runtime: %f"%(time.time()-start_time)
    return tblock_seg

In [9]:
tblock_locked = block_lock(sblock_seg, tblock_seg, iou_thresh=0.5, 
                           sbbox_dict=None, tbbox_dict=tbbox_dict,
                           search_span=2)

Runtime: 63.386322


In [10]:
tblock.data = tblock_locked
tblock.write(stage="locked-to-0")

In [11]:
tblock.shape

[60, 540, 489]