# Predict spindle

Notebook to use trained models for predicting 6 different profiles by detecting its precise location and contour. 

### 1) Load dependancies
A complete list of key libraries are stored in `requirements.txt`. To install all depandancies, copy the library folder to the conda enviroment:

`ml_env_prediction`

This enviroment can also be used for measurements.

### 2) Select profile (user input)

The user define the profile to load the correct model for prediction:
`profiles = ["CellMembrane", "Spindle"]`
which is equivalent to:

    "CellMembrane" = 0
    "Spindle" = 1


### 3) Load configuration settings
This will load the config file with the required model settings to use the model for prediction on new images.

### 4) Import image folder (Export folder will be generated automatically)
The user selects the folder with images for import. Supported file formats are: `*.png, *.jpg, *.jpeg, *.tif`

### 5) Execute inference and visualize prediction results
Run the model on all images and produce visual inspection to evaluate segmentation results.

### 6) Property table
Save the prediction in the following format:
    |-binary mask
    |-semantic mask
    |-overlay
    .csv sheet

The .csv sheet contains information such as:
1. bounding box coordinates for each object
2. (x,y) centroids
3. Computational run time

# -------------
### 7). Restart kernel after each profile
# -------------

---

In [None]:
"""
    Prediction pipeline for SpinX
    Author: David Dang
"""

In [None]:
import os
import shutil
import requests
import sys
import colorsys
import random
import math
import re
import time
import datetime
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import time
import skimage
from skimage.color import rgb2gray, gray2rgb, label2rgb
from skimage.measure import label, regionprops
from skimage.segmentation import clear_border
from skimage import exposure, img_as_ubyte
from natsort import natsorted
import termtables as tt # Print table
from statistics import stdev # Statistics
import pandas as pd
import cv2
# Root directory of the project
ROOT_DIR = os.getcwd()
print(ROOT_DIR)
# Import Model
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
from mrcnn import visualize
from mrcnn.visualize import display_images
import mrcnn.model as modellib
from mrcnn.model import log

%matplotlib inline 

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

In [None]:
MODEL_DIR

### 2) Select profile

In [None]:
# List of different profiles
profiles = ["CellMembrane", "Spindle"]
#####################    USER INPUT    #####################
############################################################
user_profile = profiles[1]
# Select profile: [0 ,1]
############################################################


# IMPLEMENTED IN APEER UI
# Use condition to remove unwanted objects
#condition = 1
# Define conditions Upper bound = (9.044 pixel/micron x 20micron)^2
#object_size = 20000
#min_score = 0.92 # 0.96 # AP

# Keep segmented objects with respect to timeframe one (based on centroid distance)
#time_lapse = 1
#n_frames = 21
#n_slices = 3

# Turn plot off for faster processing (reduces run time per image by ~1sec)
plot_on = 0

if user_profile == "CellMembrane":
    from samples import cell_membrane
    # Load config file
    config = cell_membrane.CustomConfig()
    #VAL_DIR = os.path.join(ROOT_DIR, "datasets/cell_membrane")
    # Load validation dataset
    dataset = cell_membrane.CustomDataset()
    #dataset.load_Custom(VAL_DIR, "val")
    dataset.add_class("cell_membrane", 1, "cell_membrane") # Add new class
    # Must call before using the dataset
    dataset.prepare()
elif user_profile == "Spindle":
    from samples import spindle
    # Load config file
    config = spindle.CustomConfig()
    #VAL_DIR = os.path.join(ROOT_DIR, "datasets/spindle")
    # Load validation dataset
    dataset = spindle.CustomDataset()
    #dataset.load_Custom(VAL_DIR, "val")
    dataset.add_class("spindle", 1, "spindle") # Add new class
    # Must call before using the dataset
    dataset.prepare()
print('### Selected profile: ' + user_profile)

# 3. Load configuration settings

In [None]:
def get_ax(rows=1, cols=1, size=16):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Adjust the size attribute to control how big to render images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

# Override the training configurations with a few
# changes for inferencing.
class InferenceConfig(config.__class__):
    # Run detection on one image at a time
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
config.display()

# Device to load the neural network on.
DEVICE = "/gpu:0"  # /cpu:0 or /gpu:0

# Inspect the model in training or inference modes
# values: 'inference' or 'training'
# TODO: code for 'training' test mode not ready yet
TEST_MODE = "inference"

In [None]:
def imclearborder(imgBW, radius):
    """
    Remove all border components in a binary image. Radius defines
    the distance between the image border and object.
    """
    # Given a black and white image, first find all of its contours
    imgBWcopy = imgBW.copy()
    contours,hierarchy = cv2.findContours(imgBWcopy.copy(), cv2.RETR_LIST, 
        cv2.CHAIN_APPROX_SIMPLE)

    # Get dimensions of image
    imgRows = imgBW.shape[0]
    imgCols = imgBW.shape[1]    

    contourList = [] # ID list of contours that touch the border

    # For each contour...|
    for idx in np.arange(len(contours)):
        # Get the i'th contour
        cnt = contours[idx]

        # Look at each point in the contour
        for pt in cnt:
            rowCnt = pt[0][1]
            colCnt = pt[0][0]

            # If this is within the radius of the border
            # this contour goes bye bye!
            check1 = (rowCnt >= 0 and rowCnt < radius) or (rowCnt >= imgRows-1-radius and rowCnt < imgRows)
            check2 = (colCnt >= 0 and colCnt < radius) or (colCnt >= imgCols-1-radius and colCnt < imgCols)

            if check1 or check2:
                contourList.append(idx)
                break

    for idx in contourList:
        cv2.drawContours(imgBWcopy, contours, idx, (0,0,0), -1)

    return imgBWcopy


