## Interface slice error detection

Goal: See if you can detect merge errors by identifying objects with poor IoU scores in interface slices

In [1]:
from cerebellum.utils.data_io import *
import json
import numpy as np

with open('data_locs.json') as f:
    data_locs = json.load(f)
block_indices = np.array([0, 1]) # TO CHANGE: which blocks are you checking
zz = data_locs["block-size"]*block_indices+14

# load blocks of interest
pred_blocks = len(block_indices)*[None]
for i, z in enumerate(zz.tolist()):
    pred_file = data_locs["trials"]["dir"] + data_locs["trials"]["pf48nm-cropped-relabeled"]
    if z!=14: # adjust block index
        pred_file = pred_file[:-7]+"%04d.h5"%(z)
    pred_blocks[i] = read3d_h5(pred_file, 'main')
    print pred_blocks[i].shape

  from ._conv import register_converters as _register_converters


(90, 540, 488)
(90, 540, 488)


In [2]:
# extract interface slices
last_slice = np.array([pred_blocks[0][data_locs["block-size"]-1,:,:]])
first_slice = np.array([pred_blocks[1][0,:,:]])

print last_slice.shape, first_slice.shape
print np.max(last_slice), np.max(first_slice)

(1, 540, 488) (1, 540, 488)
3659 3743


In [3]:
from cerebellum.error_analysis.voxel_methods import *

from functools import reduce
import time

# IOU CALCULATION
ints, unions, ious, orders, runtime = slice_iou(last_slice, first_slice)
print runtime
inter_iou_results = (ints, unions, ious, orders)

RuntimeError: Labels in source slice are not ordered. Relabel and try again

In [None]:
# ERROR DET
# find all IDs with IoU below a threshold and object in last_slice bigger than object in first_slice
def slice_iou_error_detector(last_slice, first_slice, iou_thresh, iou_results=None, check_order=False):
    """
    Detects all objects in last_slice that have poor IoU with objects in first_slice
    If check_order is True, only declares errors in which last_slice object is bigger
    """
    if iou_results is not None:
        (ints, unions, ious, orders) = iou_results
    else:
        ints, unions, ious, orders, _ = slice_iou(last_slice, first_slice)
    if check_order:
        detected_ids = reduce(np.intersect1d, (np.argwhere(ious<iou_thresh), 
                                               np.argwhere(ious>0), 
                                               np.argwhere(orders==True)))
    else:
        detected_ids = np.intersect1d(np.argwhere(ious<iou_thresh), np.argwhere(ious>0))
    return detected_ids.tolist()

In [None]:
inter_thresh = 0.8
inter_errors = slice_iou_error_detector(last_slice, first_slice, inter_thresh, iou_results=inter_iou_results, check_order=True)
print len(inter_errors)

## Intra-block error detection

In [None]:
# IOU CALCULATION
def slice2slice_iou(seg):
    n_slices = seg.shape[0]
    n_segs = np.max(seg)
    print "Starting slice-to-slice IoU calculation for %d slices"%(n_slices)
    runtime = 0
    iou_results = (n_slices-1)*[None]
    for i in range(n_slices-1):
        from_slice = np.array([seg[i,:,:]])
        to_slice = np.array([seg[i+1,:,:]])
        ints, unions, ious, orders, ti = slice_iou(from_slice, to_slice)
        iou_results[i] = (ints, unions, ious, orders)
        runtime += ti
    print "Runtime: %f"%(runtime)
    return iou_results

In [None]:
seg = pred_blocks[0][:,:,:]
intra_iou_results = slice2slice_iou(seg)

In [None]:
# ERROR DET
def intra_block_error_detector(seg, iou_thresh, iou_results=None):
    """
    Detects all objects for which at least one slice to slice IoU is below iou_thresh
    """
    if iou_results is None:
        iou_results = slice2slice_iou(seg)
    n_slices = seg.shape[0]
    detected_errors = (n_slices-1)*[None]
    for i in range(n_slices-1):
        from_slice = np.array([seg[i,:,:]])
        to_slice = np.array([seg[i+1,:,:]])
        slice_results = iou_results[i]
        detected_errors[i] = slice_iou_error_detector(from_slice, to_slice, iou_thresh, iou_results=slice_results)
    return detected_errors

In [None]:
intra_thresh = 0.5
intra_block_errors = intra_block_error_detector(seg, intra_thresh, iou_results=intra_iou_results)

