## import libraries

In [None]:
import os
import skimage
import skimage.io
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
import math

## read in the files (2D binary masks)

In [None]:
files=glob.glob('/Users/jje/Desktop/mdxtrainset/masks/*.png') #read in the files
files.sort()
files

In [None]:
img = []
for item in files:
    print(item)
    image = cv2.imread(item,0)
    
    convert = np.where(image == 178, 1, 0)
    img.append(convert)
img = np.asarray(img)

In [None]:
#identify individual objects(cells)
for i in range(0,len(img)):
    img[i] = skimage.measure.label(img[i], connectivity =2, background=0) 
    
#view a single slice as sanity check
    
plt.imshow(img[27])

## define functions

In [None]:
def bb_intersection_over_union(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)
    # return the intersection over union value
    return iou
def dist(p1, p2):  
    distance = math.sqrt( ((p1[0]-p2[0])**2)+((p1[1]-p2[1])**2) )
    return distance
def RelabelZ_with_fixes(previousImage, currentImage,threshold):
    # This line ensures non-intersecting label sets
    currentImage = relabel_sequential(currentImage,offset=previousImage.max()+1)[0]
    relabelimage = currentImage.copy()
    waterproperties = measure.regionprops(previousImage, previousImage)
    indices = [prop.centroid for prop in waterproperties] 
    labels = [prop.label for prop in waterproperties]
    previousbbox = [prop.bbox  for prop in waterproperties]
    if len(indices) > 1:
       tree = spatial.cKDTree(indices) #previous
       currentwaterproperties = measure.regionprops(currentImage, currentImage)
       currentindices = [prop.centroid for prop in currentwaterproperties] 
       currentlabels = [prop.label for prop in currentwaterproperties] 
       currentbbox = [prop.bbox  for prop in currentwaterproperties]

       if len(currentindices) > 0: 
           for i in range(0,len(currentindices)):
               index = currentindices[i]
               #print(f"index {index}")
               currentlabel = currentlabels[i] 
               currbbox = currentbbox[i]
               #print(f"currentlabel {currentlabel}")
               if currentlabel > 0:
                      previouspoint = tree.query(index)
                      #print(f"prviouspoint {previouspoint}")
                      previouslabel = previousImage[int(indices[previouspoint[1]][0]), int(indices[previouspoint[1]][1])]
                      x, y = np.where(np.asarray(indices)[:,0]== indices[previouspoint[1]][0]), np.where(np.asarray(indices)[:,1]== indices[previouspoint[1]][1])
                      idx = np.intersect1d(x,y)[0]
                      prevbbox = previousbbox[idx]
                      iou = bb_intersection_over_union(currbbox, prevbbox)
                      distance = dist(indices[previouspoint[1]],index)
                      #print(distance)
                       #print(f"previouslabels {previouslabel}")
                      #print(iou)
                      if ((distance < threshold) & (iou >0.5)):
                             #print(f"prviouspoint {previouspoint}")
                             #print(f"currpoint {currentindices}")
                             
                             relabelimage[np.where(currentImage == currentlabel)] = previouslabel
                             #print("keep")
                      else:
                             relabelimage[np.where(currentImage == currentlabel)] = currentlabel
                             #print("fix")
    return relabelimage
def merge_labels_across_volume(labelvol, relabelfunc, threshold=20):
    nz, ny, nx = labelvol.shape
    res = np.zeros_like(labelvol)
    res[0,...] = labelvol[0,...]
    backup = labelvol.copy() # kapoors code modifies the input array
    for i in tqdm.tqdm(range(nz-1)):
        res[i+1] = relabelfunc(res[i,...], labelvol[i+1,...],threshold=threshold)
        labelvol = backup.copy() # restore the input array
    return res

## process the images

In [None]:
# The microscope reports the following spacing (in µm)
original_spacing = np.array([1, 1, 1])

# We downsampled each slice 4x to make the data smaller
rescaled_spacing = original_spacing * [6, 1, 1]

# Normalize the spacing so that pixels are a distance of 1 apart
spacing = rescaled_spacing / rescaled_spacing[2]

seg = img
relabelled_fixes = merge_labels_across_volume(img, RelabelZ_with_fixes)

## view on napari

In [None]:
viewer = napari.view_image(img)
viewer.add_labels(relabelled_fixes, name='re',scale=spacing)