def addInnerContour(crop_obj_mask, mask_inner):
    """
    Take the inner mask, invert it and put it back
    to the original mask.
    """
    # Invert the object to embedded the filled contour in the original mask
    mask_invert = 255 - mask_inner
    # Convert uint8 [0 - 255] to [0 - 1]
    mask_inner_bl = mask_invert/255
    # Convert original cropped mask to [0 - 1]
    crop_obj_mask_bl = crop_obj_mask/255
    update_mask = mask_inner_bl * crop_obj_mask_bl
    # Convert back to uint8
    update_mask = np.asarray(update_mask*255).astype('uint8')
    return update_mask

def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors

def binarize(img):
    # Binarize image
    img_gray = rgb2gray(img)
    # Binarize image
    bw = img > 0
    return bw

def grayToRgb(im):
    # I think this will be slow
    w, h = im.shape
    ret = np.empty((w, h, 3), dtype=np.uint8)
    ret[:, :, 0] = im
    ret[:, :, 1] = im
    ret[:, :, 2] = im
    return ret

def get_fname(s, file_ext):
    """
    Extract filename from a URL
    Input: String and the file extension
    """
    start = s.rfind("/") + len("/")
    end = s.find(file_ext)
    substring = s[start:end+len(file_ext)]
    return substring

def check_model_name(model_file_name, model_name):
    """
    Check filename of the user model (e.g. spinx_model_cell_membrane_01234.h5)
    prefix_spinx: spinx_model
    model_spinx: cell_membrane
    epoch_spinx: 01234
    """
    # Get file extension
    _, ext = os.path.splitext(model_file_name)
    # Get prefix
    prefix_spinx = 'spinx_model'
    prefix_user = model_file_name[:len(prefix_spinx)] # + 1 to include '_'
    # Get model name
    model_spinx = model_name
    model_user = model_file_name[len(prefix_spinx) + 1:len(prefix_spinx) + 1 + len(model_name)]
    # Get number of epoch
    epoch_spinx = '1234' # 4-digit example
    epoch_user = model_file_name[len(prefix_spinx) + 1 + len(model_spinx) + 1:
                                 len(prefix_spinx) + 1 + len(model_spinx) + 1 + len(epoch_spinx)] # Epoch must be 4 digit number
    if prefix_spinx == prefix_user:
        print('Prefix matches')
    else:
        prefix_user = prefix_spinx
        
    if model_spinx == model_user:
        print('Model name matches')
    else:
        model_user = model_name
    
    if len(epoch_user) == len(epoch_spinx) and epoch_user.isdigit():
        print('4-digits match')
    else:
        epoch_user = epoch_spinx
    
    # Build file name
    model_full_name = prefix_user + '_' + model_user + '_' + epoch_user + ext
    return model_full_name

In [None]:
class SpinX():
    def __init__(self):
        pass
        
    def LoadImageList(self, img_list, index):
        """
        Input: A list with image path to all images; index.
        Output: Read i-th image from the list.
        """
        file_path = img_list[index]
        # Obtain filename
        filename = os.path.basename(file_path)
        # Get only filename without path or extension
        #filename_wo_ext = os.path.splitext(filename)[0]
        # Read image
        img = skimage.io.imread(file_path)
        if len(img.shape)==3:
            img = rgb2gray(img)
        elif len(img.shape)>3:
            print('Input image dimension is larger than 3')
        output = np.array(img)
        return output, filename
    
    def Convert5D(self, img_list, n_slices, n_time):
        """
        Input: A list with image path to all images, number of z-slices; Number of frames.
        Output: 5D array with (H x W x D x T x C) or (Y x X x Z x T x C).
        Z refers to -slices, T refers to time points, C refers to cell id.
        """
        temp_array = []
        name_list = []
        n_cells = len(img_list)//(n_slices*n_time)
        for i in range(len(img_list)):
            mask, name_mask = self.LoadImageList(img_list, i)
            # Read the image dimensions from first image
            if i == 1:
                # Image info
                img_height = mask.shape[0] # Y
                img_width = mask.shape[1] # X
                temp_array.append(mask)
                name_list.append(name_mask)
            else:
                temp_array.append(mask)
                name_list.append(name_mask)
           
        # use "F" Fortran for correct order
        array5d = np.dstack(temp_array).reshape(img_height, img_width, n_slices, n_time, n_cells, order='F')
        # Convert list in nested list
        name_list = np.array(name_list).reshape(n_cells, n_time, n_slices)
        return array5d, n_cells, name_list

    def Convert6D(self, img_list, n_slices, n_time, n_channel=None):
        """
        Input: A list with image path to all images, number of z-slices; Number of frames, n_channel.
        Output: 6D array with (H x W x D x T x S x C) or (Y x X x Z x T x S x C).
        Z refers to -slices, T refers to time points, S refers to Series/Cell, C refers to Channel.
        """
        temp_array = []
        name_list = []
        n_cells = len(img_list)//(n_slices*n_time)
        for i in range(len(img_list)):
            mask, name_mask = self.LoadImageList(img_list, i)
            # Read the image dimensions from first image
            if i == 1:
                # Image info
                img_height = mask.shape[0] # Y
                img_width = mask.shape[1] # X
                temp_array.append(mask)
                name_list.append(name_mask)
            else:
                temp_array.append(mask)
                name_list.append(name_mask)
        # Set Channel
        if n_channel:
            pass
        else:
            n_channel = 1
        # use "F" Fortran for correct order
        array6d = np.dstack(temp_array).reshape(img_height, img_width, n_slices, n_time, n_cells, n_channel, order='F')
        
        # Convert list in nested list
        # name_list = np.array(name_list).reshape(n_cells, n_channel, n_time, n_slices)
        return array6d, n_cells, name_list

