# SpinX 3D Modelling (Latest v.13.4.2021)

v13.4.2021:
- Supports OME-TIFF
- Expand data array from 5D to 6D
- Restructing classes and functions
- Outsource SpinX classes and functions
- Restructing code for easy access to user parameters
- Bug fixes
- Resolved undeclared variables

### Load dependencies
Please install all libraries used in this jupyter notebook.

In [1]:
import os
import sys
import numpy as np
from numpy import linalg, sqrt
#from tkinter import filedialog
#from tkinter import * # File dialog
from skimage import io
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import rc
from io import BytesIO
import skimage
from skimage.color import rgb2gray, gray2rgb, label2rgb
from skimage.measure import label, regionprops, find_contours
from skimage.segmentation import clear_border
from skimage.morphology import erosion
from skimage.morphology import dilation
from skimage.morphology import square
from skimage.morphology import black_tophat
from skimage.morphology import disk
from skimage.morphology import remove_small_objects
from skimage.draw import circle_perimeter
from scipy import ndimage as ndi
from scipy import optimize
from scipy.signal import medfilt
from scipy.ndimage import gaussian_filter
from scipy.spatial import ConvexHull
import scipy
from skimage import util, draw
from skimage.filters import threshold_local
from skimage import img_as_bool
from skimage.morphology import convex_hull_image
import imageio
import math
from math import atan2, asin, pi, sqrt, log
import platform
import pandas as pd
import datetime
import time
from statistics import mean, stdev, median # Statistics
import termtables as tt # Print table
from scipy.spatial.transform import Rotation as R
import copy
from itertools import product, combinations
from natsort import natsorted
import spinx
import shutil # To create zip

In [2]:
# Solve library issue with 3D plots:
# matplotlib
mpl_vers = matplotlib.__version__
from distutils.version import StrictVersion

# Use old 3d axes plotting if 1. Else use newer version
if StrictVersion(mpl_vers) <= StrictVersion('3.0.2'):
    old_3d = 1
elif StrictVersion(mpl_vers) >= StrictVersion('3.3.0'):
    old_3d = 0
else:
    print('Wrong Matplotlib version. Either use 3.0.2 or > 3.3.0 (or newer).')

In [3]:
# Call SpinX functions
SX = spinx.SpinX()
SX_PSF = spinx.SpinX_PSF()
SX_Model = spinx.SpinX_Modelling()
ET = spinx.EllipsoidTool()