In [None]:
# pool errors over all slices in one list
n_slices = seg.shape[0]
pooled_intra_block_errors = []
for i in range(n_slices-1):
    pooled_intra_block_errors.extend(intra_block_errors[i])
pooled_intra_block_errors = list(set(pooled_intra_block_errors))
print len(pooled_intra_block_errors)

## Validate against GT skeletons

The code blocks below are copied from our notebook on [skeleton based error detection](file:///skeleton_error_detection.ipynb)

In [None]:
prefix = "pred-all" # TO CHANGE

In [None]:
# read merge errors identified in GT skeleton analysis
def read_merges(read_path):
    f = open(read_path + "/merged.ids", "r")
    n_pairs = int(f.readline()) # no of pairs of GT merged skeletons
    merge_ids = f.readlines()[::2]
    for i, mstr in enumerate(merge_ids):
        merge_ids[i] = int(mstr)
    merge_ids = list(set(merge_ids))
    return merge_ids

# read correct IDs identified in GT skeleton analysis
def read_corrects(read_path):
    f = open(read_path + "/correct.ids", "r")
    n_corr = int(f.readline()) # no of pairs of GT merged skeletons
    corr_ids = f.readlines()
    for i, cstr in enumerate(corr_ids):
        corr_ids[i] = int(cstr.split(',')[1])
    return corr_ids

In [None]:
# compare with merge IDs detected from GT skeleton analysis
err_analysis_path = 'skeletons/'+prefix+'/error-analysis/'

merge_ids = read_merges(err_analysis_path)
corr_ids = read_corrects(err_analysis_path)

import json
with open(err_analysis_path + 'error-analysis-summary.json') as f:
    err_summ = json.load(f)
assert len(corr_ids) == err_summ["results"]["correct"]

print "# merges from GT analysis: %d"%(len(merge_ids))
print "# corrects from GT analysis: %d"%(len(corr_ids))

In [None]:
def pr_analysis(detected_ids, merge_ids, corr_ids, write_path=None):
    true_pos = list(set(merge_ids)&(set(detected_ids)))
    print "True positives: %d"%(len(true_pos))
    false_pos = list(set(corr_ids)&(set(detected_ids)))
    print "False positives: %d"%(len(false_pos))
    true_neg = list(set(corr_ids).difference(set(detected_ids)))
    print "True negatives: %d"%(len(true_neg))
    false_neg = list(set(merge_ids).difference(set(detected_ids)))
    print "False negatives: %d"%(len(false_neg))
    precision = len(true_pos)/(1.*len(true_pos)+len(false_pos))
    recall = len(true_pos)/(1.*len(true_pos)+len(false_neg))
    print "Precision: %f"%(precision)
    print "Recall: %f"%(recall)
    print "False pos:", false_pos
    print "False neg:", false_neg
    if write_path is not None:
        np.save(write_path+'false_pos', false_pos)
        np.save(write_path+'false_neg', false_neg)
    return (precision, recall)

In [None]:
detected_ids = list(set(pooled_intra_block_errors)|set(inter_errors))
print len(detected_ids)
pr_analysis(detected_ids_to_check, merge_ids, corr_ids, write_path=None)

In [None]:
# take intersection with IDs detected by junction presence in skeleton
read_path = 'skeletons/'+prefix+'/error-detection/'
sk_true_pos = np.load(read_path+'false_pos.npy')
sk_false_pos = np.load(read_path+'true_pos.npy')
sk_detected_ids = list(set(sk_true_pos)|set(sk_false_pos))

joint_detected_ids = list(set(detected_ids) & set(sk_detected_ids))
pr_analysis(joint_detected_ids, merge_ids, corr_ids, write_path=None)

PR curve sweeping IoU threshold. I ran this experiment quickly by sweeping only the threshold for inter-block part. The threshold for intra-block part was determined manually

In [None]:
# plot PR curve sweeping error detection threshold
import matplotlib.pyplot as plt

iou_threshs = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stitch_det_results = [stitch_error_detector(ints, unions, ious, orders, t) for t in iou_threshs]

pr_results = [pr_analysis(result, merge_ids, corr_ids) for result in stitch_det_results]
p = [res[0] for res in pr_results]
r = [res[1] for res in pr_results]
plt.plot(p,r)
plt.xlabel('Precision')
plt.ylabel('Recall')
plt.show()