In [None]:
class OME_TIFF():
    """
    Adapted from APEER: apeer-ometiff-library (https://github.com/apeer-micro/apeer-ometiff-library)
    and CellProfiler: python-bioformats (https://github.com/CellProfiler/python-bioformats)
    """
    def __init__(self):
        pass
    
    def read_ometiff(self, input_path):
        """
        Read OME-TIFF
        """
        import omexmlClass
        import tifffile
        with tifffile.TiffFile(input_path) as tif:
            array = tif.asarray()
            omexml_string = tif.ome_metadata

        # Turn Ome XML String to an Bioformats object for parsing
        metadata = omexmlClass.OMEXML(omexml_string)

        # Parse pixel sizes
        pixels = metadata.image(0).Pixels
        size_c = pixels.SizeC
        size_t = pixels.SizeT
        size_z = pixels.SizeZ
        size_x = pixels.SizeX
        size_y = pixels.SizeY

        # Expand image array to 5D of order (T, Z, C, X, Y)
        if size_c == 1:
            array = np.expand_dims(array, axis=-3)
        if size_z == 1:
            array = np.expand_dims(array, axis=-4)
        if size_t == 1:
            array = np.expand_dims(array, axis=-5)

        # Makes sure to return the array in (T, Z, C, X, Y) order

        dim_format = pixels.DimensionOrder

        if dim_format == "XYCZT":
            pass
        elif dim_format == "XYZCT":
            array = np.moveaxis(array, 1, 2)
        elif dim_format == "XYCTZ":
            array = np.moveaxis(array, 0, 1)
        elif dim_format == "XYZTC":
            array = np.moveaxis(array, 0, 2)
        elif dim_format == "XYTZC":
            array = np.moveaxis(array, 0, 2)
            array = np.moveaxis(array, 0, 1)
        elif dim_format == "XYTCZ":
            array = np.moveaxis(array, 1, 2)
            array = np.moveaxis(array, 0, 1)
        else:
            print(array.shape)
            raise Exception("Unknow dimension format") 

        return array, metadata, omexml_string

    def update_omexml(self, omexml, Image_ID=None, Image_Name=None, Image_AcquisitionDate=None, 
                      DimensionOrder=None, dType=None, SizeT=None, SizeZ=None, SizeC=None, SizeX=None, SizeY=None,
                      PhysicalSizeX=None, PhysicalSizeY=None, PhysicalSizeZ=None,
                      ExposureTime=None,
                      Channel_ID=None, Channel_Name=None, Channel_SamplesPerPixel=None):
        """
        Update OME-XML with user input.
        """
        import omexmlClass
        metadata = omexmlClass.OMEXML(omexml)

        if Image_ID:
            metadata.image().set_ID(Image_ID)
        if Image_Name:
            metadata.image().set_Name(Image_Name)
        if Image_AcquisitionDate:
            metadata.image().Image.AcquisitionDate = Image_AcquisitionDate

        if DimensionOrder: # Dimension order
            metadata.image().Pixels.DimensionOrder = DimensionOrder
        if dType: # The pixel bit type, for instance PT_UINT8
            metadata.image().Pixels.PixelType = dType
        if SizeT: # The dimensions of the image in the T direction in pixels
            metadata.image().Pixels.set_SizeT(SizeT)
        if SizeZ: # The dimensions of the image in the Z direction in pixels
            metadata.image().Pixels.set_SizeZ(SizeZ)
        if SizeC: # The dimensions of the image in the C direction in pixels
            metadata.image().Pixels.set_SizeC(SizeC)
        if SizeX: # The dimensions of the image in the X direction in pixels
            metadata.image().Pixels.set_SizeX(SizeX)
        if SizeY: # The dimensions of the image in the Y direction in pixels
            metadata.image().Pixels.set_SizeY(SizeY)
        if PhysicalSizeX: # The length of a single pixel in Y direction
            metadata.image().Pixels.set_PhysicalSizeX(PhysicalSizeX)
        if PhysicalSizeY: # The length of a single pixel in Y direction
            metadata.image().Pixels.set_PhysicalSizeY(PhysicalSizeY)
        if PhysicalSizeZ: # The length of a single pixel in Z direction
            metadata.image().Pixels.set_PhysicalSizeZ(PhysicalSizeZ)

        if ExposureTime: # Duration of exposure time in seconds
            metadata.image().Plane.set_ExposureTime = ExposureTime
        
        if Channel_ID:
            metadata.image().Channel.ID = Channel_ID
        if Channel_Name:
            metadata.image().Channel.Name = Channel_Name
        if Channel_SamplesPerPixel:
            metadata.image().Channel.SamplesPerPixel = Channel_SamplesPerPixel
    
        metadata = metadata.to_xml().encode()

        return metadata


    def gen_omexml(self, array):
        """
        Generate OME-XML template
        """
        import omexmlClass
        
        #Dimension order is assumed to be TZCYX
        dim_order = "TZCYX"

        metadata = omexmlClass.OMEXML()
        shape = array.shape
        assert ( len(shape) == 5), "Expected array of 5 dimensions"

        metadata.image().set_Name("IMAGE")
        metadata.image().set_ID("0")

        pixels = metadata.image().Pixels
        pixels.ome_uuid = metadata.uuidStr
        pixels.set_ID("0")

        pixels.channel_count = shape[2]

        pixels.set_SizeT(shape[0])
        pixels.set_SizeZ(shape[1])
        pixels.set_SizeC(shape[2])
        pixels.set_SizeY(shape[3])
        pixels.set_SizeX(shape[4])

        pixels.set_DimensionOrder(dim_order[::-1])

        pixels.set_PixelType(omexmlClass.get_pixel_type(array.dtype))

        for i in range(pixels.SizeC):
            pixels.Channel(i).set_ID("Channel:0:" + str(i))
            pixels.Channel(i).set_Name("C:" + str(i))

        for i in range(pixels.SizeC):
            pixels.Channel(i).set_SamplesPerPixel(1)

        pixels.populate_TiffData()

        return metadata.to_xml().encode()



    def write_ometiff(self, output_path, array, mode='minisblack', omexml_str = None):
        """
        Write OME-TIFF.
        """
        import tifffile
        if omexml_str is None:
            omexml_str = self.gen_omexml(array)

        tifffile.imwrite(output_path, array,  photometric = mode, description=omexml_str, metadata = None)
        return omexml_str
    
    
    def write_omexml(self, path, omexml_str):
        """
        Export for each cell an XML file (prettified).
        """
        import xml.dom.minidom #Prettify XML
        omexml_parse = xml.dom.minidom.parseString(omexml_str)
        omexml_pretty = omexml_parse.toprettyxml()
        
        if path == 'print':
            print(omexml_pretty) 
        else:
            # Export XML
            f =  open(path, "wb")
            f.write(omexml_pretty.encode())
            f.close()
            print(omexml_pretty)