In [7]:
def execute(
    input_cell_raw = None,
    input_cell_mask = None, 
    input_spindle_raw = None,
    input_spindle_mask = None,
    exp_acronym = 'mark2i',
    exp_name = 'exp2021-001-set001',
    exp_set = 1,
    exp_group_a = 'control',
    n_group_a = 1,
    exp_group_b = 'mark2i',
    n_group_b = 1,
    n_frames = 5,
    n_slices = 3,
    exp_interval = 3,
    pixel_x = 0.0688750,
    pixel_y = 0.0688750,
    pixel_z = 2,
    voxel_x = 512,
    voxel_y = 512,
    voxel_z = 512,
    wavelength = 0.605,
    na = 1.42,
    magification = 60,
    res_lateral = 0.100,
    res_axial = 0.200,
    ri_specimen = 1.34,
    ri_coverslip = 1.522,
    ri_medium = 1.524,
    working_dist = 150,
    coverslip_thickness = 170,
    particle_dist = 2
    ):
    
    # =================== Generate Output folder
    # Modelling Notebook location
    file_model_root = 1
    FILE_DIR = os.path.abspath(os.getcwd())
    # If modelling file is in ROOT folder
    if file_model_root == 0:
        ROOT_DIR = os.path.split(FILE_DIR)[0]
    else:
        ROOT_DIR = FILE_DIR
    ROOT_DIR
    
    OUTPUT_BASE_DIR = os.path.join(ROOT_DIR, 'output')
    OUTPUT_CSV_DIR = os.path.join(OUTPUT_BASE_DIR, 'output_csv', exp_acronym)
    # Make dir for modelling output
    # Path 
    OUTPUT_DIR = os.path.join(OUTPUT_BASE_DIR, 'figs', exp_acronym)
    # Create folder structure
    SX.create_dir_structure(OUTPUT_DIR, exp_name)
    # Create folder structure
    if not os.path.exists( OUTPUT_CSV_DIR ):
        # Main Folder
        os.makedirs( OUTPUT_CSV_DIR )
        print('Output: ##### Create csv output folder. #####')
    else:
        print('Output: ##### csv output folder exists. #####')

    # =================== LATEX settings
    latex_on = 0 # If Latex is available, set 1. Else 0
    # Rename variables
    # =================== Imaging settings
    condition_group = [exp_group_a, exp_group_b]
    pixelsize_xy = pixel_x # micron per pixel (length of pixel in micron)
    pixels_per_micron_xy = 1/pixelsize_xy # number of pixels per micron
    pixels_per_micron_z = pixel_z

    exp_duration = n_frames*exp_interval # Total length of experiment in minutes

    focal_slice = 1 # Best focal plane
    median_slice = np.median(np.arange(1,(n_slices+1)))
    # =================== modelling settings
    voxel_x = voxel_x // pixels_per_micron_xy #px (refers to: X-AXIS)
    voxel_y = voxel_y // pixels_per_micron_xy #px (refers to: Y-AXIS)
    voxel_z = voxel_z // pixels_per_micron_xy #px (refers to: Z-AXIS)

    # Set threshold (pole_cortax gap) to 1 micron
    thres_dist_cortex = 1

    # =================== PSF settings
    psf_true = 0 # Import generated/experimental PSF or 0 to simulate a PSF

    # CY5
    #wavelength_user_ex = 0.645 # in micron
    #wavelength_user_em = 0.705 # in micron

    #wavelength_user_ex = 0.470 # in micron
    #wavelength_user_em = 0.525 # in micron

    # TRITC
    #wavelength_user_ex = 0.555 # in micron
    #wavelength_user_em = 0.605 # in micron


    #=======================================#
    # Cell Cortex 3D Reconstruction
    #=======================================#
    print('=== Cell Cortex 3D reconstruction: Start ===')
    # Raw
    array6d_mem_raw, n_mem_raw, name_list_raw = SX.multi_importer(input_cell_raw, n_slices, n_frames)
    # Mask
    array6d_mem_mask, n_mem_mask, name_list_mask = SX.multi_importer(input_cell_mask, n_slices, n_frames)
    # Check if the numbers are equal.
    if n_mem_raw == n_mem_mask:
        n_cells = n_mem_mask
    else:
        print("Number of raw and mask images does not match.")
    """
    Output: Height x Width x Depth x Time X Cells X Channel
    """
    # Create a sequence for the position of z-slices: Align focal slice to (0,0). Works for even and odd no. slices.
    start_p = (focal_slice)*pixels_per_micron_z*(-1)
    end_p = (n_slices - focal_slice-1)*pixels_per_micron_z
    dist_slices = SX.generate_seq(start_p, end_p, pixels_per_micron_z)
    # Convert micron to pixel !!! Later: Already in pixel, convert to micron
    # Obtain Y, X, Z coordinates of object boundary
    contours6d_orig, info6d_orig = SX.boundary_6d(array6d_mem_mask, dist_slices, pixels_per_micron_xy)
    # Repeat to make a copy for correction
    contours6d, info6d = SX.boundary_6d(array6d_mem_mask, dist_slices, pixels_per_micron_xy)

    # Reconstruction starts here
    
    # Turn on latex for axis labels
    if latex_on == 1:
        rc('text', usetex=True)
    else:
        rc('text', usetex=False)    

    # Decide, which ellipsoid fit to use (0 = mininum enclosing fit; 1 =  least square fit - 3x faster)
    fitting_select = 0

    # Add pixels on top and bot plane for balancing
    est_select = 1

    membrane_correction = 0 # Slow

    # Align segmented cortex contour to the reference centroid (z-slice with the largest dimameter)
    align_stack = 1

    #Turn plot on (1) or off (0)
    plot_on = 0
    # Print minor-plots (for inspection)
    subplot_on = 0


    factor = 1 # Surrounding neighbor pixel distance
    merged_contours6d = []
    merged_contours6d_split = []
    merged_surface6d = []
    merged_axis6d = []
    merged_center6d = []
    merged_radii6d = []
    merged_rotation6d = []
    total_time = []

    for c_id in range(len(info6d)):
        # Start timer
        start = time.time()
        print("---------------| Cell id:" + str(c_id))
        temp_contours6d = []
        temp_contours6d_split = []
        temp_surface6d = []
        temp_axis6d = []
        temp_center6d = []
        temp_radii6d = []
        temp_rotation6d = []
        for tp in range(len(info6d[0])):
            print("Time point:" + str(tp))
            # Find out which slice has the largest diameter (equals to mid-plane of the object)
            max_idx, max_value = SX.find_diameter(np.array(info6d[c_id][tp])[:,1])
            # Use the obtained index to get centroid coordinates
            x0 = info6d[c_id][tp][max_idx][3]
            y0 = info6d[c_id][tp][max_idx][4]

            if align_stack == 1:
                for z in range(len(contours6d[c_id][tp])):
                    if z == max_idx:
                        continue
                    else:
                        contours6d[c_id][tp][z], info6d[c_id][tp][z][3], info6d[c_id][tp][z][4] = SX.align_stack(contours6d[c_id][tp][z], [info6d[c_id][tp][z][3], info6d[c_id][tp][z][4]], [x0, y0])
                if subplot_on == 1:
                    # No correction
                    SX.align_stack_plot(array6d_mem_raw[:,:,:,tp,c_id,0], contours6d_orig[c_id][tp], info6d_orig[c_id][tp], '', exp_name, OUTPUT_DIR)
                    # Correction
                    SX.align_stack_plot(array6d_mem_raw[:,:,:,tp,c_id,0], contours6d[c_id][tp], info6d[c_id][tp], '_corrected', exp_name, OUTPUT_DIR)

            if membrane_correction == 1:
                # Find mid-plane
                mid_plane_idx = median(range(n_slices))
                # Define gap distances away from mid plane
                gap_dist = contours6d[c_id][tp][mid_plane_idx][0,2] - contours6d[c_id][tp][mid_plane_idx-1][0,2]
                balanced_contours6d = contours6d[c_id][tp].copy()
                if max_idx < mid_plane_idx:
                    diff_z = n_slices-(max_idx + 1) # Calculate how many z-slices are missing to balance
                    idx_incr = max_idx
                    for z in range(diff_z):
                        idx_incr = idx_incr + 1
                        # Copy XY-column
                        xyz = contours6d[c_id][tp][idx_incr].copy()
                        # Adjust for z coordinates
                        z_dist = balanced_contours6d[0][0,2]-gap_dist
                        xyz[:,2] = np.full((1, len(xyz[:,2])), z_dist)
                        balanced_contours6d.insert(0, xyz)
                        idx_counter = idx_counter + 1
                elif max_idx > mid_plane_idx:
                    diff_z = max_idx-1 # Calculate how many z-slices are missing to balance
                    idx_incr = max_idx - 1
                    for z in range(diff_z):
                        idx_incr = idx_incr - 1
                        # Copy XY-column
                        xyz = contours6d[c_id][tp][idx_incr].copy()
                        # Adjust for z coordinates
                        z_dist = balanced_contours6d[-1][0,2]+gap_dist
                        xyz[:,2] = np.full((1, len(xyz[:,2])), z_dist)
                        balanced_contours6d.append(xyz)
                contours6d[c_id][tp] = balanced_contours6d.copy()


            # Create array with Y, X, Z coordinates for the estimated data point (including neighbor pixels) that at the very bot and top of the object.
            if est_select == 1:
                #if max_idx != median_slice:

                # Use the obtained index to get centroid coordinates
                x0 = info6d[c_id][tp][max_idx][3]
                y0 = info6d[c_id][tp][max_idx][4]


                # Estimate bot plane
                est_bot = SX.slice_estimate(x0, y0, factor, max_value, slice_select="bot")
                # Estimate top plane
                est_top = SX.slice_estimate(x0, y0, factor, max_value, slice_select="top")
                # Incoperate bot plane to the first position in contours5d array
                contours6d[c_id][tp].insert(0,est_bot)
                # Incoperate top plane to the last position in contours5d array
                contours6d[c_id][tp].append(est_top)

            # Read iith 3D contour
            coord_array = contours6d[c_id][tp]
            temp_contours6d_split.append(coord_array)
            # Convert list to array
            coord_array = np.vstack(coord_array)
            temp_contours6d.append(coord_array)

            # Find the ellipsoid
            if fitting_select == 0:
                center, radii, rotation = ET.getMinVolEllipse(coord_array, .01)
            elif fitting_select == 1:
                data_reg = ET.data_regularize(coord_array)
                center, radii, rotation, _ = ET.ellipsoid_fit(data_reg)

            # ==== Plot
            fig = plt.figure(figsize=(15, 15))
            ax = fig.add_subplot(111, projection='3d')
            if old_3d == 1:
                #scaling = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
                #ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3)
                ax.set_aspect("equal") # Works only with older matplotlib==3.0.2 (unsolved bug with 3.3.1)
            else:
                xs = np.array([0,voxel_x])
                ys = np.array([0,voxel_y])
                zs = np.array([0,voxel_z])
                ax.set_box_aspect((np.ptp(xs), np.ptp(ys), np.ptp(zs)))  # aspect ratio is 1:1:1


            # plot points
            ax.scatter(coord_array[:,0], coord_array[:,1], coord_array[:,2], color='g', marker='*', s=1)

            # plot ellipsoid
            xyz_cell, xyz_cell_axis = ET.plotEllipsoid(center, radii, rotation, ax=ax, plotAxes=True, rStride=4, cStride=4, cageColor = 'k', cageAlpha=0.2)
            ax.set(xlim=(0, voxel_x), ylim=(0, voxel_y), zlim=(0-(voxel_z/2), (voxel_z/2)))

            # draw cube
            rx = [0, voxel_x]
            ry = [0, voxel_y]
            rz = [0-(voxel_z/2), (voxel_z/2)]
            for s, e in combinations(np.array(list(product(rx, ry, rz))), 2):
                if np.sum(np.abs(s-e)) == rx[1]-rx[0]:
                    ax.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)

            # Change view
            ax.view_init(elev=30, azim=60)

            ax.tick_params(axis='both', which='major', labelsize=20, pad=15)
            ax.tick_params(axis='both', which='minor', labelsize=20, pad=15)

            # Disable z-label rotation
            ax.zaxis.set_rotate_label(False)  # disable automatic rotation

            # Increase fontsize
            ax.set_xlabel('$\mathbf{x}$', fontsize=30, labelpad=10)
            ax.set_ylabel('$\mathbf{y}$', fontsize=30, labelpad=15)
            ax.set_zlabel('$\mathbf{z}$', fontsize=30, labelpad=15, rotation=0)

            # make the grid lines transparent
            ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
            ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
            ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

            # Save
            name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
            #full_name_export = 'figs/' + exp_name + '/model/cortex/segmentation/' + name_export + '.pdf'
            full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'cortex', 'segmentation', name_export + '.pdf')
            plt.savefig(full_name_export, dpi=600, bbox_inches='tight', transparent=True)

            if plot_on == 0:
                #plt.show(block=False)
                #plt.pause(3)
                plt.close(fig)
                del fig
            else:
                plt.show()
                del fig
            # ==== Plot end

            # Append to list
            temp_surface6d.append(xyz_cell)
            temp_axis6d.append(xyz_cell_axis)

            # Append to list
            temp_center6d.append(center)
            temp_radii6d.append(radii)
            temp_rotation6d.append(rotation)

        # End timer
        end = time.time()
        e_time = end - start
        print('Elapsed time for one cell: %f seconds' %(round(e_time,3)))
        print() 
        print() 
        total_time.append(e_time)

        # Merge surface coordinates of all time points into new list per cell
        merged_contours6d.append(temp_contours6d)
        merged_contours6d_split.append(temp_contours6d_split)
        merged_surface6d.append(temp_surface6d)
        merged_axis6d.append(temp_axis6d)
        # Merge center, radii, rotation
        merged_center6d.append(temp_center6d)
        merged_radii6d.append(temp_radii6d)
        merged_rotation6d.append(temp_rotation6d)

    # Convert seconds to hh:mm:ss
    hours, seconds =  sum(total_time) // 3600, sum(total_time) % 3600
    minutes, seconds = sum(total_time) // 60, sum(total_time) % 60

    # Runtime calculations
    total_runtime = str(f"{round(hours):02d}" + "h " + f"{round(minutes):02d}" + "mins " + f"{round(seconds):02d}" + "secs")
    avg_runtime = str( round(sum(total_time)/len(total_time),3) ) + " seconds"
    # Variances cant be computed with only 1 value (assigned to 0)
    if len(total_time) < 2:
        var_runtime = str(0) + " seconds"
    else:
        var_runtime = str( round(stdev(total_time), 3) ) + " seconds"

        # Change alignment for adding more columns 'c'
    pred_tab = tt.to_string(
        [[ n_cells, total_runtime, avg_runtime, var_runtime ]],
        header=["N", "Total run time:", "Avg. run time for one cell:", "SD of run time:"],
        style=tt.styles.ascii_thin_double,
        alignment="lccr",
        # padding=(0, 1),
    )
    print(pred_tab)    
    print('=== Cell Cortex 3D reconstruction: Completed ===')
    #=======================================#
    # PSF Simulation (Gibson-Lanni model)
    #=======================================#
    print('=== PSF simulation: Start ===')
    if psf_true == 1:
        PSF = SX_PSF.psf_load(BEAD_DIR)
        if PSF.dtype != np.uint16:
            PSF_conv = 2**16 * PSF
            PSF = PSF_conv.astype(np.uint16)
        thres = 2000
    else:
        _, PSF, res_axial = SX_PSF.generate_psf(wavelength, na, magification, ri_specimen, ri_coverslip, ri_medium, working_dist, coverslip_thickness, res_lateral, res_axial, particle_dist)
        thres = 2000
    fit_model_xy, fit_model_z, fit_model_z_norm = SX_PSF.psf_fit(PSF, thres)    
    # Map intensity changes to spatial distance
    zFromInt1 = lambda I: fit_model_z_norm.stddev * sqrt(-1*log(I)) + fit_model_z_norm.mean
    zFromInt2 = lambda I: fit_model_z_norm.stddev * (-1)*sqrt(-1*log(I)) + fit_model_z_norm.mean
    # Vectorize function to allow multiple values in np.array
    zFromInt1_vec = np.vectorize(zFromInt1)
    zFromInt2_vec = np.vectorize(zFromInt2)
    print('=== PSF simulation: Completed ===')
    #=======================================#
    # Spindle 3D reconstruction
    #=======================================#
    print('=== Spindle 3D reconstruction: Start ===')
    # Raw
    array6d_spind_raw, n_cells_raw, name_list_raw = SX.multi_importer(input_spindle_raw, n_slices, n_frames)
    # Mask
    array6d_spind_mask, n_cells_mask, name_list_mask  = SX.multi_importer(input_spindle_mask, n_slices, n_frames)

    # Check if the numbers are equal.
    if n_cells_raw == n_cells_mask:
        n_cells = n_cells_mask
    else:
        print("Number of raw and mask images does not match.")
    """
    Output: Height x Width x Depth x Time X Cells X Channel
    """
    # Create a sequence for the position of z-slices: Align focal slice to (0,0). Works for even and odd no. slices.
    start_p = (focal_slice)*pixels_per_micron_z*(-1)
    end_p = (n_slices - focal_slice-1)*pixels_per_micron_z
    dist_slices = SX.generate_seq(start_p, end_p, pixels_per_micron_z)
    # Reconstruction starts here
    # Turn on latex for axis labels
    if latex_on == 1:
        rc('text', usetex=True)
    else:
        rc('text', usetex=False) 

    # Decide, which ellipsoid fit to use (0 = mininum enclosing fit; 1 =  least square fit - 3x faster)
    fitting_select = 0

    # Plot axes
    plot_axes = 1

    # Plot on
    plot_on = 0

    # Turn z-corrections using PSF
    psf_correction = 1

    # Pole xy-correction
    pole_xy_correct = 1

    # Switch condition based on number of cells
    cond_idx = 0

    merged_array_s = [] # Coordinates of spindle signals
    merged_array_split_s = [] # Coordinates of spindle signals split by planes
    # No correction
    merged_array_s_no_correction = [] # Coordinates of spindle signals
    merged_array_split_s_no_correction = [] # Coordinates of spindle signals split by planes

    merged_surface6d_s = [] # Surface coordinates
    merged_axis6d_s = [] # Axis coordinates
    merged_center6d_s = [] # Centroid coordinates
    merged_radii6d_s = [] # Semi-axes: a, b, c
    merged_rotation6d_s = [] # Rotation

    # Keep only top 30% of the brighest pixels
    prop_data = 0.3

    start = []
    total_time = []

    count = 0
    cell_count = 0

    # Dataframe to store before and after correction
    # Preallocate dataframe
    df_cor = pd.DataFrame(columns=[
                                'N',
                                'filename',
                                'exp_set',
                                'condition',
                                'cell_id_total',
                                'cell_id',
                                'time_point',
                                'spind_len',
                                'spind_len_cor',
                                'spind_pole1_dist',
                                'spind_pole2_dist'
                                ])

    # Loop over cells
    for c_id in range(len(array6d_spind_mask[0,0,0,0,:,0])):
    #for c_id in range(1,2):
        # Start timer
        start = time.time()
        print("---------------| Cell id:" + str(c_id))
        temp_array_tp_s = []
        temp_array_tp_split_s = []
        # No correction
        temp_array_tp_split_s_no_correction = []
        temp_array_tp_s_no_correction = []
        temp_surface6d_tp_s = []
        temp_axis6d_tp_s = []
        temp_center6d_tp_s = []
        temp_radii6d_tp_s = []
        temp_rotation6d_tp_s = []


        # Loop over time points
        for tp in range(len(array6d_spind_mask[0,0,0,:,0,0])):
            count += 1
    #    for tp in range(20,21):
            print("Time point:" + str(tp))   
            temp_array_z_s = []
            temp_array_z_s_no_correction = []
            #Loop over z-sections
            for z in range(len(array6d_spind_mask[0,0,:,0,0,0])):
                # Read mask from z-stack
                temp_mask = array6d_spind_mask[:,:,z,tp,c_id,0]
                # Read the raw image corresponding to the mask
                temp_raw = array6d_spind_raw[:,:,z,tp,c_id,0]
                # Add blur
                #temp_raw = gaussian_filter(temp_raw, sigma=1, mode='nearest')
                # Convert mask to boolean [0, 1]
                mask_bl = temp_mask//temp_mask.max()
                # Burn mask on image
                burned_img = temp_raw * mask_bl
                # Keep only the top 30% pixels (based on intensity values) - exclude 0s
                qntval = np.quantile(burned_img[np.nonzero(burned_img)], 1-prop_data)
                if qntval == burned_img.max():
                    qntval = qntval - 1.0
                # Turn pixels to 0 if under quantile value
                binary_qnt = burned_img > qntval
                # Convert new mask to uint8
                binary_qnt_scale = binary_qnt * 255 # Scale up to 255
                updated_mask = binary_qnt_scale.astype(np.uint8)
                # Find coordinates with signals
                #indices = np.where(updated_mask == [255])
                indices_all = np.where(updated_mask == [255])
                # Transpose
                indices_all = np.transpose(np.vstack([indices_all[0], indices_all[1]]))           
                # Create a Hull
                hull = ConvexHull(indices_all)
                indices = hull.points.T
                # Convert to micron
                indices_um = indices / pixels_per_micron_xy

                coordinates_um = np.transpose(np.vstack((indices_um[0], indices_um[1], np.full((1, len(indices_um[0])),dist_slices[z]) )))
                # Keep a copy with pixel (ONLY X,Y but Z is in micron)
                coordinates_pxl = np.transpose(np.vstack((indices[0], indices[1], np.full((1, len(indices[0])),dist_slices[z]) )))
                # To reduce computational time, dropout n data points while keeping object structure.
                # The number of remaining data points are at least 20% (can be changed through parameter input)
                keep_steps = round( len(coordinates_um)/(len(coordinates_um) - (len(coordinates_um) * (1-prop_data))) )
                red_coordinates = coordinates_um[0::keep_steps]
                red_coordinates_pxl = coordinates_pxl[0::keep_steps]
                # Make a copy to overlay non-correction vs. correction
                red_coordinates_no_correction = red_coordinates.copy() # No correction
                # Convert to float
                red_coordinates = red_coordinates.astype(float)
                red_coordinates_pxl = red_coordinates_pxl.astype(float)

                if psf_correction == 1:
                    temp_raw_norm = SX_PSF.norm_img(temp_raw)
                    intValue_inImg = temp_raw_norm[red_coordinates_pxl[:,1].astype(int), red_coordinates_pxl[:,0].astype(int)] # Swap X,Y to Y,X
                    # Max intensity is equal to amplitude of Gaussian fit (due to normalisation of PSF and input image to [0-1])
                    max_int = fit_model_z_norm.amplitude[0]
                    if max_int > 1.0:
                        max_int = 1.0

                    # Calculate differences in intensity values between reference image and input image
                    difIntValue = intValue_inImg - max_int
                    # Since we use max intensity based on image type, we have to map only negative values
                    difIntValue[difIntValue<0] = zFromInt2_vec( abs(difIntValue[difIntValue<0]) ) - fit_model_z_norm.mean

                    if dist_slices[z] < 0:
                        red_coordinates[:,2] = dist_slices[z] - (difIntValue*res_axial)
                    else:
                        red_coordinates[:,2] = dist_slices[z] + (difIntValue*res_axial)

                temp_array_z_s.append(red_coordinates)
                temp_array_z_s_no_correction.append(red_coordinates_no_correction) # No correction

            # Convert list to array
            temp_array_tp_split_s.append(temp_array_z_s)
            coord_array_s = np.vstack(temp_array_z_s)
            temp_array_tp_s.append(coord_array_s)
            # Do the same for no correction
            temp_array_tp_split_s_no_correction.append(temp_array_z_s_no_correction)
            coord_array_no_correction = np.vstack(temp_array_z_s_no_correction)
            temp_array_tp_s_no_correction.append(coord_array_no_correction)

            # Find the ellipsoid
            if fitting_select == 0:
                center, radii, rotation_raw = ET.getMinVolEllipse(coord_array_s, .01)
            elif fitting_select == 1:
                data_reg = ET.data_regularize(coord_array_s)
                center, radii, rotation_raw, _ = ET.ellipsoid_fit(data_reg) # Rotation before pole correction





            # ==== Plot
            fig = plt.figure(figsize=(15, 15))
            ax = fig.add_subplot(111, projection='3d')
            if old_3d == 1:
                #scaling = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
                #ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3)
                ax.set_aspect("equal") # Works only with older matplotlib==3.0.2 (unsolved bug with 3.3.1)
            else:
                xs = np.array([0,voxel_x])
                ys = np.array([0,voxel_y])
                zs = np.array([0,voxel_z])
                ax.set_box_aspect((np.ptp(xs), np.ptp(ys), np.ptp(zs)))  # aspect ratio is 1:1:1

            rx = [0, voxel_x]
            ry = [0, voxel_y]
            rz = [0-(voxel_z/2), (voxel_z/2)]

            ax.set(xlim=(rx[0], rx[1]), ylim=(ry[0], ry[1]), zlim=(rz[0], rz[1]))



            # Change view
            ax.view_init(elev=30, azim=60)


            if plot_axes == 1:
                # Increase thickness
                #for axis in [ax.w_xaxis, ax.w_yaxis, ax.w_zaxis]:
                #    axis.line.set_linewidth(8)

                # Axes style
                # make the grid lines transparent
                ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
                ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
                ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

                ax.tick_params(axis='both', which='major', labelsize=20, pad=15)
                ax.tick_params(axis='both', which='minor', labelsize=20, pad=15)

                # Disable z-label rotation
                ax.zaxis.set_rotate_label(False)  # disable automatic rotation

                # Increase fontsize
                ax.set_xlabel('$\mathbf{x}$', fontsize=30, labelpad=10)
                ax.set_ylabel('$\mathbf{y}$', fontsize=30, labelpad=15)
                ax.set_zlabel('$\mathbf{z}$', fontsize=30, labelpad=15, rotation=0)
                # draw voxel
                rx = [0, voxel_x]
                ry = [0, voxel_y]
                rz = [0-(voxel_z/2), (voxel_z/2)]
                for s, e in combinations(np.array(list(product(rx, ry, rz))), 2):
                    if np.sum(np.abs(s-e)) == rx[1]-rx[0]:
                        ax.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)
                    if np.sum(np.abs(s-e)) == rz[1]-rz[0]:
                        ax.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)    

            else:
                ax.axis('off')

            # plot points (no correction)
            #ax.scatter(coord_array_no_correction[:,0], coord_array_no_correction[:,1], coord_array_no_correction[:,2], color='silver', marker='*', s=1)
            # plot points
            ax.scatter(coord_array_s[:,0], coord_array_s[:,1], coord_array_s[:,2], color='m', marker='*', s=1)

            # plot ellipsoid
            xyz_spin, xyz_spin_axis = ET.plotEllipsoid(center, radii, rotation_raw, ax=ax, plotAxes=True, cageColor='k', cageAlpha=0.2)


            # Save
            name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
            #full_name_export = 'figs/' + exp_name + '/model/spindle/axes/' + name_export + '.pdf'
            full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'spindle', 'axes', name_export + '.pdf')
            plt.savefig(full_name_export, dpi=600, bbox_inches='tight', transparent=True)

            if plot_on == 1:
                plt.show()

            plt.close(fig)
            del fig

            # ==== Plot end


            # Apply pole XY correction
            if pole_xy_correct == 1:

                # Mask
                slice_height, slice_width = array6d_spind_mask[:,:,0,tp,c_id,0].shape # Image dimensions from reference image (slice)
                merged_slices = np.zeros(shape=(slice_height, slice_width)) # Preallocate canvas to merge masks through 3D
                merged_slices = merged_slices.astype(np.uint8)

                # Raw MAX projection
                raw_max = np.max(array6d_spind_raw[:,:,:,tp,c_id,0], axis=2)

                for z_id in range( len(array6d_spind_mask[0,0,:,0,0,0]) ):
                    z_img = array6d_spind_mask[:,:,z_id,tp,c_id,0]
                    merged_slices += z_img
                    merged_slices[merged_slices<z_img]= np.iinfo(merged_slices.dtype).max # To prevent overflow when adding 2 images

                # Obtain pixel value along the spindle length axis (ax_id = 2)
                ax_id = 2
                line_length = merged_slices[(xyz_spin_axis[ax_id][:,0]/pixelsize_xy).astype(int), (xyz_spin_axis[ax_id][:,1]/pixelsize_xy).astype(int)]
                # Detect the first appearance of signal
                list_max = np.where(line_length == line_length.max())
                # Find first occurrence
                first_val = list_max[0][0]
                # Find last occurrence
                occurrences = np.count_nonzero(line_length)
                last_val = first_val + (occurrences - 1)

                # Update correction by re-calculating centroids, radii
                # calculate centroid after correction to apply re-fitting
                center = ( xyz_spin_axis[2][first_val,:] + xyz_spin_axis[2][last_val,:] ) / 2
                # Calculate new radius of corrected axis
                radius_correct = np.linalg.norm(xyz_spin_axis[2][first_val,:] - xyz_spin_axis[2][last_val,:]) / 2

                # Plot
                plt.figure(figsize=(15,15))
                fig, ax = plt.subplots()
                ax.axis('off')
                ax.axes.get_xaxis().set_visible(False)
                ax.axes.get_yaxis().set_visible(False)
                #plt.imshow(merged_slices)
                ax.imshow(raw_max, cmap='gray')
                # First pole
                ax.plot(xyz_spin_axis[ax_id][0,1]/pixelsize_xy, xyz_spin_axis[ax_id][0,0]/pixelsize_xy, 'ro')
                ax.plot(xyz_spin_axis[ax_id][first_val,1]/pixelsize_xy, xyz_spin_axis[ax_id][first_val,0]/pixelsize_xy, 'g+')
                # Second pole
                ax.plot(xyz_spin_axis[ax_id][-1,1]/pixelsize_xy, xyz_spin_axis[ax_id][-1,0]/pixelsize_xy, 'ro')
                ax.plot(xyz_spin_axis[ax_id][last_val,1]/pixelsize_xy, xyz_spin_axis[ax_id][last_val,0]/pixelsize_xy, 'g+')

                name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
                #full_name_export = 'figs/' + exp_name + '/model/spindle/correct/' + name_export + '.pdf'
                full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'spindle', 'correct', name_export + '.pdf')
                plt.savefig(full_name_export, dpi=300, bbox_inches='tight',transparent=True, pad_inches=0)
                plt.close('all')

                # Update database
                radii[2] = radius_correct



                print('Before Correction (Surface): ' + str(xyz_spin[0]))
                print('Before Correction (height): ' + str(np.linalg.norm(xyz_spin_axis[0][0] - xyz_spin_axis[0][-1])))
                print('Before Correction (width): ' + str(np.linalg.norm(xyz_spin_axis[1][0] - xyz_spin_axis[1][-1])))
                print('Before Correction (length): ' + str(np.linalg.norm(xyz_spin_axis[2][0] - xyz_spin_axis[2][-1])))


                # Refit ellipsoid
                xyz_spin_cor, xyz_spin_axis_cor = ET.plotEllipsoid(center, radii, rotation_raw, ax=None, plotAxes=False)

                print('-------------')

                print('After Correction (Surface): ' + str(xyz_spin_cor[0]))
                print('After Correction (height): ' + str(np.linalg.norm(xyz_spin_axis_cor[0][0] - xyz_spin_axis_cor[0][-1])))
                print('After Correction (width): ' + str(np.linalg.norm(xyz_spin_axis_cor[1][0] - xyz_spin_axis_cor[1][-1])))
                print('After Correction (length): ' + str(np.linalg.norm(xyz_spin_axis_cor[2][0] - xyz_spin_axis_cor[2][-1])))

                # Store in dataframe
                df_cor.at[count, 'N'] = count
                df_cor.at[count, 'filename'] = name_list_raw[c_id][tp][focal_slice]
                df_cor.at[count, 'exp_set'] = exp_set

                if c_id == n_group_a:
                    cell_count = 0 # Reset
                    cond_idx = 1

                condition = condition_group[cond_idx]
                df_cor.at[count, 'condition'] = condition
                df_cor.at[count, 'cell_id_total'] = c_id
                df_cor.at[count, 'cell_id'] = cell_count
                df_cor.at[count, 'time_point'] = tp

                df_cor.at[count, 'spind_len'] = np.linalg.norm(xyz_spin_axis[2][0] - xyz_spin_axis[2][-1]) # Spindle length before corrections
                df_cor.at[count, 'spind_len_cor'] = np.linalg.norm(xyz_spin_axis_cor[2][0] - xyz_spin_axis_cor[2][-1]) # Spindle length after corrections
                df_cor.at[count, 'spind_pole1_dist'] = np.linalg.norm(xyz_spin_axis[ax_id][0,:] - xyz_spin_axis[ax_id][first_val,:]) # Spindle pole 1 distance w/o and with correction
                df_cor.at[count, 'spind_pole2_dist'] = np.linalg.norm(xyz_spin_axis[ax_id][-1,:] - xyz_spin_axis[ax_id][last_val,:]) # Spindle pole 1 distance w/o and with correction

                # Overwrite with new corrections
                xyz_spin = xyz_spin_cor
                xyz_spin_axis = xyz_spin_axis_cor


            # Append to list
            temp_surface6d_tp_s.append(xyz_spin)
            temp_axis6d_tp_s.append(xyz_spin_axis)

            # Append to list
            temp_center6d_tp_s.append(center)
            temp_radii6d_tp_s.append(radii)
            temp_rotation6d_tp_s.append(rotation_raw) # Use rotation_raw instead of rotation because pole correction applies only on xy-coordinates not or angles

        # Cell count
        cell_count += 1

        # Append to array
        merged_array_split_s.append(temp_array_tp_split_s) # Keep coordinates for individual planes accessible
        merged_array_s.append(temp_array_tp_s)
        # No correction
        merged_array_split_s_no_correction.append(temp_array_tp_split_s_no_correction)
        merged_array_s_no_correction.append(temp_array_tp_s_no_correction)
        # Merge surface coordinates of all time points into new list per cell
        merged_surface6d_s.append(temp_surface6d_tp_s)
        merged_axis6d_s.append(temp_axis6d_tp_s)
        # Merge center, radii, rotation
        merged_center6d_s.append(temp_center6d_tp_s)
        merged_radii6d_s.append(temp_radii6d_tp_s)
        merged_rotation6d_s.append(temp_rotation6d_tp_s)

        # End timer
        end = time.time()
        e_time = end - start
        print('Elapsed time for one cell: %f seconds' %(round(e_time,3)))
        print() 
        print() 
        total_time.append(e_time)

    # Convert seconds to hh:mm:ss
    hours, seconds =  sum(total_time) // 3600, sum(total_time) % 3600
    minutes, seconds = sum(total_time) // 60, sum(total_time) % 60

    if pole_xy_correct == 1:
        # Export csv file
        csv_filename = "{}{:%Y%m%dT%H%M}.csv".format("refinement_" + str(exp_name) + '_', datetime.datetime.now())
        df_cor.to_csv(os.path.join(OUTPUT_CSV_DIR, csv_filename), index=False)

    # Runtime calculations
    total_runtime = str(f"{round(hours):02d}" + "h " + f"{round(minutes):02d}" + "mins " + f"{round(seconds):02d}" + "secs")
    avg_runtime = str( round(sum(total_time)/len(total_time),3) ) + " seconds"
    # Variances cant be computed with only 1 value (assigned to 0)
    if len(total_time) < 2:
        var_runtime = str(0) + " seconds"
    else:
        var_runtime = str( round(stdev(total_time), 3) ) + " seconds"

        # Change alignment for adding more columns 'c'
    pred_tab = tt.to_string(
        [[ n_cells, total_runtime, avg_runtime, var_runtime ]],
        header=["N", "Total run time:", "Avg. run time for one cell:", "SD of run time:"],
        style=tt.styles.ascii_thin_double,
        alignment="lccr",
        # padding=(0, 1),
    )
    print(pred_tab)
    print('=== Spindle 3D reconstruction: Completed ===')
    #=======================================#
    # 3D Modelling
    #=======================================#    
    print('=== 3D Modelling: Start ===')
    # Preallocate dataframe
    df = pd.DataFrame(columns=['N',
                               'filename',
                               'exp_set',
                               'condition',
                               'img_height',
                               'img_width',
                               'img_dim',
                               'wavelength',
                               'pixels_per_micron_xy',
                               'gap_micron_z',
                               'cell_id_total',
                               'cell_id',
                               'time_point',
                               'time_point_mins',
                               'cell_centroid_x',
                               'cell_centroid_y',
                               'cell_centroid_z',
                               'cell_pole_h_x1',
                               'cell_pole_h_y1',
                               'cell_pole_h_z1',
                               'cell_pole_h_x2',
                               'cell_pole_h_y2',
                               'cell_pole_h_z2',
                               'cell_pole_w_x1',
                               'cell_pole_w_y1',
                               'cell_pole_w_z1',
                               'cell_pole_w_x2',
                               'cell_pole_w_y2',
                               'cell_pole_w_z2',
                               'cell_pole_l_x1',
                               'cell_pole_l_y1',
                               'cell_pole_l_z1',
                               'cell_pole_l_x2',
                               'cell_pole_l_y2',
                               'cell_pole_l_z2',
                               'cell_axis_a',
                               'cell_axis_b',
                               'cell_axis_c',
                               'cell_2d_major_axis',
                               'cell_2d_minor_axis',
                               'cell_axis_ratio',
                               'cell_volume',
                               'cell_area',
                               'cell_sphericity',
                               'cell_eangle_alpha_deg',
                               'cell_eangle_beta_deg',
                               'cell_eangle_gamma_deg',
                               'cell_eangle_alpha_rad',
                               'cell_eangle_beta_rad',
                               'cell_eangle_gamma_rad',
                               'cell_qangle_x',
                               'cell_qangle_y',
                               'cell_qangle_z',
                               'cell_qangle_w',
                               'cell_translation_dist',
                               'spindle_centroid_x',
                               'spindle_centroid_y',
                               'spindle_centroid_z',
                               'spindle_pole_h_x1',
                               'spindle_pole_h_y1',
                               'spindle_pole_h_z1',
                               'spindle_pole_h_x2',
                               'spindle_pole_h_y2',
                               'spindle_pole_h_z2',
                               'spindle_pole_w_x1',
                               'spindle_pole_w_y1',
                               'spindle_pole_w_z1',
                               'spindle_pole_w_x2',
                               'spindle_pole_w_y2',
                               'spindle_pole_w_z2',
                               'spindle_pole_l_x1',
                               'spindle_pole_l_y1',
                               'spindle_pole_l_z1',
                               'spindle_pole_l_x2',
                               'spindle_pole_l_y2',
                               'spindle_pole_l_z2',
                               'spindle_pole_l_x1_dif',
                               'spindle_pole_l_y1_dif',
                               'spindle_pole_l_z1_dif',
                               'spindle_pole_l_x2_dif',
                               'spindle_pole_l_y2_dif',
                               'spindle_pole_l_z2_dif',                           
                               'spindle_pole_h_corrected',
                               'spindle_pole_w_corrected',
                               'spindle_pole_l_corrected',
                               'spindle_axis_h',
                               'spindle_axis_w',
                               'spindle_axis_l',
                               'spindle_axis_h_ratio',
                               'spindle_axis_w_ratio',
                               'spindle_axis_l_ratio',
                               'spindle_volume',
                               'spindle_area',
                               'spindle_sphericity',
                               'spindle_eangle_alpha_deg',
                               'spindle_eangle_beta_deg',
                               'spindle_eangle_gamma_deg',
                               'spindle_eangle_alpha_rad',
                               'spindle_eangle_beta_rad',
                               'spindle_eangle_gamma_rad',
                               'spindle_qangle_x',
                               'spindle_qangle_y',
                               'spindle_qangle_z',
                               'spindle_qangle_w',
                               'spindle_translation_dist',
                               'pole_cortex_dist_h_1',
                               'pole_cortex_dist_h_2',
                               'pole_cortex_dist_w_1',
                               'pole_cortex_dist_w_2',
                               'pole_cortex_dist_l_1',
                               'pole_cortex_dist_l_2',
                               'pole_cortex_close_h_1',
                               'pole_cortex_close_h_2',
                               'pole_cortex_close_w_1',
                               'pole_cortex_close_w_2',
                               'pole_cortex_close_l_1',
                               'pole_cortex_close_l_2',
                               'pole_cortex_velo_h_1',
                               'pole_cortex_velo_h_2',
                               'pole_cortex_velo_w_1',
                               'pole_cortex_velo_w_2',
                               'pole_cortex_velo_l_1',
                               'pole_cortex_velo_l_2',
                               'spindle_pole_disp_1',
                               'spindle_pole_disp_2',
                               'spindle_pole_msd_1',
                               'spindle_pole_msd_2',
                               'spindle_pole_msd_total',
                               'spindle_pole_long_disp_1',
                               'spindle_pole_eq_disp_1',
                               'spindle_pole_z_disp_1',
                               'spindle_pole_long_disp_2',
                               'spindle_pole_eq_disp_2',
                               'spindle_pole_z_disp_2',
                               'spindle_pole_long_disp_1_rel2d',
                               'spindle_pole_eq_disp_1_rel2d',
                               'spindle_pole_z_disp_1_rel2d',
                               'spindle_pole_long_disp_2_rel2d',
                               'spindle_pole_eq_disp_2_rel2d',
                               'spindle_pole_z_disp_2_rel2d',
                               'spindle_pole_long_disp_1_rel3d',
                               'spindle_pole_eq_disp_1_rel3d',
                               'spindle_pole_z_disp_1_rel3d',
                               'spindle_pole_long_disp_2_rel3d',
                               'spindle_pole_eq_disp_2_rel3d',
                               'spindle_pole_z_disp_2_rel3d', 
                               'spindle_pole_velo_1',
                               'spindle_pole_velo_2',
                               'spindle_pole_long_velo_1',
                               'spindle_pole_eq_velo_1',
                               'spindle_pole_z_velo_1',
                               'spindle_pole_long_velo_2',
                               'spindle_pole_eq_velo_2',
                               'spindle_pole_z_velo_2',
                              ])
    
    # 3D Modelling starts here
    # Turn on latex for axis labels
    if latex_on == 1:
        rc('text', usetex=True)
    else:
        rc('text', usetex=False) 

    # Plot on
    plot_on = 0

    # Turn SpindlePoleTracker on
    tracker = 1
    poles_s = []
    axis_color = ('r', 'g', 'b')

    # Select method for Ray-tracing: Heuristic (rt_select == 0); Analytic (rt_select == 1)
    rt_select = 1

    # Preallocate for post-filtering
    disp_l1_all = []
    disp_l2_all = []

    count = 0
    cell_count = 0
    total_time = []

    # For pole-cortex distances
    merged_temp_d_h = []
    merged_temp_d_w = []
    merged_temp_d_l = []

    # For poles
    merged_pole_h_s = []
    merged_pole_w_s = []
    merged_pole_l_s = []
    # Centroid
    merged_centroid_s = [] # For spindle

    merged_centroid = [] # For cell 
    # Switch condition based on number of cells
    cond_idx = 0
    # Loop over cells
    for c_id in range(len(merged_axis6d_s[:])):
        # Start timer
        start = time.time()
        print("---------------| Cell id:" + str(c_id))
        # For poles
        temp_poles = []
        temp_pole_h = []
        temp_pole_w = []
        temp_pole_l = []
        temp_poles_s = []
        temp_pole_h_s = []
        temp_pole_w_s = []
        temp_pole_l_s = []
        # Spindle centroid
        temp_centroid_s = []
        # Cell centroid
        temp_centroid = []
        # For post filtering
        temp_disp_l1 = []
        temp_disp_l2 = []

        # For pole to cortex plot
        data_raw_1 = np.array([])
        data_raw_2 = np.array([])
        time_raw_mins_1 = np.array([])
        time_raw_mins_2 = np.array([])

        # For spindle pole-cortex interaction
        temp_d_h = []
        temp_d_w = []
        temp_d_l = []
        # Spindle height, width and length
        temp_h_s = []
        temp_w_s = []
        temp_l_s = []

        # For rotation
        angles_list = []
        angles_list_s = []

        # For animation
        fused_all = []
        plot_all = []

        # Row ids for each cell
        temp_count = []
        # Loop over time points
        for tp in range(len(merged_axis6d_s[0][:])):
            print("Time point:" + str(tp))
            count += 1

            # Cell
            # Axis have a specific order red:height, green:width, blue:length. For ellipsoid, only 3 axis are needed.
            ax_h = 0
            ax_w = 1
            ax_l = 2
            # === Cell height axis
            # Obtain x,y,z for height pole 1 & 2 [x, y, z] (select first and last value)
            pole_h_1 = np.asarray([ merged_axis6d[c_id][tp][ax_h][0,0], merged_axis6d[c_id][tp][ax_h][0,1], merged_axis6d[c_id][tp][ax_h][0,2] ])
            pole_h_2 = np.asarray([ merged_axis6d[c_id][tp][ax_h][-1,0], merged_axis6d[c_id][tp][ax_h][-1,1], merged_axis6d[c_id][tp][ax_h][-1,2] ])
            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip for first time point t0)
                if tp == 0:
                    pole_h = [pole_h_1, pole_h_2]
                    corrected_h = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_h_t1 = np.asarray([ merged_axis6d[c_id][tp-1][ax_h][0,0], merged_axis6d[c_id][tp-1][ax_h][0,1], merged_axis6d[c_id][tp-1][ax_h][0,2] ])
                    pole_h_t2 = np.asarray([ merged_axis6d[c_id][tp-1][ax_h][-1,0], merged_axis6d[c_id][tp-1][ax_h][-1,1], merged_axis6d[c_id][tp-1][ax_h][-1,2] ])

                    # Get previous and current poles
                    current_poles_h = np.vstack((pole_h_1, pole_h_2))
                    prev_poles_h = np.vstack((pole_h_t1, pole_h_t2))
                    # Apply Pole Tracking in 3D
                    pole_h_1, pole_h_2, corrected_h = SX_Model.pole_tracking3d_v3(current_poles_h, prev_poles_h ,'h')
                    pole_h = [pole_h_1, pole_h_2]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_h == 1:
                        merged_axis6d[c_id][tp][ax_h] = merged_axis6d[c_id][tp][ax_h][::-1]  
            else:
                pole_h = [pole_h_1, pole_h_2]
                corrected_h = 'nan'

            # === Cell width axis
            # Obtain x,y,z for width pole 1 & 2 [x, y, z]
            pole_w_1 = np.asarray([ merged_axis6d[c_id][tp][ax_w][0,0], merged_axis6d[c_id][tp][ax_w][0,1], merged_axis6d[c_id][tp][ax_w][0,2] ])
            pole_w_2 = np.asarray([ merged_axis6d[c_id][tp][ax_w][-1,0], merged_axis6d[c_id][tp][ax_w][-1,1], merged_axis6d[c_id][tp][ax_w][-1,2] ])       
            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip for first time point t0)
                if tp == 0:
                    pole_w = [pole_w_1, pole_w_2]
                    corrected_w = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_w_t1 = np.asarray([ merged_axis6d[c_id][tp-1][ax_w][0,0], merged_axis6d[c_id][tp-1][ax_w][0,1], merged_axis6d[c_id][tp-1][ax_w][0,2] ])
                    pole_w_t2 = np.asarray([ merged_axis6d[c_id][tp-1][ax_w][-1,0], merged_axis6d[c_id][tp-1][ax_w][-1,1], merged_axis6d[c_id][tp-1][ax_w][-1,2] ])
                    # Get previous and current poles
                    current_poles_w = np.vstack((pole_w_1, pole_w_2))
                    prev_poles_w = np.vstack((pole_w_t1, pole_w_t2))
                    # Apply Pole Tracking in 3D
                    pole_w_1, pole_w_2, corrected_w = SX_Model.pole_tracking3d_v3(current_poles_w, prev_poles_w, 'w')
                    pole_w = [pole_w_1, pole_w_2]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_w == 1:
                        merged_axis6d[c_id][tp][ax_w] = merged_axis6d[c_id][tp][ax_w][::-1]           
            else:
                pole_w = [pole_w_1, pole_w_2]
                corrected_w = 'nan'


            # === Cell length axis
            # Obtain x,y,z for length pole 1 & 2 [x, y, z]
            pole_l_1 = np.asarray([ merged_axis6d[c_id][tp][ax_l][0,0], merged_axis6d[c_id][tp][ax_l][0,1], merged_axis6d[c_id][tp][ax_l][0,2] ])
            pole_l_2 = np.asarray([ merged_axis6d[c_id][tp][ax_l][-1,0], merged_axis6d[c_id][tp][ax_l][-1,1], merged_axis6d[c_id][tp][ax_l][-1,2] ])

            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip for first time point t0)
                if tp == 0:
                    pole_l = [pole_l_1, pole_l_2]
                    corrected_l = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_l_t1 = np.asarray([ merged_axis6d[c_id][tp-1][ax_l][0,0], merged_axis6d[c_id][tp-1][ax_l][0,1], merged_axis6d[c_id][tp-1][ax_l][0,2] ])
                    pole_l_t2 = np.asarray([ merged_axis6d[c_id][tp-1][ax_l][-1,0], merged_axis6d[c_id][tp-1][ax_l][-1,1], merged_axis6d[c_id][tp-1][ax_l][-1,2] ])
                    # Get previous and current poles
                    current_poles_l = np.vstack((pole_l_1, pole_l_2))
                    prev_poles_l = np.vstack((pole_l_t1, pole_l_t2))
                    # Apply Pole Tracking in 3D
                    pole_l_1, pole_l_2, corrected_l = SX_Model.pole_tracking3d_v3(current_poles_l, prev_poles_l, 'l')
                    pole_l = [pole_l_1, pole_l_2]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_l == 1:
                        merged_axis6d[c_id][tp][ax_l] = merged_axis6d[c_id][tp][ax_l][::-1]
            else:
                pole_l = [pole_l_1, pole_l_2]
                corrected_l = 'nan'

            # Spindle
            # Axis have a specific order red:height, green:width, blue:length. For ellipsoid, only 3 axis are needed.
            ax_h_s = 0
            ax_w_s = 1
            ax_l_s = 2

            correct_spindle = 1
            #if correct_spindle == 2:
            #    # Obtain x,y,z for length pole 1 & 2 [x, y, z]
            #    pole_l_1_s = np.asarray([ merged_axis5d_s[c_id][tp][ax_l_s][0,0], merged_axis5d_s[c_id][tp][ax_l_s][0,1], merged_axis5d_s[c_id][tp][ax_l_s][0,2] ])
            #    pole_l_2_s = np.asarray([ merged_axis5d_s[c_id][tp][ax_l_s][-1,0], merged_axis5d_s[c_id][tp][ax_l_s][-1,1], merged_axis5d_s[c_id][tp][ax_l_s][-1,2] ])
            #    pole_l_s = [pole_l_1_s, pole_l_2_s]
            #    
            #    p3_a, p3_b, d_total, d_corrected = correct_spindle_pole(IM_MAX, pole_l_s[0], pole_l_s[1])
            #    # Ratio between corrected spindle length / uncorrected
            #    ratio_d = d_corrected/d_total
            #    # Select tp-th centroid of ellipsoid
            #    c_s = merged_center5d_s[c_id][tp]
            #    old_centr = c_s
            #    mid_point_cor = midpoint(p3_a, p3_b)
            #    # Update spindle length in data base
            #    merged_radii5d_s[c_id][tp][2] = d_corrected/2
            #    merged_center5d_s[c_id][tp] = mid_point_cor

            #    xyz_spin_s, xyz_spin_axis_s = ET.plotEllipsoid(merged_center5d_s[c_id][tp], merged_radii5d_s[c_id][tp], merged_rotation5d_s[c_id][tp], ax=ax, plotAxes=True, cageColor='k', cageAlpha=0.2)
            #    # Update centroid in data base
            #    merged_surface5d_s[c_id][tp] = xyz_spin_s
            #    merged_axis5d_s[c_id][tp] = xyz_spin_axis_s


            # === Spindle height axis
            # Obtain x,y,z for height pole 1 & 2 [x, y, z] (select first and last value)
            pole_h_1_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_h_s][0,0], merged_axis6d_s[c_id][tp][ax_h_s][0,1], merged_axis6d_s[c_id][tp][ax_h_s][0,2] ])
            pole_h_2_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_h_s][-1,0], merged_axis6d_s[c_id][tp][ax_h_s][-1,1], merged_axis6d_s[c_id][tp][ax_h_s][-1,2] ])
            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip for first time point t0)
                if tp == 0:
                    pole_h_s = [pole_h_1_s, pole_h_2_s]
                    corrected_h_s = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_h_t1_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_h_s][0,0], merged_axis6d_s[c_id][tp-1][ax_h_s][0,1], merged_axis6d_s[c_id][tp-1][ax_h_s][0,2] ])
                    pole_h_t2_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_h_s][-1,0], merged_axis6d_s[c_id][tp-1][ax_h_s][-1,1], merged_axis6d_s[c_id][tp-1][ax_h_s][-1,2] ])

                    # Get previous and current poles
                    current_poles_h_s = np.vstack((pole_h_1_s, pole_h_2_s))
                    prev_poles_h_s = np.vstack((pole_h_t1_s, pole_h_t2_s))
                    # Apply Pole Tracking in 3D
                    pole_h_1_s, pole_h_2_s, corrected_h_s = SX_Model.pole_tracking3d_v3(current_poles_h_s, prev_poles_h_s ,'h')
                    pole_h_s = [pole_h_1_s, pole_h_2_s]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_h_s == 1:
                        merged_axis6d_s[c_id][tp][ax_h_s] = merged_axis6d_s[c_id][tp][ax_h_s][::-1]  
            else:
                pole_h_s = [pole_h_1_s, pole_h_2_s]
                corrected_h_s = 'nan'

            # === Spindle width axis
            # Obtain x,y,z for width pole 1 & 2 [x, y, z]
            pole_w_1_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_w_s][0,0], merged_axis6d_s[c_id][tp][ax_w_s][0,1], merged_axis6d_s[c_id][tp][ax_w_s][0,2] ])
            pole_w_2_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_w_s][-1,0], merged_axis6d_s[c_id][tp][ax_w_s][-1,1], merged_axis6d_s[c_id][tp][ax_w_s][-1,2] ])       
            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip foar first time point t0)
                if tp == 0:
                    pole_w_s = [pole_w_1_s, pole_w_2_s]
                    corrected_w_s = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_w_t1_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_w_s][0,0], merged_axis6d_s[c_id][tp-1][ax_w_s][0,1], merged_axis6d_s[c_id][tp-1][ax_w_s][0,2] ])
                    pole_w_t2_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_w_s][-1,0], merged_axis6d_s[c_id][tp-1][ax_w_s][-1,1], merged_axis6d_s[c_id][tp-1][ax_w_s][-1,2] ])
                    # Get previous and current poles
                    current_poles_w_s = np.vstack((pole_w_1_s, pole_w_2_s))
                    prev_poles_w_s = np.vstack((pole_w_t1_s, pole_w_t2_s))
                    # Apply Pole Tracking in 3D
                    pole_w_1_s, pole_w_2_s, corrected_w_s = SX_Model.pole_tracking3d_v3(current_poles_w_s, prev_poles_w_s, 'w')
                    pole_w_s = [pole_w_1_s, pole_w_2_s]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_w_s == 1:
                        merged_axis6d_s[c_id][tp][ax_w_s] = merged_axis6d_s[c_id][tp][ax_w_s][::-1]           
            else:
                pole_w_s = [pole_w_1_s, pole_w_2_s]
                corrected_w_s = 'nan'


            # === Spindle length axis
            # Obtain x,y,z for length pole 1 & 2 [x, y, z]
            pole_l_1_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_l_s][0,0], merged_axis6d_s[c_id][tp][ax_l_s][0,1], merged_axis6d_s[c_id][tp][ax_l_s][0,2] ])
            pole_l_2_s = np.asarray([ merged_axis6d_s[c_id][tp][ax_l_s][-1,0], merged_axis6d_s[c_id][tp][ax_l_s][-1,1], merged_axis6d_s[c_id][tp][ax_l_s][-1,2] ])

            if tracker == 1:
                # Tracking algorithm to track individual spindle poles consistantly (Skip for first time point t0)
                if tp == 0:
                    pole_l_s = [pole_l_1_s, pole_l_2_s]
                    corrected_l_s = 0
                elif tp > 0:
                    # Get pole 1 and 2 from previous time point t-1
                    pole_l_t1_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_l_s][0,0], merged_axis6d_s[c_id][tp-1][ax_l_s][0,1], merged_axis6d_s[c_id][tp-1][ax_l_s][0,2] ])
                    pole_l_t2_s = np.asarray([ merged_axis6d_s[c_id][tp-1][ax_l_s][-1,0], merged_axis6d_s[c_id][tp-1][ax_l_s][-1,1], merged_axis6d_s[c_id][tp-1][ax_l_s][-1,2] ])
                    # Get previous and current poles
                    current_poles_l_s = np.vstack((pole_l_1_s, pole_l_2_s))
                    prev_poles_l_s = np.vstack((pole_l_t1_s, pole_l_t2_s))
                    # Apply Pole Tracking in 3D
                    pole_l_1_s, pole_l_2_s, corrected_l_s = SX_Model.pole_tracking3d_v3(current_poles_l_s, prev_poles_l_s, 'l')
                    pole_l_s = [pole_l_1_s, pole_l_2_s]
                    # After correction, update the merged_axis5d_s array as well by flipping arrays!!!
                    if corrected_l_s == 1:
                        merged_axis6d_s[c_id][tp][ax_l_s] = merged_axis6d_s[c_id][tp][ax_l_s][::-1]
            else:
                pole_l_s = [pole_l_1_s, pole_l_2_s]
                corrected_l_s = 'nan'



            if correct_spindle == 1:
                img_raw = array6d_spind_raw[:,:,:,tp,c_id, 0]
                IM_MAX = np.max(img_raw, axis=2)
                my_dpi=300
                #plt.figure(figsize=(IM_MAX.shape[0]/my_dpi, IM_MAX.shape[1]/my_dpi), dpi=my_dpi)
                figz, axz = plt.subplots()
                axz.axis('off')
                axz.imshow(IM_MAX, cmap='gray')
                axz.axes.get_xaxis().set_visible(False)
                axz.axes.get_yaxis().set_visible(False)
                axz.plot(pole_l_s[0][1]*pixels_per_micron_xy, pole_l_s[0][0]*pixels_per_micron_xy, 'ko')        
                axz.plot(pole_l_s[1][1]*pixels_per_micron_xy, pole_l_s[1][0]*pixels_per_micron_xy, color='darkorange', marker='o')
                axz.plot(merged_center6d_s[c_id][tp][1]*pixels_per_micron_xy, merged_center6d_s[c_id][tp][0]*pixels_per_micron_xy, color='r', marker='o')

                # Corrected
                #plt.plot(p3_a[1], p3_a[0], 'k*')
                #plt.plot(p3_b[1], p3_b[0], color='darkorange', marker='*')
                #plt.plot(mid_point_cor[1], mid_point_cor[0], color='m', marker='o')

                # Width
                #plt.plot(pole_w_s[0][1], pole_w_s[0][0], 'ro')        
                #plt.plot(pole_w_s[1][1], pole_w_s[1][0], color='r', marker='o')
                #fig.axes.get_xaxis().set_visible(False)
                #fig.axes.get_yaxis().set_visible(False)  

                name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
                #full_name_export = 'figs/' + exp_name + '/model/spindle/correct/' + name_export + '.pdf'
                full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'spindle', 'tracking', name_export + '.pdf')
                plt.savefig(full_name_export, dpi=300, bbox_inches='tight',transparent=True, pad_inches=0)
                plt.close('all')

            # Calculate coordinate differences through time
            pole_l_dif = []
            if tp == 0:
                pole_l_dif.append(pole_l_s[0] - pole_l_s[0])
                pole_l_dif.append(pole_l_s[1] - pole_l_s[1])
            else:
                pole_l_dif.append(pole_l_s[0] - temp_pole_l_s[-1][0])
                pole_l_dif.append(pole_l_s[1] - temp_pole_l_s[-1][1])


            # Merge all poles next to each other
            poles_tp_s = np.hstack((pole_h_s[0], pole_h_s[1], pole_w_s[0], pole_w_s[1], pole_l_s[0], pole_l_s[1]))
            temp_poles_s.append(poles_tp_s)
            # Append poles
            # Cortex
            temp_pole_h.append(pole_h)
            temp_pole_w.append(pole_w)
            temp_pole_l.append(pole_l)
            # Spindle
            temp_pole_h_s.append(pole_h_s)
            temp_pole_w_s.append(pole_w_s)
            temp_pole_l_s.append(pole_l_s)
            # Spindle centroid
            cent_s = SX_Model.midpoint(np.asarray(pole_l_s[0]), np.asarray(pole_l_s[1]))
            temp_centroid_s.append(cent_s)

            # Cell centroid
            cent_c = SX_Model.midpoint(np.asarray(pole_l[0]), np.asarray(pole_l[1]))
            temp_centroid.append(cent_c)

            # Spindle centroid displacement
            disp_centroid = SX_Model.centroid_displacement3d(temp_centroid)

            # Cell parameters
            # Select tp-th surface3d coords
            surface_3dcoords = merged_surface6d[c_id][tp]
            # Select tp-th centroid of ellipsoid
            c = merged_center6d[c_id][tp]
            # Select tp-th radii coords
            r = merged_radii6d[c_id][tp]
            # Select tp-th rotation matrix
            rot = merged_rotation6d[c_id][tp]         

            # Spindle parameters
            # Select tp-th centroid of ellipsoid
            c_s = merged_center6d_s[c_id][tp]
            # Select tp-th radii coords
            r_s = merged_radii6d_s[c_id][tp]        
            # Select tp-th rotation matrix Note: Rotation matrix is before tracking correction.
            # rot_s = merged_rotation5d_s[c_id][tp]

            # Extend spindle axis by 5 (500% of the initial axis length)
            length = 5
            p3_h = SX_Model.extend3d_line(pole_h_s[0], pole_h_s[1], length)
            p3_w = SX_Model.extend3d_line(pole_w_s[0], pole_w_s[1], length)
            p3_l = SX_Model.extend3d_line(pole_l_s[0], pole_l_s[1], length)

            # Use Ray-tracing method
            if rt_select == 0: # Heuristic
                intersect_h = SX_Model.hrt(surface_3dcoords, p3_h)
                intersect_w = SX_Model.hrt(surface_3dcoords, p3_w)
                intersect_l = SX_Model.hrt(surface_3dcoords, p3_l)
            elif rt_select == 1: # Analytical solution
                intersect_h = SX_Model.line_ellipsoid_intersection(p3_h, c, r)
                intersect_w = SX_Model.line_ellipsoid_intersection(p3_w, c, r)
                intersect_l = SX_Model.line_ellipsoid_intersection(p3_l, c, r)

            # - Joint
            # Calculate 3D pole-to-cortex distances d_h/w/l and length l_h/w/l
            d_h, l_h = SX_Model.pole_cortex_distance3d(pole_h_s, intersect_h)
            d_w, l_w = SX_Model.pole_cortex_distance3d(pole_w_s, intersect_w)
            d_l, l_l = SX_Model.pole_cortex_distance3d(pole_l_s, intersect_l)

            temp_d_h.append(d_h)
            temp_d_w.append(d_w)
            temp_d_l.append(d_l)
            # Get spindle height, width, length ratio
            temp_h_s.append(l_h)
            temp_w_s.append(l_w)
            temp_l_s.append(l_l)
            spindle_h_ratio = SX_Model.spindle_ratio(temp_h_s, norm_t=0)
            spindle_w_ratio = SX_Model.spindle_ratio(temp_w_s, norm_t=0)
            spindle_l_ratio = SX_Model.spindle_ratio(temp_l_s, norm_t=0)

            # Calculate 3D pole displacement (Pole 1 and 2 which is pole length)
            disp_l = SX_Model.pole_displacement3d(temp_pole_l_s)
            # Spindle centroid displacement
            disp_centroid_s = SX_Model.centroid_displacement3d(temp_centroid_s)
            temp_disp_l1.append(disp_l[0])
            temp_disp_l2.append(disp_l[1])

            # Calculate 3D MSD
            msd, msd_total = SX_Model.compute_msd(temp_pole_l_s)

            # Decompose spindle movement (spindle poles) in longitudinal (along x-axis), equatorial (along y-axis) and z-axis NOTE: Relative to image coordinate system
            decompose_l_1, decompose_l_2 = SX_Model.decomposition3d(temp_pole_l_s)

            # Decompose spindle movement (spindle poles) in longitudinal (along spindle long axis), equatorial (spindle width axis) and spindle height axis
            if tp == 0:
                decompose_l_rel2d = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
                decompose_l_rel3d = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
            else:
                temp_poles_t0_to_t1 = []
                temp_poles_t0 = temp_pole_l_s[tp-1]
                temp_poles_t1 = temp_pole_l_s[tp]
                temp_poles_t0_to_t1.append(temp_poles_t0) # Append pole coordinates at t0
                temp_poles_t0_to_t1.append(temp_poles_t1) # Append pole coordinates at t1
                rotatedLine = SX_Model.rotate_line3d(temp_poles_t0_to_t1, scale=False, plot=False)
                decompose_l_rel2d, proj_point = SX_Model.decomposition3d_v2(rotatedLine, plot=False, vers='2d')
                decompose_l_rel3d, proj_point = SX_Model.decomposition3d_v2(rotatedLine, plot=False, vers='3d')

            # Add column to identify pole1 or pole2 is larger

            # ==== Cell ====
            # Calculate volume of ellipsoid (Cell)
            vol = SX_Model.ellipsoid_volume(r)
            # Calculate surface area of ellipsoid (Cell)
            surf = SX_Model.ellipsoid_surface_area(r)
            # Calculate sphericity (Cell)
            sphericity = SX_Model.sphericity3d(vol, surf)
            # Compute relative Eulers angle (intrinsic) and quaternions
            if tp == 0:
                angles = np.array([0.0, 0.0, 0.0])
                rad = np.array([0.0, 0.0, 0.0])
                quat = np.array([0.0, 0.0, 0.0, 0.0])
                dist3d = 0.0
                angles_list.append(angles)
            else:
                vec3d_t0 = [temp_pole_h[tp-1], temp_pole_w[tp-1], temp_pole_l[tp-1]]
                vec3d_t1 = [temp_pole_h[tp], temp_pole_w[tp], temp_pole_l[tp]]
                # We use an older version for scipy (1.3.2), else: R.from_matrix(rot_s)
                # https://acko.net/blog/animate-your-way-to-glory-pt2/#quaternions
                angles, rad, quat, dist3d = SX_Model.rel_angle3d(vec3d_t0, vec3d_t1) # Quat: [x,y,z,w]
                angles_list.append(angles)


            # ==== Spindle ====
            # Calculate volume of ellipsoid (Spindle)
            vol_s = SX_Model.ellipsoid_volume(r_s)
            # Calculate surface area of ellipsoid (Spindle)
            surf_s = SX_Model.ellipsoid_surface_area(r_s)
            # Calculate sphericity (Cell)
            sphericity_s = SX_Model.sphericity3d(vol_s, surf_s)

            # Compute relative Eulers angle (intrinsic) and quaternions
            if tp == 0:
                angles_s = np.array([0.0, 0.0, 0.0])
                rad_s = np.array([0.0, 0.0, 0.0])
                quat_s = np.array([0.0, 0.0, 0.0, 0.0])
                dist3d_s = 0.0
                angles_list_s.append(angles_s)
            else:
                vec3d_t0 = [temp_pole_h_s[tp-1], temp_pole_w_s[tp-1], temp_pole_l_s[tp-1]]
                vec3d_t1 = [temp_pole_h_s[tp], temp_pole_w_s[tp], temp_pole_l_s[tp]]
                # We use an older version for scipy (1.3.2), else: R.from_matrix(rot_s)
                # https://acko.net/blog/animate-your-way-to-glory-pt2/#quaternions
                angles_s, rad_s, quat_s, dist3d_s = SX_Model.rel_angle3d(vec3d_t0, vec3d_t1) # Quat: [x,y,z,w]
                angles_list_s.append(angles_s)

            # Create Pole to Cortex plot here
            data_raw_1, data_raw_2, time_raw_mins_1, time_raw_mins_2 = SX_Model.plot_pole_cortex(data_raw_1, data_raw_2, time_raw_mins_1, time_raw_mins_2, c_id, tp, d_l[0], d_l[1], n_frames, exp_interval, exp_name, OUTPUT_DIR)

            # Extract 3D stacks of membrane and spindle images and fuse their maximum projection
            stack_spindle = array6d_spind_raw[:,:,:,tp,c_id,0]
            stack_membrane = array6d_mem_raw[:,:,:,tp,c_id,0]
            # Membrane contour
            focus_cont_membrane = contours6d[c_id][tp][focal_slice]
            fused_img = SX_Model.fuse_img(stack_spindle, stack_membrane, focus_cont_membrane, c_id, tp, exp_interval, pixels_per_micron_xy, exp_name, OUTPUT_DIR)
            # imageio.imwrite('figs/' + exp_name + '/video_frames/fused/overlay_c_' + str(c_id) + '_tp_' + str(tp) + '.png', fused_img)
            fused_path = os.path.join(OUTPUT_DIR, exp_name, 'video_frames', 'fused', 'overlay_c_' + str(c_id) + '_tp_' + str(tp) + '.png')
            imageio.imwrite(fused_path, fused_img)
            fused_all.append(fused_img)

            # Extract filename
            # Write to pandas dataframe

            # ==== Meta data
            df.at[count, 'N'] = count
            df.at[count, 'filename'] = name_list_raw[c_id][tp][focal_slice]
            df.at[count, 'exp_set'] = exp_set

            if c_id == n_group_a:
                cell_count = 0 # Reset
                cond_idx = 1

            condition = condition_group[cond_idx]
            df.at[count, 'condition'] = condition
            df.at[count, 'img_height'] = stack_spindle.shape[0]
            df.at[count, 'img_width'] = stack_spindle.shape[1]
            df.at[count, 'img_dim'] = stack_spindle.shape[2]
            df.at[count, 'wavelength'] = wavelength
            df.at[count, 'pixels_per_micron_xy'] = pixels_per_micron_xy
            df.at[count, 'gap_micron_z'] = pixels_per_micron_z
            df.at[count, 'cell_id_total'] = c_id
            df.at[count, 'cell_id'] = cell_count


            df.at[count, 'time_point'] = tp
            df.at[count, 'time_point_mins'] = tp*exp_interval

            # ==== Cell
            df.at[count, 'cell_centroid_x'] = c[0]
            df.at[count, 'cell_centroid_y'] = c[1]
            df.at[count, 'cell_centroid_z'] = c[2]
            df.at[count, 'cell_pole_h_x1'] = pole_h[0][0]
            df.at[count, 'cell_pole_h_y1'] = pole_h[0][1]
            df.at[count, 'cell_pole_h_z1'] = pole_h[0][2]
            df.at[count, 'cell_pole_h_x2'] = pole_h[1][0]
            df.at[count, 'cell_pole_h_y2'] = pole_h[1][1]
            df.at[count, 'cell_pole_h_z2'] = pole_h[1][2]
            df.at[count, 'cell_pole_w_x1'] = pole_w[0][0]
            df.at[count, 'cell_pole_w_y1'] = pole_w[0][1]
            df.at[count, 'cell_pole_w_z1'] = pole_w[0][2]
            df.at[count, 'cell_pole_w_x2'] = pole_w[1][0]
            df.at[count, 'cell_pole_w_y2'] = pole_w[1][1]
            df.at[count, 'cell_pole_w_z2'] = pole_w[1][2]
            df.at[count, 'cell_pole_l_x1'] = pole_l[0][0]
            df.at[count, 'cell_pole_l_y1'] = pole_l[0][1]
            df.at[count, 'cell_pole_l_z1'] = pole_l[0][2]
            df.at[count, 'cell_pole_l_x2'] = pole_l[1][0]
            df.at[count, 'cell_pole_l_y2'] = pole_l[1][1]
            df.at[count, 'cell_pole_l_z2'] = pole_l[1][2]
            df.at[count, 'cell_2d_major_axis'] = info6d[c_id][tp][max_idx][0]
            df.at[count, 'cell_2d_minor_axis'] = info6d[c_id][tp][max_idx][1]
            # Find out which slice has the largest diameter (equals to mid-plane of the object)
            max_idx, _ = SX.find_diameter(np.array(info6d[c_id][tp])[:,1])
            df.at[count, 'cell_axis_ratio'] = info6d[c_id][tp][max_idx][0] / info6d[c_id][tp][max_idx][1] #If > 1 (major axis is longer)

            df.at[count, 'cell_axis_a'] = r[0]*2
            df.at[count, 'cell_axis_b'] = r[1]*2
            df.at[count, 'cell_axis_c'] = r[2]*2

            df.at[count, 'cell_volume'] = vol
            df.at[count, 'cell_area'] = surf
            df.at[count, 'cell_sphericity'] = sphericity
            df.at[count, 'cell_eangle_alpha_deg'] = angles[0] # roll: Cell rolling (along the x-axis)
            df.at[count, 'cell_eangle_beta_deg'] = angles[1] # pitch: Cell tilting (along the y-axis)
            df.at[count, 'cell_eangle_gamma_deg'] = angles[2] # yaw: Cell rotation (along the z-axis)
            df.at[count, 'cell_eangle_alpha_rad'] = rad[0] # roll: Cell rolling (along the x-axis)
            df.at[count, 'cell_eangle_beta_rad'] = rad[1] # pitch: Cell tilting (along the y-axis)
            df.at[count, 'cell_eangle_gamma_rad'] = rad[2] # yaw: Cell rotation (along the z-axis)
            df.at[count, 'cell_qangle_x'] = quat[0]
            df.at[count, 'cell_qangle_y'] = quat[1]
            df.at[count, 'cell_qangle_z'] = quat[2]        
            df.at[count, 'cell_qangle_w'] = quat[3]
            df.at[count, 'cell_translation_dist'] = dist3d

            # ==== Spindle
            df.at[count, 'spindle_centroid_x'] = c_s[0]
            df.at[count, 'spindle_centroid_y'] = c_s[1]
            df.at[count, 'spindle_centroid_z'] = c_s[2]

            df.at[count, 'spindle_pole_h_x1'] = pole_h_s[0][0]
            df.at[count, 'spindle_pole_h_y1'] = pole_h_s[0][1]
            df.at[count, 'spindle_pole_h_z1'] = pole_h_s[0][2]
            df.at[count, 'spindle_pole_h_x2'] = pole_h_s[1][0]
            df.at[count, 'spindle_pole_h_y2'] = pole_h_s[1][1]
            df.at[count, 'spindle_pole_h_z2'] = pole_h_s[1][2]
            df.at[count, 'spindle_pole_w_x1'] = pole_w_s[0][0]
            df.at[count, 'spindle_pole_w_y1'] = pole_w_s[0][1]
            df.at[count, 'spindle_pole_w_z1'] = pole_w_s[0][2]
            df.at[count, 'spindle_pole_w_x2'] = pole_w_s[1][0]
            df.at[count, 'spindle_pole_w_y2'] = pole_w_s[1][1]
            df.at[count, 'spindle_pole_w_z2'] = pole_w_s[1][2]
            df.at[count, 'spindle_pole_l_x1'] = pole_l_s[0][0]
            df.at[count, 'spindle_pole_l_y1'] = pole_l_s[0][1]
            df.at[count, 'spindle_pole_l_z1'] = pole_l_s[0][2]
            df.at[count, 'spindle_pole_l_x2'] = pole_l_s[1][0]
            df.at[count, 'spindle_pole_l_y2'] = pole_l_s[1][1]
            df.at[count, 'spindle_pole_l_z2'] = pole_l_s[1][2]
            df.at[count, 'spindle_pole_l_x1_dif'] = pole_l_dif[0][0]
            df.at[count, 'spindle_pole_l_y1_dif'] = pole_l_dif[0][1]
            df.at[count, 'spindle_pole_l_z1_dif'] = pole_l_dif[0][2]
            df.at[count, 'spindle_pole_l_x2_dif'] = pole_l_dif[1][0]
            df.at[count, 'spindle_pole_l_y2_dif'] = pole_l_dif[1][1]
            df.at[count, 'spindle_pole_l_z2_dif'] = pole_l_dif[1][2]

            df.at[count, 'spindle_pole_h_corrected'] = corrected_h_s
            df.at[count, 'spindle_pole_w_corrected'] = corrected_w_s
            df.at[count, 'spindle_pole_l_corrected'] = corrected_l_s

            df.at[count, 'spindle_axis_h'] = l_h
            df.at[count, 'spindle_axis_w'] = l_w
            df.at[count, 'spindle_axis_l'] = l_l
            df.at[count, 'spindle_axis_h_ratio'] = spindle_h_ratio
            df.at[count, 'spindle_axis_w_ratio'] = spindle_w_ratio
            df.at[count, 'spindle_axis_l_ratio'] = spindle_l_ratio


            df.at[count, 'spindle_volume'] = vol_s
            df.at[count, 'spindle_area'] = surf_s
            df.at[count, 'spindle_sphericity'] = sphericity_s

            df.at[count, 'spindle_eangle_alpha_deg'] = angles_s[0] # roll: Spindle rolling (along the x-axis)
            df.at[count, 'spindle_eangle_beta_deg'] = angles_s[1] # pitch: Spindle tilting (along the y-axis)
            df.at[count, 'spindle_eangle_gamma_deg'] = angles_s[2] # yaw: Spindle rotation (along the z-axis)
            df.at[count, 'spindle_eangle_alpha_rad'] = rad_s[0] # roll: Spindle rolling (along the x-axis)
            df.at[count, 'spindle_eangle_beta_rad'] = rad_s[1] # pitch: Spindle tilting (along the y-axis)
            df.at[count, 'spindle_eangle_gamma_rad'] = rad_s[2] # yaw: Spindle rotation (along the z-axis)
            df.at[count, 'spindle_qangle_x'] = quat_s[0]
            df.at[count, 'spindle_qangle_y'] = quat_s[1]
            df.at[count, 'spindle_qangle_z'] = quat_s[2]        
            df.at[count, 'spindle_qangle_w'] = quat_s[3]
            df.at[count, 'spindle_translation_dist'] = dist3d_s

            df.at[count, 'pole_cortex_dist_h_1'] = d_h[0]
            df.at[count, 'pole_cortex_dist_h_2'] = d_h[1]
            df.at[count, 'pole_cortex_dist_w_1'] = d_w[0]
            df.at[count, 'pole_cortex_dist_w_2'] = d_w[1]
            df.at[count, 'pole_cortex_dist_l_1'] = d_l[0]
            df.at[count, 'pole_cortex_dist_l_2'] = d_l[1]

            df.at[count, 'pole_cortex_close_h_1'] = 0
            df.at[count, 'pole_cortex_close_h_2'] = 0
            df.at[count, 'pole_cortex_close_w_1'] = 0
            df.at[count, 'pole_cortex_close_w_2'] = 0
            df.at[count, 'pole_cortex_close_l_1'] = 0
            df.at[count, 'pole_cortex_close_l_2'] = 0

            df.at[count, 'pole_cortex_velo_h_1'] = d_h[0]
            df.at[count, 'pole_cortex_velo_h_2'] = d_h[1]
            df.at[count, 'pole_cortex_velo_w_1'] = d_w[0]
            df.at[count, 'pole_cortex_velo_w_2'] = d_w[1]
            df.at[count, 'pole_cortex_velo_l_1'] = d_l[0]
            df.at[count, 'pole_cortex_velo_l_2'] = d_l[1]

            df.at[count, 'spindle_pole_disp_1'] = disp_l[0]
            df.at[count, 'spindle_pole_disp_2'] = disp_l[1]

            #df.at[count, 'spindle_centroid_disp'] = disp_centroid_s

            df.at[count, 'spindle_pole_msd_1'] = msd[0]
            df.at[count, 'spindle_pole_msd_2'] = msd[1]
            df.at[count, 'spindle_pole_msd_total'] = msd_total


            df.at[count, 'spindle_pole_long_disp_1'] = decompose_l_1[0]
            df.at[count, 'spindle_pole_eq_disp_1'] = decompose_l_1[1]
            df.at[count, 'spindle_pole_z_disp_1'] = decompose_l_1[2]
            df.at[count, 'spindle_pole_long_disp_2'] = decompose_l_2[0]
            df.at[count, 'spindle_pole_eq_disp_2'] = decompose_l_2[1]
            df.at[count, 'spindle_pole_z_disp_2'] = decompose_l_2[2]

            df.at[count, 'spindle_pole_long_disp_1_rel2d'] = decompose_l_rel2d[0][0]
            df.at[count, 'spindle_pole_eq_disp_1_rel2d'] = decompose_l_rel2d[0][1]    
            df.at[count, 'spindle_pole_z_disp_1_rel2d'] = decompose_l_rel2d[0][2]
            df.at[count, 'spindle_pole_long_disp_2_rel2d'] = decompose_l_rel2d[1][0]
            df.at[count, 'spindle_pole_eq_disp_2_rel2d'] = decompose_l_rel2d[1][1]
            df.at[count, 'spindle_pole_z_disp_2_rel2d'] = decompose_l_rel2d[1][2]

            df.at[count, 'spindle_pole_long_disp_1_rel3d'] = decompose_l_rel3d[0][0]
            df.at[count, 'spindle_pole_eq_disp_1_rel3d'] = decompose_l_rel3d[0][1]
            df.at[count, 'spindle_pole_z_disp_1_rel3d'] = decompose_l_rel3d[0][2]
            df.at[count, 'spindle_pole_long_disp_2_rel3d'] = decompose_l_rel3d[1][0]
            df.at[count, 'spindle_pole_eq_disp_2_rel3d'] = decompose_l_rel3d[1][1]
            df.at[count, 'spindle_pole_z_disp_2_rel3d'] = decompose_l_rel3d[1][2]


            df.at[count, 'spindle_pole_velo_1'] = disp_l[0]/exp_interval
            df.at[count, 'spindle_pole_velo_2'] = disp_l[1]/exp_interval

            df.at[count, 'spindle_pole_long_velo_1'] = decompose_l_1[0]/exp_interval
            df.at[count, 'spindle_pole_eq_velo_1'] = decompose_l_1[1]/exp_interval        
            df.at[count, 'spindle_pole_z_velo_1'] = decompose_l_1[2]/exp_interval
            df.at[count, 'spindle_pole_long_velo_2'] = decompose_l_2[0]/exp_interval
            df.at[count, 'spindle_pole_eq_velo_2'] = decompose_l_2[1]/exp_interval
            df.at[count, 'spindle_pole_z_velo_2'] = decompose_l_2[2]/exp_interval

            # Store row id per cell
            temp_count.append(count)
            # Plot figures
            fig = plt.figure(figsize=(15,15))
            ax = fig.add_subplot(111, projection='3d')


            if old_3d == 1:
                #scaling = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
                #ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3)
                ax.set_aspect("equal") # Works only with older matplotlib==3.0.2 (unsolved bug with 3.3.1)
            else:
                xs = np.array([0,512])
                ys = np.array([0,512])
                zs = np.array([0,512])
                ax.set_box_aspect((np.ptp(xs), np.ptp(ys), np.ptp(zs)))  # aspect ratio is 1:1:1

            # draw voxel
            rx = [0, voxel_x]
            ry = [0, voxel_y]
            rz = [0-(voxel_z/2), (voxel_z/2)]
            for s, e in combinations(np.array(list(product(rx, ry, rz))), 2):
                if np.sum(np.abs(s-e)) == rx[1]-rx[0]:
                    ax.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)
                if np.sum(np.abs(s-e)) == rz[1]-rz[0]:
                    ax.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)    
            # Set axes limits
            ax.set(xlim=(rx[0], rx[1]), ylim=(ry[0], ry[1]), zlim=(rz[0], rz[1]))
            # Change view
            ax.view_init(elev=30, azim=60)

            # Increase thickness
            #for axis in [ax.w_xaxis, ax.w_yaxis, ax.w_zaxis]:
            #    axis.line.set_linewidth(8)

            # Axes style
            # make the grid lines transparent
            ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
            ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
            ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

            # Plot cell
            ET.plotEllipsoid(merged_center6d[c_id][tp], merged_radii6d[c_id][tp], merged_rotation6d[c_id][tp], ax=ax, plotAxes=False, cageColor='k', cageAlpha=0.2)
            # Set axes limits

            for i in range(3):
                ax.plot(merged_axis6d_s[c_id][tp][i][:,0], merged_axis6d_s[c_id][tp][i][:,1], merged_axis6d_s[c_id][tp][i][:,2], axis_color[i], linewidth=5)


            # Using ray trace algorithm
            ax.plot([intersect_h[0][0],intersect_h[1][0]], [intersect_h[0][1],intersect_h[1][1]], [intersect_h[0][2],intersect_h[1][2]], color='m', linestyle='solid', marker='o',
         markerfacecolor='navy', markeredgecolor='navy', markersize=5)
            ax.plot([intersect_w[0][0],intersect_w[1][0]], [intersect_w[0][1],intersect_w[1][1]], [intersect_w[0][2],intersect_w[1][2]], color='m', linestyle='solid', marker='o',
         markerfacecolor='navy', markeredgecolor='navy', markersize=5)
            ax.plot([intersect_l[0][0],intersect_l[1][0]], [intersect_l[0][1],intersect_l[1][1]], [intersect_l[0][2],intersect_l[1][2]], color='m', linestyle='solid', marker='o',
         markerfacecolor='navy', markeredgecolor='navy', markersize=5)


            # Optional: Control dash axis
            ax.plot([p3_h[0][0],p3_h[1][0]], [p3_h[0][1],p3_h[1][1]], [p3_h[0][2],p3_h[1][2]], color='silver',linestyle='dashed')
            ax.plot([p3_w[0][0],p3_w[1][0]], [p3_w[0][1],p3_w[1][1]], [p3_w[0][2],p3_w[1][2]], color='silver',linestyle='dashed')
            ax.plot([p3_l[0][0],p3_l[1][0]], [p3_l[0][1],p3_l[1][1]], [p3_l[0][2],p3_l[1][2]], color='silver',linestyle='dashed')

            # Spindle length pole 1 and 2
            ax.plot([pole_l_s[0][0]], [pole_l_s[0][1]], [pole_l_s[0][2]], color='darkorange', marker='o', markersize=10, fillstyle='full')
            ax.plot([pole_l_s[1][0]], [pole_l_s[1][1]], [pole_l_s[1][2]], color='black', marker='o', markersize=10, fillstyle='full')

            # Spindle height pole 1 and 2
            ax.plot([pole_h_s[0][0]], [pole_h_s[0][1]], [pole_h_s[0][2]], color='darkorange', marker='o', markersize=5, fillstyle='full')
            ax.plot([pole_h_s[1][0]], [pole_h_s[1][1]], [pole_h_s[1][2]], color='black', marker='o', markersize=5, fillstyle='full')

            # Spindle width pole 1 and 2
            ax.plot([pole_w_s[0][0]], [pole_w_s[0][1]], [pole_w_s[0][2]], color='darkorange', marker='o', markersize=5, fillstyle='full')
            ax.plot([pole_w_s[1][0]], [pole_w_s[1][1]], [pole_w_s[1][2]], color='black', marker='o', markersize=5, fillstyle='full')


            ax.tick_params(axis='both', which='major', labelsize=20, pad=15)
            ax.tick_params(axis='both', which='minor', labelsize=20, pad=15)

            # Disable z-label rotation
            ax.zaxis.set_rotate_label(False)  # disable automatic rotation

            # Increase fontsize
            ax.set_xlabel('$\mathbf{x}$', fontsize=30, labelpad=10)
            ax.set_ylabel('$\mathbf{y}$', fontsize=30, labelpad=15)
            ax.set_zlabel('$\mathbf{z}$', fontsize=30, labelpad=15, rotation=0)


            # Convert plot to image

            io_buf = BytesIO()
            fig.savefig(io_buf, format='raw')
            io_buf.seek(0)
            plot_to_image = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                                 newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
            io_buf.close()
            plot_all.append(plot_to_image) # Store in list

            # Save
            name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
            #full_name_export = 'figs/' + exp_name + '/model/complete_model/' + name_export + '.pdf'
            full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'complete_model', name_export + '.pdf')
            plt.savefig(full_name_export, dpi=600, bbox_inches='tight', transparent=True)
            if plot_on == 1:
                plt.show()
            plt.close(fig)
            del fig

            # Save segmentation plot
            # Plot figures
            fig2 = plt.figure(figsize=(15,15))
            ax2 = fig2.add_subplot(111, projection='3d')
            if old_3d == 1:
                #scaling = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
                #ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3)
                ax2.set_aspect("equal") # Works only with older matplotlib==3.0.2 (unsolved bug with 3.3.1)
            else:
                xs = np.array([0,voxel_x])
                ys = np.array([0,voxel_y])
                zs = np.array([0,voxel_z])
                ax2.set_box_aspect((np.ptp(xs), np.ptp(ys), np.ptp(zs)))  # aspect ratio is 1:1:1

            # draw voxel
            rx = [c[0] - voxel_x/4, c[0] + voxel_x/4]
            ry = [c[1] - voxel_y/4, c[1] + voxel_y/4]
            rz = [c[2] - voxel_z/4, c[2] + voxel_z/4]
            for s, e in combinations(np.array(list(product(rx, ry, rz))), 2):
                if np.sum(np.abs(s-e)) == rx[1]-rx[0]:
                    ax2.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)
                if np.sum(np.abs(s-e)) == rz[1]-rz[0]:
                    ax2.plot3D(*zip(s, e), "k--", linewidth=1, antialiased=True)

            # Set axes limits
            ax2.set(xlim=(rx[0], rx[1]), ylim=(ry[0], ry[1]), zlim=(rz[0], rz[1]))

            # Change view
            ax2.view_init(elev=30, azim=60)

            plot_axes = 0
            if plot_axes == 1:
                # Increase thickness
                #for axis in [ax2.w_xaxis, ax2.w_yaxis, ax2.w_zaxis]:
                #    axis.line.set_linewidth(8)

                # Axes style
                ax2.xaxis._axinfo["grid"]['color'] = 'whitesmoke'
                ax2.yaxis._axinfo["grid"]['color'] = 'whitesmoke'
                ax2.zaxis._axinfo["grid"]['color'] = 'whitesmoke'

                ax2.tick_params(axis='both', which='major', labelsize=20, pad=15)
                ax2.tick_params(axis='both', which='minor', labelsize=20, pad=15)

                # Disable z-label rotation
                ax2.zaxis.set_rotate_label(False)  # disable automatic rotation

                # Increase fontsize
                ax2.set_xlabel('$\mathbf{x}$', fontsize=30, labelpad=20)
                ax2.set_ylabel('$\mathbf{y}$', fontsize=30, labelpad=20)
                ax2.set_zlabel('$\mathbf{z}$', fontsize=30, labelpad=20, rotation=0)   
            else:
                ax2.axis('off')

            for i in range(3):
                ax2.plot(merged_axis6d_s[c_id][tp][i][:,0], merged_axis6d_s[c_id][tp][i][:,1], merged_axis6d_s[c_id][tp][i][:,2], axis_color[i], linewidth=5)
            # plot points
            ax2.scatter(merged_array_s[c_id][tp][:,0], merged_array_s[c_id][tp][:,1], merged_array_s[c_id][tp][:,2], color='m', marker='*', s=3)
            ET.plotEllipsoid(merged_center6d_s[c_id][tp], merged_radii6d_s[c_id][tp], merged_rotation6d_s[c_id][tp], ax=ax2, plotAxes=False, cageColor='k', cageAlpha=0.2)

            # Spindle length pole 1 and 2
            ax2.plot([pole_l_s[0][0]], [pole_l_s[0][1]], [pole_l_s[0][2]], color='darkorange', marker='o', markersize=10, fillstyle='full')
            ax2.plot([pole_l_s[1][0]], [pole_l_s[1][1]], [pole_l_s[1][2]], color='black', marker='o', markersize=10, fillstyle='full')

            # Save
            name_export = 'cell_' + str(c_id) + '_tp_' + str(tp)
            #full_name_export = 'figs/' + exp_name + '/model/spindle/axes_no/' + name_export + '.pdf'
            full_name_export = os.path.join(OUTPUT_DIR, exp_name, 'model' , 'spindle', 'axes_no', name_export + '.pdf')
            plt.savefig(full_name_export, dpi=300, bbox_inches='tight', transparent=True)  

            if plot_on == 1:
                plt.show()
            plt.close(fig2)
            del fig2    

        cell_count += 1


        # Do for each cell
        # Merge all pole coordinates through time to cells
        poles_s.append(temp_poles_s)                
        disp_l1_all.append(np.asarray(temp_disp_l1))
        disp_l2_all.append(np.asarray(temp_disp_l2))
        # Merge Pole-Cortex distances
        merged_temp_d_l.append(temp_d_l)
        # Merge poles
        merged_pole_h_s.append(temp_pole_h_s)
        merged_pole_w_s.append(temp_pole_w_s)
        merged_pole_l_s.append(temp_pole_l_s)
        # Merge centroids
        merged_centroid.append(temp_centroid) # Cell
        merged_centroid_s.append(temp_centroid_s) # Spindle
        # Search for time points where spindles are very close to their cortex
        idx_close_h = SX_Model.close_cortex(temp_d_h, thres_dist_cortex, num_poles=2)
        for i in range(len(idx_close_h)):
            for j, idx in enumerate(idx_close_h[i]):
                if i == 0:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_h_1'] = 1
                elif i == 1:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_h_2'] = 1    

        idx_close_w = SX_Model.close_cortex(temp_d_w, thres_dist_cortex, num_poles=2)
        for i in range(len(idx_close_w)):
            for j, idx in enumerate(idx_close_w[i]):
                if i == 0:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_w_1'] = 1
                elif i == 1:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_w_2'] = 1

        idx_close_l = SX_Model.close_cortex(temp_d_l, thres_dist_cortex, num_poles=2)
        for i in range(len(idx_close_l)):
            for j, idx in enumerate(idx_close_l[i]):
                if i == 0:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_l_1'] = 1
                elif i == 1:
                    # Set value to 1 if pole is close to cortex
                    df.loc[temp_count[idx], 'pole_cortex_close_l_2'] = 1                




        # Export fused images as video
        #imageio.mimsave('figs/' + exp_name + '/video/raw/fused_cell_' + str(c_id) + '.gif', fused_all, fps=55)
        #imageio.mimsave('figs/' + exp_name + '/video/model/model_cell_' + str(c_id) + '.gif', plot_all, fps=55)
        fused_video_path = os.path.join(OUTPUT_DIR, exp_name, 'video', 'raw', 'fused_cell_' + str(c_id) + '.gif')
        imageio.mimsave(fused_video_path, fused_all, fps=55)
        fused_model_path = os.path.join(OUTPUT_DIR, exp_name, 'video', 'model', 'model_cell_' + str(c_id) + '.gif')
        imageio.mimsave(fused_model_path, plot_all, fps=55)

        # End timer
        end = time.time()
        e_time = end - start
        print('Elapsed time for one cell: %f seconds' %(round(e_time,3)))
        print() 
        print() 
        total_time.append(e_time)

    # Apply filtering after (test when data set is large)
    #if len(disp_l1_all[0]) > 2:
    #    x_bar_d1, y_hat_d1, rmse_d1, opt_window_size_d1 = SX_Model.data_filter_multi(disp_l1_all,n_frames//exp_interval) # Spindle pole1 displacement

    # Export csv file
    csv_filename = "{}{:%Y%m%dT%H%M}.csv".format("modelling6d_" + str(exp_name) + '_', datetime.datetime.now())
    df.to_csv(os.path.join(OUTPUT_CSV_DIR, csv_filename), index=False)



    # Convert seconds to hh:mm:ss
    hours, seconds =  sum(total_time) // 3600, sum(total_time) % 3600
    minutes, seconds = sum(total_time) // 60, sum(total_time) % 60

    # Runtime calculations
    total_runtime = str(f"{round(hours):02d}" + "h " + f"{round(minutes):02d}" + "mins " + f"{round(seconds):02d}" + "secs")
    avg_runtime = str( round(sum(total_time)/len(total_time),3) ) + " seconds"
    # Variances cant be computed with only 1 value (assigned to 0)
    if len(total_time) < 2:
        var_runtime = str(0) + " seconds"
    else:
        var_runtime = str( round(stdev(total_time), 3) ) + " seconds"

        # Change alignment for adding more columns 'c'
    pred_tab = tt.to_string(
        [[ n_cells, total_runtime, avg_runtime, var_runtime ]],
        header=["N", "Total run time:", "Avg. run time for one cell:", "SD of run time:"],
        style=tt.styles.ascii_thin_double,
        alignment="lccr",
        # padding=(0, 1),
    )
    print(pred_tab)
    print('=== 3D Modelling: Completed ===')
    
    # Create a ZIP file of the OUTPUT_DIR
    print('=== Export as ZIP archive: Start ===')
    # Destination of ZIP file (in ROOT DIR)
    output_zip_name = exp_acronym + '_' + exp_name
    shutil.make_archive(output_zip_name, 'zip', OUTPUT_BASE_DIR)
    output_zip_name_full = output_zip_name + '.zip'
    print('=== Export as ZIP archive: Completed ===')
    return {'output_zip': output_zip_name_full}


