Example Notebook for running GeoSeg on correcting 2D-based method cell segmentation error

In [None]:
import sys
!{sys.executable} -m pip install scikit-image
!{sys.executable} -m pip install numpy
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install scipy

In [3]:
import CandidateSearching as CS
import interpolate as IP

import os
import re
import h5py
import numpy as np
from scipy import ndimage
import matplotlib.colors as mcolors
from matplotlib import pyplot as plt
from cellstitch import evaluation as eval

Please entry your directory name and file name 

If you are using the plant dataset, each image stack is indexed from 0 to 99

For example, Anther_00.npy, ..., Anther_99.npy

In [7]:
directory_name = r'.\Leaf' # Example of directory for the dataset
image_stack = "Leaf_{:02d}.npy" # Example of file name for each image stack in the dataset

In [None]:
# Image stack index
problem_masks = []

# Number of image stacks
num_dataset = 20

for iter in range(num_dataset):
    print(iter)
    file_name = image_stack.format(iter)
    file_path = os.path.join(directory_name, file_name)
    result = CS.main(file_path)
    for j in result:
        problem_masks.append((iter,j[0],j[1],j[2]))
        
print("********************************************************")
print("List of problematic mask in this dataset:")
for iter, cell_A, cell_B, missing_mask in problem_masks:
    print(f"Image Stack {iter}, Cell_A id: {cell_A}, Cell_B id: {cell_B}, Missing mask layer starts at: {missing_mask}")

From the previous step, we get a set of suspected cases where our algorithm detected that there's a 2D-segmentation error.

The following step will correct the 2D-segmentation error by doing the cross layer interpolation:

In [10]:
# WrongedMask[0] represents the id of the image stack from the dataset
# WrongedMask[1] represents the index of the upper cell
# WrongedMask[2] represents the index of the lower cell
wrongedMask = problem_masks

for info in wrongedMask:
    dir_name = directory_name + "_augmented"
    file_name = "augmented_{:02d}_masks.npy".format(info[0])
    file_path = os.path.join(dir_name, file_name)

    # We accumulatively correct the segmentation error on a image stack
    try:
        if not os.path.exists(file_path):
            raise FileNotFoundError
    except FileNotFoundError:
        dir_name = directory_name
        file_name = image_stack.format(info[0])
        file_path = os.path.join(dir_name, file_name)
        os.makedirs(directory_name + "_augmented", exist_ok=True)

    array = CS.load_array(os.path.join(file_path))
    cell_dict = CS.extract_cells_info(array)
    try:
        mask_A = array[cell_dict[info[1]].lowest_layer]
        mask_B = array[cell_dict[info[2]].highest_layer]
    except:
        continue

    # Load the cell dir and each cell's info
    lbl_A = info[1]
    lbl_B = info[2]
    adapted_mask_A = np.copy(mask_A)
    adapted_mask_B = np.copy(mask_B)
    adapted_mask_A[adapted_mask_A != lbl_A] = 0
    adapted_mask_B[adapted_mask_B != lbl_B] = 0
    contour_A = IP.get_contours(adapted_mask_A)
    contour_B = IP.get_contours(adapted_mask_B)
    recovered = IP.interpolate(IP.mask_to_coord(contour_A),IP.mask_to_coord(contour_B))
    recovered_array = np.array(recovered)
    recovered_array_2d = np.reshape(recovered_array, (-1, 2))
    resu = IP.connect_boundary(recovered_array_2d, (np.shape(mask_A)[0], np.shape(mask_A)[1]) , lbl_A)
    filled_mask = ndimage.binary_fill_holes(resu)
    mask_C = array[cell_dict[info[1]].lowest_layer + 1]
    
    # Correct the 2D segmentation by adding the interpolation mask to the original 2D segmentation error layer
    filled_indices = np.argwhere(filled_mask != 0)
    index_tuple = tuple(zip(*filled_indices))
    mask_C[index_tuple] = lbl_B
    array_new = np.copy(array)
    array_new[cell_dict[info[1]].lowest_layer + 1] = mask_C
    array_new[array_new == lbl_A] = lbl_B

    index = info[0]
    save_directory = directory_name + "_augmented"
    save_name = "augmented_{:02d}_masks.npy".format(index)
    save_path = os.path.join(save_directory,save_name)
    np.save(save_path, array_new)
    