# 6. Obtain properties and export predictions

In [None]:
def execute(image_paths, 
            load_model, 
            condition, 
            object_size, 
            min_score, 
            time_lapse, 
            n_frames, 
            n_slices, 
            export_ome_tiff, 
            pixel_x=0, 
            pixel_y=0, 
            pixel_z=0):
    
    # ================= Setting parameters ================= #
    convert_format = 1
    pref_ext = '.png'
    
    # ================= IMPORT LIST OF FILES ================= #
    # Check if it is a list
    if isinstance(image_paths, list): 
        print("your object is a list !") 
    else: 
        image_paths = [image_paths]
    
    # ================= CREATE OUTPUT DIR ================= #    
    OUTPUT_DIR = 'output/'
    if not os.path.exists( OUTPUT_DIR ):
        # Main Folder
        os.makedirs(  OUTPUT_DIR )
        print('Output: ##### Create Output folder. #####')
                
    
    file_type_list = []
    #def load_image_list
    image_list = []
    for filepath in image_paths:
        # Obtain filename
        filename = os.path.basename(filepath)
        # Keep file extension after first dot (.ome.tiff)
        if filename.split(os.extsep, 1)[1].lower() == 'ome.tiff':
            image_list.append(filepath)
            file_type_list.append('ome-tiff')
        elif os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg', '.tif']:
            image_list.append(filepath)
            file_type_list.append('tif-png')
    # Sort list alphabetically (libary)
    image_list = natsorted(image_list)

    # Conver to a set (get unique values) to check if the file format is consistent
    if len(set(file_type_list)) > 1:
        sys.exit('Mixed file formats were imported.')
    # Get file format of loaded images
    file_type = list(set(file_type_list))[0]
    
    if file_type == 'ome-tiff':
        filename_ome_list = []
        # Use OME-TIFF loader
        OME = OME_TIFF()
        # Most efficient way to stack numpy in a loop is by append to list
        ome_list = []
        meta_list = []
        for i, file in enumerate(image_list):
            # Read OME-TIFF: T, Z, C, Y, X
            ome_array, metadata, xml_str = OME.read_ometiff(file)
            # Obtain filename
            filename_ome_list.append(os.path.basename(file))
            ome_list.append(ome_array)
            meta_list.append(metadata)
        # Convert list to 6D array: S, T, Z, C, Y, X
        array6d_ome = np.stack(ome_list, axis=0)
        array6d_ome.shape
        # Meta data
        pixel_x = meta_list[0].image(0).Pixels.get_PhysicalSizeX()
        pixel_y = meta_list[0].image(0).Pixels.get_PhysicalSizeY()
        pixel_z = meta_list[0].image(0).Pixels.get_PhysicalSizeZ()
        
        # Print meta data
        print( 'Pixel size X: ' + str(pixel_x) )
        print( 'Pixel size Y: ' + str(pixel_y) )
        print( 'Pixel size Z: ' + str(pixel_z) )
        
        # Convert to SpinX 6D array format by permutation
        # From: [S, T, Z, C, Y, X] to [Y, X, Z, T, S, C]
        perm_ome_spinx = (4, 5, 2, 1, 0, 3 ) # Order

        array6d_sx = np.transpose(array6d_ome, perm_ome_spinx)
        array6d_sx.shape
    else:
        SX = SpinX()
        array6d_sx, _, filename_img_list = SX.Convert6D(image_paths, n_slices, n_frames)
    # ================= IMPORT LIST OF FILES END ================= #
    
    
    
    
    
    # ================= MODEL IMPORT ================= #
    # Check if user select pre-trained model
    user_list = ['NA', 'na', 'Default', 'DEFAULT', '0']
    # Load SpinX model if user types in 'Default' (or anything but an URL)
    if load_model in user_list or len(load_model) < 10:
    # Load default model
        model_default = 1
    else:
        file_ext = '.h5'
        # Extract model name
        model_name = get_fname(load_model, file_ext)
        
        # Check file name format of the new model. Rename if needed.

        model_name = check_model_name(model_file_name = model_name,
                                      model_name = config.NAME) # Get name of the model from the config file
        
        
        # Create folder for new model
        out_name = config.NAME # Get name of the model from the config file
        model_dir = os.path.join(MODEL_DIR,
                                  out_name + '{}'.format(
                                      datetime.datetime.now().
                                      strftime("%Y%m%dT%H%M")))
        os.mkdir(model_dir)
        # Download model and save it in the model folder
        myfile = requests.get(load_model)
        open(os.path.join(model_dir, model_name), 'wb').write(myfile.content)
        model_default = 0
        
    # Create model in inference mode
    model = modellib.MaskRCNN(mode="inference", 
                              model_dir=MODEL_DIR,
                              config=config)
    
    if model_default == 1:
        # Find default model (first model)
        weights_path = model.find_first()
    else:
        # Find latest model
        weights_path = model.find_last()
    
    # Load weights
    model.load_weights(weights_path, by_name=True)
    print("Loading model and weights completed.")

    # Obtain folder and file name to store in csv
    # Get head of path
    path_head = os.path.split(weights_path)[0]
    # Get folder name (last position)
    path_folder = os.path.split(path_head)[1]
    # Get model name
    path_name = os.path.split(weights_path)[1]
    # Merge
    model_path_name = os.path.join(path_folder, path_name)
    # ================= MODEL IMPORT END ================= #

    # User specification
    df = pd.DataFrame(columns=['model_name',
                               'filename',
                               'img_height',
                               'img_width',
                               'img_dim',
                               'cond_filt',
                               'time_filt',
                               'series',
                               'time',
                               'z-slice',
                               'obj_id',
                               'bbox_x1',
                               'bbox_y1',
                               'bbox_x2',
                               'bbox_y2',
                               'bbox_height',
                               'bbox_width',
                               'pred_runtime',
                               'post_runtime',
                               'excluded_obj'
                              ])



    ##### in progress
    total_time = []
    counter = 0
    num_img = len(image_paths)
    bad_obj_counter = 0

    # Preallocate variables
    bad_objects = []
    
    # For Timelapse
    centroid_list = []
    # Number of cells
    N_cells = array6d_sx.shape[4]
    list_break = np.arange(0, num_img, n_frames*n_slices)
    cell_id = 0

    # Set channel
    ch = 0
    
    array6d_raw = [] # Store 6D array for raw
    array6d_bin = [] # Store 6D array for binary
    array6d_label = [] # Store 6D array for label
    array6d_overlay = [] # Store 6D array for overlay    
    
    # Loop through Series [Y, X, Z, T, S, C]
    for s in range(array6d_sx.shape[4]):
        if file_type == 'ome-tiff':
            filename = filename_ome_list[s]
        elif file_type == 'tif-png':
            filename = filename_img_list[counter]
            
        # Loop through Time [Y, X, Z, T, S, C]
        for t in range(array6d_sx.shape[3]):
            # Loop through Z-Slices [Y, X, Z, T, S, C]
            for z in range(array6d_sx.shape[2]):
                # Start timer
                start = time.time()
                print('counter var: ' + str(counter))
                img =  array6d_sx[:,:,z,t,s,ch]
                img_orig = np.array(img)
                img_orig = img_as_ubyte(exposure.rescale_intensity(img_orig)) # Rescale uint16 to uint8
                img_orig = grayToRgb(img_orig)
                # Obtain image dimensions
                img_height = img_orig.shape[0]
                img_width = img_orig.shape[1]
                img_dim = img_orig.shape[2]

                # Prediction
                results = model.detect([img_orig], verbose=1)
                r = results[0]
                if plot_on == 1:
                    visualize.display_instances(img_orig, r['rois'], r['masks'], r['class_ids'], 
                                                dataset.class_names, r['scores'], title="Predictions", figsize=(15,15))

                # End timer
                end = time.time()
                e_time = end - start
                print('Elapsed time for prediction: %f seconds' %(round(e_time,3)))
                print() 
                print() 
                total_time.append(e_time)

                # th-dataframe
                df_line = pd.DataFrame(columns=['model_name',
                                                'filename',
                                                'img_height',
                                                'img_width',
                                                'img_dim',
                                                'cond_filt',
                                                'time_filt',
                                                'series',
                                                'time',
                                                'z-slice',
                                                'obj_id',
                                                'bbox_x1',
                                                'bbox_y1',
                                                'bbox_x2',
                                                'bbox_y2',
                                                'bbox_height',
                                                'bbox_width',
                                                'pred_runtime',
                                                'post_runtime',
                                                'excluded_obj'
                                               ])    

                ### Loop through bounding boxes and masks (every mask has a bounding box)

                # Create canvas for merging all masks (each value represents an instance) Note: A different data structure is needed for multiple classes
                merged_mask = np.zeros((img_height, img_width), 'uint8')
                # Create canvas for binary mask
                binary_mask = np.zeros((img_height, img_width), 'uint8')

                # Use conditional filtering
                if condition == 1:
                    if len(r['rois']) > 1:
                        all_idx = []
                        idx_to_clear = []
                        area_array = []
                        # Loop through detect ROIS
                        for m in range(len(r['rois'])):
                            all_idx.append(m)
                            # Clear objects that are very close to the border
                            bw_clear = imclearborder(r['masks'][:,:,m].astype('uint8'), 10)
                            if np.sum(bw_clear)==0:
                                # Store indices to be cleared
                                idx_to_clear.append(m)


                            # Keep only the object with the highest score
                            if r['scores'][m] < min_score:
                                idx_to_clear.append(m)

                            # Check for area size
                            props = regionprops(r['masks'][:,:,m].astype('uint8'))
                            if user_profile == "Spindle":
                                if props[0].area > object_size:
                                    idx_to_clear.append(m)
                            elif user_profile == "CellMembrane":
                                if props[0].area < object_size:
                                    idx_to_clear.append(m)
                            area_array.append(props[0].area)
                        # Find largest blob
                        #max_value = max(area_array)
                        #max_index = area_array.index(max_value)

                        # Identify sublist from previous list and complete list
                        #list_remaining = list(set(all_idx)^set(idx_to_clear))
                        #to_remove = list(filter(lambda a: a != max_index, list_remaining))
                        #idx_to_clear = idx_to_clear + to_remove

                        ## Keep only unique ids
                        idx_to_clear = list(set(idx_to_clear))
                        bad_objects = len(idx_to_clear)
                        r['rois'] = np.delete(r['rois'],idx_to_clear,0)
                        r['class_ids'] = np.delete(r['class_ids'],idx_to_clear,0)
                        r['scores'] = np.delete(r['scores'],idx_to_clear,0)
                        r['masks'] = np.delete(r['masks'],idx_to_clear,2)
                        # If empty array (repeat buy lower min_score requirement by 10%)
                        if r['rois'].size == 0:
                            # Re-run model on image
                            results = model.detect([img_orig], verbose=1)
                            r = results[0]
                            if len(r['rois']) > 1:
                                all_idx = []
                                idx_to_clear = []
                                area_array = []
                                for m in range(len(r['rois'])):
                                    all_idx.append(m)
                                    # Clear objects that are very close to the border
                                    bw_clear = imclearborder(r['masks'][:,:,m].astype('uint8'), 10)
                                    if np.sum(bw_clear)==0:
                                        # Store indices to be cleared
                                        idx_to_clear.append(m)


                                    # Keep only the object with the highest score
                                    if r['scores'][m] <= min_score-(min_score*0.10):
                                        idx_to_clear.append(m)

                                    # Check for area size
                                    props = regionprops(r['masks'][:,:,m].astype('uint8'))
                                    if user_profile == "Spindle":
                                        if props[0].area > object_size:
                                            idx_to_clear.append(m)
                                    elif user_profile == "CellMembrane":
                                        if props[0].area < object_size:
                                            idx_to_clear.append(m)
                                    area_array.append(props[0].area)
                                # Find largest blob
                                #max_value = max(area_array)
                                #max_index = area_array.index(max_value)

                                # Identify sublist from previous list and complete list
                                #list_remaining = list(set(all_idx)^set(idx_to_clear))
                                #to_remove = list(filter(lambda a: a != max_index, list_remaining))
                                #idx_to_clear = idx_to_clear + to_remove

                                ## Keep only unique ids
                                idx_to_clear = list(set(idx_to_clear))
                                bad_objects = len(idx_to_clear)
                                r['rois'] = np.delete(r['rois'],idx_to_clear,0)
                                r['class_ids'] = np.delete(r['class_ids'],idx_to_clear,0)
                                r['scores'] = np.delete(r['scores'],idx_to_clear,0)
                                r['masks'] = np.delete(r['masks'],idx_to_clear,2)
                    else:
                        bad_objects = 0

                if time_lapse == 1:
                    idx_to_clear_time = []
                    dist_list = []
                    all_idx_time = []
                    # Check if ROI exists
                    if r['rois'].size > 0:
                        props_cen = regionprops(r['masks'].astype('uint8'))
                        yx_centroid = [props_cen[0].centroid[0], props_cen[0].centroid[1]]
                        centroid_list.append(yx_centroid)
                        if t == 0:
                            # Change based on number of cells
                            cell_id += 1
                        else:
                            for n in range(len(r['rois'])):
                                all_idx_time.append(n)
                                current = regionprops(r['masks'][:,:,n].astype('uint8'))
                                current_yx = [current[0].centroid[0], current[0].centroid[1]]


                                y1 = current_yx[0]
                                y2 = centroid_list[t-1][0]
                                x1 = current_yx[1]
                                x2 = centroid_list[t-1][1]

                                dist = np.sqrt( (x2 - x1)**2 + (y2 - y1)**2 )
                                dist_list.append(dist)

                            min_dist = min(dist_list)
                            min_dist_idx = dist_list.index(min_dist)
                            filt_idx_time = list(filter(lambda a: a != min_dist_idx, all_idx_time))
                            r['rois'] = np.delete(r['rois'],filt_idx_time,0)
                            r['class_ids'] = np.delete(r['class_ids'],filt_idx_time,0)
                            r['scores'] = np.delete(r['scores'],filt_idx_time,0)
                            r['masks'] = np.delete(r['masks'],filt_idx_time,2)
                            bad_objects = bad_objects + len(filt_idx_time)
                    else:
                        yx_centroid = []  
                        centroid_list.append(yx_centroid)


                # Obtain tensor with masks (format: H x W x instances)
                mask_tensor = r['masks']
                # Number of objects
                num_obj = mask_tensor.shape[2]
                # Obtain tensor with bounding boxes (format: NUMBER OF OBJECTS x (y1, x1, y2, x2))
                bbox_tensor = r['rois']
                # Total number of bounding boxes
                num_bbox = len(bbox_tensor)


                if num_bbox > 0:
                    # Loop through instances
                    for jj in range(0, num_bbox):
                        # Original format of bounding box: y1, x1, y2, x2
                        th_bbox = r['rois'][jj]
                        # Re-order coordinates
                        x1 = th_bbox[1]
                        y1 = th_bbox[0]
                        x2 = th_bbox[3]
                        y2 = th_bbox[2]

                        # Calculate the bounding box dimensions
                        bbox_height = x2 - x1
                        bbox_width = y2 - y1        
                        # Write data to dataframe
                        #df_line.loc[jj] = model_path_name, filename, img_height, img_width, img_dim, jj, x1, y1, x2, y2, bbox_height, bbox_width, e_time, 0, 0, 0, bad_objects


                        ### MASK
                        # Take th mask from the tensor
                        th_mask = mask_tensor[:, :, jj]
                        # Convert true/false to image
                        th_mask_img = th_mask.astype(np.uint8)
                        th_mask_img*= 255

                        # Assign value if the th-mask is 1
                        merged_mask[th_mask_img==255] = jj + 1

                        # Create binary
                        binary_mask[th_mask_img==255] = 255
                else:
                    jj = []
                    x1 = []
                    y1 = []
                    x2 = []
                    y2 = []
                    bbox_height = []
                    bbox_width = []

                # Sort dataframe by obj_id
                df_line.sort_values(by=['obj_id'])
                # Append dataframe for each image
                df = df.append(df_line)
                
                # Create overlay
                # Binarize image
                bw = binarize(merged_mask)

                # Delete border connected to image border
                label_image = label(bw)
                image_label_overlay = label2rgb(label_image, image=img_orig, image_alpha = 0.6, bg_label = 0)
                # Rescale overlay
                img_overlay = image_label_overlay*255
                img_overlay = img_overlay.astype('uint8')
                # Export merged mask
                

                # Append to 6D array (for OME-TIFF)
                # Rescale raw image
                img_raw = rgb2gray(img_orig)*255
                img_raw = img_raw.astype('uint8')
                # For 6D array
                array6d_raw.append(img_raw) # Raw 6D
                array6d_bin.append(binary_mask) # Binary 6D
                array6d_label.append(merged_mask) # Label 6D
                array6d_overlay.append(img_overlay) # Overlay 6D  

        
                end_2 = time.time()
                post_time = end_2 - end
                print('Elapsed time for post-processing: %f seconds' %(round(post_time,3)))
                # Write data to dataframe
                df_line.loc[0] = model_path_name, filename, img_height, img_width, img_dim, condition, time_lapse, s, t, z, jj, x1, y1, x2, y2, bbox_height, bbox_width, e_time, post_time, bad_objects
                df = df.append(df_line)

                counter += 1
        
    
    # Reshape 6d array to OME-TIFF Format: [S, T, Z, C, Y, X] from SpinX [Y, X, Z, T, S, C]
    # Raw ( Don't use order='F' for a reshaping non 1D array - use it for 1D only)
    array6d_raw = np.array(array6d_raw).reshape(array6d_sx.shape[4], array6d_sx.shape[3], array6d_sx.shape[2], array6d_sx.shape[5], array6d_sx.shape[0] , array6d_sx.shape[1])
    # Binary
    array6d_bin = np.array(array6d_bin).reshape(array6d_sx.shape[4], array6d_sx.shape[3], array6d_sx.shape[2], array6d_sx.shape[5], array6d_sx.shape[0] , array6d_sx.shape[1])
    # Label
    array6d_lab = np.array(array6d_label).reshape(array6d_sx.shape[4], array6d_sx.shape[3], array6d_sx.shape[2], array6d_sx.shape[5], array6d_sx.shape[0] , array6d_sx.shape[1])
    # Overlay (The overlay is a RGB. Hence, the dimensions are now: STZYXC with C = 3)
    array6d_overlay = np.array(array6d_overlay).reshape(array6d_sx.shape[4], array6d_sx.shape[3], array6d_sx.shape[2], array6d_sx.shape[0] , array6d_sx.shape[1], 3)
    
    
    # =========== EXPORT =========== #
    # Preallocate
    output_raw = [] # Store output raw images here
    output_binary = [] # Store output binary images here
    output_label = [] # Store output label images here
    output_overlay = [] # Store output overlay images here
    
    # Export individual images from 6D array
    # Create folder structure
 
    list_dir_raw = os.path.join(OUTPUT_DIR, 'raw')
    if not os.path.exists( list_dir_raw ):
        # Main Folder
        os.makedirs( list_dir_raw )
        print('Output: ##### Create raw input folder. #####')

    list_dir_bin = os.path.join(OUTPUT_DIR, 'binary')
    if not os.path.exists( list_dir_bin ):
        # Main Folder
        os.makedirs( list_dir_bin )
        print('Output: ##### Create binary output folder. #####')

    list_dir_lab = os.path.join(OUTPUT_DIR, 'label')
    if not os.path.exists( list_dir_lab ):
        # Main Folder
        os.makedirs( list_dir_lab )
        print('Output: ##### Create label output folder. #####')

    list_dir_overlay = os.path.join(OUTPUT_DIR, 'overlay')
    if not os.path.exists( list_dir_overlay ):
        # Main Folder
        os.makedirs( list_dir_overlay )
        print('Output: ##### Create overlay output folder. #####')
            
    # If user wants to export as OME-TIFF
    if export_ome_tiff == 1:
        # Use OME-TIFF loader
        OME = OME_TIFF()
        # Loop through Series [S, T, Z, C, Y, X]
        for s in range(array6d_raw.shape[0]):
            print('Series ' + str(s) )
            # Get 5D array and write output as multi-page OME-TIFF files
            array5d_raw = array6d_raw[s,:,:,:,:,:]
            array5d_bin = array6d_bin[s,:,:,:,:,:]
            array5d_label = array6d_lab[s,:,:,:,:,:]
            array5d_overlay = array6d_overlay[s,:,:,:,:,:]
            
            if file_type == 'ome-tiff':
                # Get OME-TIFF file name and extension
                ome_tiff_name = filename_ome_list[s].split(os.extsep, 1)[0]
                ext = '.' + filename_ome_list[s].split(os.extsep, 1)[1]
            elif file_type == 'tif-png':
                ome_tiff_name = filename_img_list[s*n_frames*n_slices].split(os.extsep, 1)[0]
                ext = '.ome.tiff'
            
            ome_tiff_name_raw = ome_tiff_name + ext
            ome_tiff_name_binary = ome_tiff_name + '_BINARY' + ext
            ome_tiff_name_label = ome_tiff_name + '_LABEL' + ext
            ome_tiff_name_overlay = ome_tiff_name + '_OVERLAY' + ext
            # Generate OME-XML template
            omexml_temp = OME.gen_omexml(array5d_bin)
            # Update OME-XML
            omexml_upd = OME.update_omexml(omexml_temp, PhysicalSizeX=pixel_x, PhysicalSizeY=pixel_y, PhysicalSizeZ=pixel_z)
            # Write OME-TIFF
            omexml_new = OME.write_ometiff(os.path.join(list_dir_raw, ome_tiff_name_raw), array5d_raw, mode='minisblack', omexml_str = omexml_upd)
            _ = OME.write_ometiff(os.path.join(list_dir_bin, ome_tiff_name_binary), array5d_bin, mode='minisblack', omexml_str = omexml_upd)
            _ = OME.write_ometiff(os.path.join(list_dir_lab, ome_tiff_name_label), array5d_label, mode='minisblack', omexml_str = omexml_upd)
            _ = OME.write_ometiff(os.path.join(list_dir_overlay, ome_tiff_name_overlay), array5d_overlay, mode='rgb', omexml_str = omexml_upd)
            # Export OME-XML
            OME.write_omexml(os.path.join(OUTPUT_DIR, ome_tiff_name +'.xml'), omexml_new)
            # Store full path for export
            output_raw.append(ome_tiff_name_raw)
            output_binary.append(ome_tiff_name_binary)
            output_label.append(ome_tiff_name_label)
            output_overlay.append(ome_tiff_name_overlay)
    else:
        counter_export = 0 # Counting index for list export
        # Loop through Series [S, T, Z, C, Y, X]
        for s in range(array6d_raw.shape[0]):
            # Loop through Times [S, T, Z, C, Y, X]
            for t in range(array6d_raw.shape[1]):
                # Loop through Z-slices [S, T, Z, C, Y, X]
                for z in range(array6d_raw.shape[2]):
                    img_raw = array6d_raw[s,t,z,ch,:,:] # Raw
                    img_bin = array6d_bin[s,t,z,ch,:,:] # Binary
                    img_lab = array6d_lab[s,t,z,ch,:,:] # Label
                    img_overlay = array6d_overlay[s,t,z,:,:,:] # Overlay
                    # Obtain filename is different for OME-TIFF and list of TIF
                    if file_type == 'ome-tiff':
                        filename_export_list = filename_ome_list[s].split(os.extsep, 1)[0] + '_s' + str(s) + '_t' + str(t) + '_z' + str(z)
                        # Export raw image as single tif
                        filename_export_list_raw = filename_export_list + '.tif'
                    elif file_type == 'tif-png':
                        filename_export_list = os.path.splitext(os.path.basename(filename_img_list[counter_export]))[0]
                        # Export raw image as single tif
                        filename_export_list_raw = filename_export_list + '.tif'
                        
                    filename_export_list_binary = filename_export_list + '_BINARY' + pref_ext
                    filename_export_list_label = filename_export_list + '_LABEL' + pref_ext
                    filename_export_list_overlay = filename_export_list + '_OVERLAY' + pref_ext

                    # Export
                    cv2.imwrite(os.path.join(list_dir_raw, filename_export_list_raw), img_raw) # Binary
                    cv2.imwrite(os.path.join(list_dir_bin, filename_export_list_binary), img_bin) # Binary
                    cv2.imwrite(os.path.join(list_dir_lab, filename_export_list_label), img_lab) # Label
                    cv2.imwrite(os.path.join(list_dir_overlay, filename_export_list_overlay), img_overlay) # Overlay
                    # Store full path for export
                    output_raw.append( os.path.join(list_dir_raw, filename_export_list_raw) )
                    output_binary.append( os.path.join(list_dir_bin, filename_export_list_binary) )
                    output_label.append( os.path.join(list_dir_lab, filename_export_list_label) )
                    output_overlay.append( os.path.join(list_dir_overlay, filename_export_list_overlay) )
                    counter_export += 1
     
    # Convert seconds to hh:mm:ss
    hours, seconds =  sum(total_time) // 3600, sum(total_time) % 3600
    minutes, seconds = sum(total_time) // 60, sum(total_time) % 60

    # Runtime calculations
    total_runtime = str(f"{round(hours):02d}" + "h " + f"{round(minutes):02d}" + "mins " + f"{round(seconds):02d}" + "secs")
    avg_runtime = str( round(sum(total_time)/len(total_time),3) ) + " seconds"
    # Variances cant be computed with only 1 value (assigned to 0)
    if len(total_time) < 2:
        var_runtime = str(0) + " seconds"
    else:
        var_runtime = str( round(stdev(total_time), 3) ) + " seconds"

        # Change alignment for adding more columns 'c'
    pred_tab = tt.to_string(
        [[ num_img, total_runtime, avg_runtime, var_runtime ]],
        header=["N", "Total prediction time:", "Avg. prediction time for one image", "SD of prediction time:"],
        style=tt.styles.ascii_thin_double,
        alignment="lccr",
        # padding=(0, 1),
    )
    print(pred_tab)

    # Assign runtime
    #df['total_pred_runtime'] = total_time[0]


    # Export dataframe to .csv
    csv_filename = "{}{:%Y%m%dT%H%M}.csv".format("summaryTable_", datetime.datetime.now())
    df.to_csv(os.path.join(OUTPUT_DIR, csv_filename), index=False)
    
    # Delete if new model was imported
    if model_default == 0:
        # Delete downloaded file once executed
        shutil.rmtree(model_dir)
    
    print('=== SpinX AI: Completed ===')
    
    # Create a ZIP file of the OUTPUT_DIR
    print('=== Export as ZIP archive: Start ===')
    # Destination of ZIP file (in ROOT DIR)
    output_zip_name = 'spinx_ai_output'
    shutil.make_archive(output_zip_name, 'zip', OUTPUT_DIR)
    output_zip_name_full = output_zip_name + '.zip'
    print('=== Export as ZIP archive: Completed ===')
    return {'output_zip': output_zip_name_full}