In [10]:
run_locally = 1

if run_locally == 1:
    # Local testing
    ROOT_DIR = os.path.abspath(os.getcwd())
    # Membrane
    MEMBRANE_DIR = os.path.join(ROOT_DIR, 'input/cell_cortex')
    MEMBRANE_DIR = os.path.join(ROOT_DIR, 'input_ome_tiff/cell_cortex') # OME-TIFF
    MEMBRANE_RAW_DIR = os.path.join(MEMBRANE_DIR, 'raw')
    MEMBRANE_BIN_DIR = os.path.join(MEMBRANE_DIR, 'mask')
    membrane_list_raw, ftype_cell_r = SX.get_list_dir(MEMBRANE_RAW_DIR)
    membrane_list_mask, ftype_cell_m = SX.get_list_dir(MEMBRANE_BIN_DIR)

    # Spindle
    SPINDLE_DIR = os.path.join(ROOT_DIR, 'input/spindle')
    SPINDLE_DIR = os.path.join(ROOT_DIR, 'input_ome_tiff/spindle') # OME-TIFF
    SPINDLE_RAW_DIR = os.path.join(SPINDLE_DIR, 'raw')
    SPINDLE_BIN_DIR = os.path.join(SPINDLE_DIR, 'mask')
    spindle_list_raw, ftype_spin_r = SX.get_list_dir(SPINDLE_RAW_DIR)
    spindle_list_mask, ftype_spin_m = SX.get_list_dir(SPINDLE_BIN_DIR)


    execute(
        input_cell_raw = membrane_list_raw, # Input raw images of cell cortex
        input_cell_mask = membrane_list_mask, # Segmentation masks of cell cortex
        input_spindle_raw = spindle_list_raw, # Input raw images of spindle
        input_spindle_mask = spindle_list_mask, # Segmentation masks of spindle
        exp_acronym = 'mark2i', # Acronym for experiment
        exp_name = 'exp2021-001-set001', # Experimental name
        exp_set = 1, # Set set number of experiment
        exp_group_a = 'control', # Label for condition A
        n_group_a = 1, # Number of movies in condition A
        exp_group_b = 'mark2i', # Label for condition B
        n_group_b = 1, # Number of movies in condition B
        n_frames = 5, # Number of frames
        n_slices = 3, # Number of slices
        exp_interval = 3, # Temporal resolution of the movie
        pixel_x = 0.11020, # Pixel size X (If metadata exists )
        pixel_y = 0.11020, # Pixel size Y
        pixel_z = 2, # Pixel size Z
        voxel_x = 512, # Voxel width
        voxel_y = 512, # Voxel heights
        voxel_z = 512, # Voxel depth
        wavelength = 0.605, # Wavelength in microns (Emitted)
        na = 1.42, # Numerical aperture
        magification = 60, # Magnification
        res_lateral = 0.100, # Lateral resolution (Chapter 6: https://cdn.southampton.ac.uk/assets/imported/transforms/content-block/UsefulDownloads_Download/F21A6D82AB864B598A07D487C756A92E/Delta%20Vision%20Elite%20User%20Manual.pdf)
        res_axial = 0.200, # Axial resolution
        ri_specimen = 1.34, # Specimen refractive index (RI)
        ri_coverslip = 1.522, # Coverslip RI design value (use oil calculator app)
        ri_medium = 1.524, # Immersion medium RI design value
        working_dist = 150, # Working distance (immersion medium thickness) design value (http://facilities.igc.gulbenkian.pt/microscopy/microscopy-dv.php)
        coverslip_thickness = 170, # Coverslip thickness design value
        particle_dist = 2 # Particle distance
        )