In [None]:
# # Multiple objects (OME-TIFF)
# execute(
#     image_paths = [
#     'input_ome_tiff/exp2020-026-set001_hela_his2b-gfp_mcherry-tubulin_mg132-10uM_01-07-DMSO_08-19-MARK2i-10uM01_17_R3D_EQT_w605_t01_z1.ome.tiff',
#     'input_ome_tiff/exp2020-026-set001_hela_his2b-gfp_mcherry-tubulin_mg132-10uM_01-07-DMSO_08-19-MARK2i-10uM01_18_R3D_EQT_w605_t01_z1.ome.tiff'
#     ],
#     load_model = 'NA',
#     condition = 1, 
#     object_size = 20000, 
#     min_score = 0.92, 
#     time_lapse = 1, 
#     n_frames = 5, 
#     n_slices = 3, 
#     export_ome_tiff = 1)

In [None]:
## Multiple objects (LIST)
#
#input_dir = 'input'
#list_full = os.listdir(input_dir)
## Sort list
#list_full = natsorted(list_full)
#list_filt = []
#for fp in list_full:
#    if os.path.splitext(fp)[1] in ['.png', '.jpg', '.jpeg', '.tif']:
#            list_filt.append(os.path.join(input_dir, fp))
#
#
#execute(
#    image_paths = list_filt,
#    load_model = 'NA',
#    condition = 1, 
#    object_size = 20000, 
#    min_score = 0.92, 
#    time_lapse = 1, 
#    n_frames = 5, 
#    n_slices = 3, 
#    export_ome_tiff = 1)