In [None]:
from tifffile import imread, imsave
import os, re, sys, csv
import numpy as np
import matplotlib.pyplot as pyp
from skimage.morphology import remove_small_objects
from skimage.segmentation import find_boundaries
from skimage.measure import regionprops, regionprops_table, label
from skimage.segmentation import clear_border
import cv2
import copy
import pandas as pd
from scipy import ndimage as ndi
import napari
sys.path.append('~/3D_IMC_paper/Python/python_3d_imc_tools')
from io_files import image_filepath_for_3D_stack


In [None]:
#copied the function from https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.expand_labels
# becuase skimage could not import it for some reason
"""
expand_labels is derived from code that was
originally part of CellProfiler, code licensed under BSD license.
Website: http://www.cellprofiler.org
Copyright (c) 2020 Broad Institute
All rights reserved.
Original authors: CellProfiler team
"""
import numpy as np
from scipy.ndimage import distance_transform_edt

def expand_labels(label_image, distance=1):
    """Expand labels in label image by ``distance`` pixels without overlapping.
    Given a label image, ``expand_labels`` grows label regions (connected components)
    outwards by up to ``distance`` pixels without overflowing into neighboring regions.
    More specifically, each background pixel that is within Euclidean distance
    of <= ``distance`` pixels of a connected component is assigned the label of that
    connected component.
    Where multiple connected components are within ``distance`` pixels of a background
    pixel, the label value of the closest connected component will be assigned (see
    Notes for the case of multiple labels at equal distance).
    Parameters
    ----------
    label_image : ndarray of dtype int
        label image
    distance : float
        Euclidean distance in pixels by which to grow the labels. Default is one.
    Returns
    -------
    enlarged_labels : ndarray of dtype int
        Labeled array, where all connected regions have been enlarged
    Notes
    -----
    Where labels are spaced more than ``distance`` pixels are apart, this is
    equivalent to a morphological dilation with a disc or hyperball of radius ``distance``.
    However, in contrast to a morphological dilation, ``expand_labels`` will
    not expand a label region into a neighboring region.  
    This implementation of ``expand_labels`` is derived from CellProfiler [1]_, where
    it is known as module "IdentifySecondaryObjects (Distance-N)" [2]_.
    There is an important edge case when a pixel has the same distance to
    multiple regions, as it is not defined which region expands into that
    space. Here, the exact behavior depends on the upstream implementation
    of ``scipy.ndimage.distance_transform_edt``.
    See Also
    --------
    :func:`skimage.measure.label`, :func:`skimage.segmentation.watershed`, :func:`skimage.morphology.dilation`
    References
    ----------
    .. [1] https://cellprofiler.org
    .. [2] https://github.com/CellProfiler/CellProfiler/blob/082930ea95add7b72243a4fa3d39ae5145995e9c/cellprofiler/modules/identifysecondaryobjects.py#L559
    Examples
    --------
    >>> labels = np.array([0, 1, 0, 0, 0, 0, 2])
    >>> expand_labels(labels, distance=1)
    array([1, 1, 1, 0, 0, 2, 2])
    Labels will not overwrite each other:
    >>> expand_labels(labels, distance=3)
    array([1, 1, 1, 1, 2, 2, 2])
    In case of ties, behavior is undefined, but currently resolves to the
    label closest to ``(0,) * ndim`` in lexicographical order.
    >>> labels_tied = np.array([0, 1, 0, 2, 0])
    >>> expand_labels(labels_tied, 1)
    array([1, 1, 1, 2, 2])
    >>> labels2d = np.array(
    ...     [[0, 1, 0, 0],
    ...      [2, 0, 0, 0],
    ...      [0, 3, 0, 0]]
    ... )
    >>> expand_labels(labels2d, 1)
    array([[2, 1, 1, 0],
           [2, 2, 0, 0],
           [2, 3, 3, 0]])
    """

    distances, nearest_label_coords = distance_transform_edt(
        label_image == 0, return_indices=True
    )
    labels_out = np.zeros_like(label_image)
    dilate_mask = distances <= distance
    # build the coordinates to find nearest labels,
    # in contrast to [1] this implementation supports label arrays
    # of any dimension
    masked_nearest_label_coords = [
        dimension_indices[dilate_mask]
        for dimension_indices in nearest_label_coords
    ]
    nearest_labels = label_image[tuple(masked_nearest_label_coords)]
    labels_out[dilate_mask] = nearest_labels
    return labels_out


###### Explanation of how function 'remove_disconnected_objects' works: 
###### label = 10336
###### object_i output (correspond to the minimal parallelepiped that contains the object):
###### [[10336 10336 13346 13346 13346 13346 13346]
###### [    0     0 13346 13346 13346 13346 13346]
###### [    0     0 13346 13346 13346 13346 13346]
###### [    0 15317 10336 13346 13346 13346 10336]
###### [15317 15317 15317 15317 13346 10336     0]]

###### object_i== label: 
###### [[ True  True False False False False False]
######  [False False False False False False False]
###### [False False False False False False False]
###### [False False  True False False False  True]
###### [False False False False False  True False]]
###### 5

In [None]:
def remove_small_disconnected_objects(mask_2D, min_size, pxl_connect):      
    """Remove small objects that are disconnected on a 2D plane, but merge on a next plane.
    The function takes each object seperately on an image and removes smaller parts of the object
    and leaves the biggest part of the object intact. pxl_connect specifies the connectivity to find neigbours, and 
    min_size specifies the size of the object that should always be removed."""
    
    all_objects = ndi.find_objects(mask_2D)

    for i, sl in enumerate(all_objects):
        if sl is None:
            continue

        label_id = i + 1

        object_i = mask_2D[sl] #matrix of the area where object is present

        boolean_object_i = object_i== label_id
        boolean_object_i = boolean_object_i.astype(int)
        labeled_ob_i = label(boolean_object_i,connectivity=pxl_connect)
        object_sizes = np.bincount(labeled_ob_i.ravel())

        #following steps modified from skimage.measure.remove_small_objects
        #do not modify object if only one present:
        if len(object_sizes) <= 2:
            continue
        
        else:
            too_small = object_sizes <= min_size

            if sum(too_small) >= 1:
                too_small_mask = too_small[labeled_ob_i]        
                object_i[too_small_mask] = 0

            #find max and min size in the np array, ignore the first element as this is for count==0
            #uncomment to check the size of disconnected objects
            ####print(object_sizes)
            max_ob = object_sizes[1:].max()

            the_biggest = object_sizes < max_ob        
            the_biggest_mask = the_biggest[labeled_ob_i]        
            object_i[the_biggest_mask] = 0
            mask_2D[sl] = object_i

    return mask_2D

In [None]:
## function from skimage package https://github.com/scikit-image/scikit-image/blob/main/skimage/measure/_regionprops.py#L869-L1161

COL_DTYPES = {
    'area': int,
    'bbox': int,
    'bbox_area': int,
    'moments_central': float,
    'centroid': float,
    'convex_area': int,
    'convex_image': object,
    'coords': object,
    'eccentricity': float,
    'equivalent_diameter': float,
    'euler_number': int,
    'extent': float,
    'feret_diameter_max': float,
    'filled_area': int,
    'filled_image': object,
    'moments_hu': float,
    'image': object,
    'inertia_tensor': float,
    'inertia_tensor_eigvals': float,
    'intensity_image': object,
    'label': int,
    'local_centroid': float,
    'major_axis_length': float,
    'max_intensity': int,
    'mean_intensity': float,
    'min_intensity': int,
    'minor_axis_length': float,
    'moments': float,
    'moments_normalized': float,
    'orientation': float,
    'perimeter': float,
    'slice': object,
    'solidity': float,
    'weighted_moments_central': float,
    'weighted_centroid': float,
    'weighted_moments_hu': float,
    'weighted_local_centroid': float,
    'weighted_moments': float,
    'weighted_moments_normalized': float
}

OBJECT_COLUMNS = {
    'image', 'coords', 'convex_image', 'slice',
    'filled_image', 'intensity_image'
}

def  skimage_props_to_dict(regions, properties=('label', 'bbox'), separator='-'):
    """Convert image region properties list into a column dictionary."""

    out = {}
    n = len(regions)
    for prop in properties:
        r = regions[0]
        rp = getattr(r, prop)
        if prop in COL_DTYPES:
            dtype = COL_DTYPES[prop]
        else:
            func = r._extra_properties[prop]
            dtype = _infer_regionprop_dtype(
                func,
                intensity=r._intensity_image is not None,
                ndim=r.image.ndim,
            )
        column_buffer = np.zeros(n, dtype=dtype)

        # scalars and objects are dedicated one column per prop
        # array properties are raveled into multiple columns
        # for more info, refer to notes 1
        if np.isscalar(rp) or prop in OBJECT_COLUMNS or dtype is np.object_:
            for i in range(n):
                column_buffer[i] = regions[i][prop]
            out[prop] = np.copy(column_buffer)
        else:
            if isinstance(rp, np.ndarray):
                shape = rp.shape
            else:
                shape = (len(rp),)

            for ind in np.ndindex(shape):
                for k in range(n):
                    loc = ind if len(ind) > 1 else ind[0]
                    column_buffer[k] = regions[k][prop][loc]
                modified_prop = separator.join(map(str, (prop,) + ind))
                out[modified_prop] = np.copy(column_buffer)
    return out

In [None]:
def load_channel_stack_for_napari(channel_name_to_load, base_folder, missing, crop_im = True):
    metal_folder = base_folder +"/" + channel_name_to_load
    image_path1 = image_filepath_for_3D_stack(metal_folder)
    image1 = imread(image_path1, pattern = None)
    
    if missing is not None:
        missing_slice_image = np.mean( np.array([image1[missing-1, :,:],image1[missing+1,:,:]]), axis=0)
        image1 =  np.insert(image1,missing, missing_slice_image, axis=0)
    
    for i in range(image1.shape[0]):
        #percent99 = np.percentile(image1[i, :,:], 99)
        #tmp_im = np.clip(image1[i, :,:],0,percent99)
        tmp_im = cv2.normalize(image1[i, :,:], None, alpha=0, beta=65535, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_16U)
        tmp_im = np.clip(tmp_im,0,65535)
        image1[i, :,:] = cv2.GaussianBlur(tmp_im,(3,3),1)
        #image1[i, :,:] = cv2.blur(tmp_im,(3,3))
               
    if crop_im == True:
         image1 = image1[:, y_start:y_end,x_start:x_end]
     
    print('Max pixel value:', np.max(image1))
    print('Median pixel value:', np.percentile(image1, 50))
    return image1

In [None]:
#follwing funtions from https://github.com/BodenmillerGroup/ImcPluginsCP/blob/d14624f2f47bd0c745b5b4345d440d8cc103b563/plugins/correctspillovermeasurements.py
import scipy.optimize as spo
def compensate_ls(dat, sm):
    compdat = np.linalg.lstsq(sm.T, dat.T, None)[0]
    return compdat.T

def compensate_nnls(dat, sm):
    def nnls(x):
        return spo.nnls(sm.T, x)[0]

    return np.apply_along_axis(nnls, 1, dat)

def compensate_dat( dat, sm, method):
    """
    Compensate by solving the linear system:
        comp * sm = dat -> comp = dat * inv(sm)
    """
    # only compensate cells with all finite measurements
    fil = np.all(np.isfinite(dat), 1)
    
    if np.sum(fil) == 0:
        # Dont compensate if there are now valid rows!
        return dat
    compdat = dat.copy()
    
    if method == 'METHOD_LS':
        compdat[fil, :] = compensate_ls(dat[fil, :], sm)
    
    if method == 'METHOD_NNLS':
        compdat[fil, :] = compensate_nnls(dat[fil, :], sm)
    # columns with any not finite value are set to np.nan
    compdat[~fil, :] = np.nan
    
    return compdat

In [None]:
#function thaken from https://github.com/BodenmillerGroup/ImcPluginsCP/blob/master/plugins/smoothmultichannel.py
def clip_hot_pixels(img, hp_filter_shape, hp_threshold):
    if hp_filter_shape[0] % 2 != 1 or hp_filter_shape[1] % 2 != 1:
        raise ValueError(
            "Invalid hot pixel filter shape: %s" % str(hp_filter_shape)
        )
    hp_filter_footprint = np.ones(hp_filter_shape)
    hp_filter_footprint[
        int(hp_filter_shape[0] / 2), int(hp_filter_shape[1] / 2)
    ] = 0
    max_img = ndi.maximum_filter(
        img, footprint=hp_filter_footprint, mode="reflect"
    )
    hp_mask = img - max_img > hp_threshold
    img = img.copy()
    img[hp_mask] = max_img[hp_mask]
    return img

In [None]:
# INPUT: single chanel TIFFs from the whole 3D model to use for 

#folder for registeration i.e an image per slice
input_base = '~/3D_model201710/3D_registred_tiffs/IMC_fullStack_registred/imageJ_registration/full_model_aligned/'


#INPUT: stack of single channel tiffs for each slice for each metal channel to whcih 
#the transformation will be applied. Tiffs for each slice are separated into folders.
stack_registred = input_base + 'SIMILARITY10_Nd148'
labels_name = 'segmentation_hwatershed_500.00_90%.tif'
initial_labels = input_base + labels_name


#overlapping area of the image stack ie the area for the full 3D model used for downstream data analysis. 
row_start = 448 #y
row_end = 936 #y
col_start = 386 #x
col_end = 1038 #x

#replace metal names with target names for downstream analysis
#csv panel to replace metal names with the antibody target 
csv_pannel = "~/model201710_panel.csv"
csv_pannel_new_name = 'clean_target'
csv_pannel_metal = 'Metal Tag'


#also measure volume for each object 
properties_to_measure_for_metals = ['mean_intensity'] #regionprops uses enumarate thus objects always retriewed in numrical ascending order
additional_properties_to_measure = ('label', 'area')

#data type for intensities to measure
dtype_to_measure = np.uint16

#remove channels not used for downstream data analysis
channels_not_to_include = ['Y89','Ru96','Ru98','Ru99','Ru100', 'Ru101', 'Ru102', 'Ru104', 'Ir191','In115', 
                           'La139','Nd145', 'Gd155', 'Gd158','Dy163','Er170','Eu151', 'Sm154', 'Er166', 'Tm169', 'Dy161']

new_lables_name = 'final' + '_' + labels_name
measured_mask_name = 'measured_mask' + '_' + new_lables_name
#depends on number of objects found in the segmentation output
dtype_for_cell_labels = 'uint32'
#output tabel names
intensities_table_name = input_base + 'model201710_mean_intensities.csv'
extra_properties_table_name = input_base + 'model201710_labels_area.csv'

##### 1.1 Load a nuclear and cytoplasmic channel and the segmentation mask. Remove objects in 3D that are less than 10 pixels in size.

In [None]:
channel_1= 'Ir191'
ir_im_stack = load_channel_stack_for_napari(channel_1,stack_registred, None,False )

channel_2 = 'Pr141'
cyto_im_stack  = load_channel_stack_for_napari(channel_2,stack_registred, None,False )



In [None]:
mask_im_stack_ini = imread(initial_labels, pattern = None)
napari_name_seg = 'segmentation'

mask_im_stack = np.zeros((round(mask_im_stack_ini.shape[1]/2),mask_im_stack_ini.shape[2],mask_im_stack_ini.shape[3] ))
mask_im_stack = mask_im_stack.astype(dtype_for_cell_labels)


for i in range(mask_im_stack.shape[0]):
    if i ==0 :
        mask_im_stack[i, :, :] = mask_im_stack_ini[:,i, :, :]
    else:
        mask_im_stack[i, :, :] = mask_im_stack_ini[:,i*2, :, :]


In [None]:
mask_im_stack_ini.shape

In [None]:
small_removed = remove_small_objects(mask_im_stack, min_size=11, connectivity=1) #min size < x, not <=

##### 1.2 Remove small disconnected objects and small objects 

In [None]:
improved_mask = np.zeros(mask_im_stack.shape, dtype = mask_im_stack.dtype)
k = 0 
while k < small_removed.shape[0]: 
    slice_2D = small_removed[k, :,:]
    disconnect_removed = remove_small_disconnected_objects(slice_2D, min_size = 7, pxl_connect = 1) #here min size <= nr
    small_slice_removed = remove_small_objects(disconnect_removed, min_size =4,connectivity=1) #here min size < nr
    improved_mask[k,:,:] = expand_labels(small_slice_removed , 1)
    k  = k + 1

improved_mask = remove_small_objects(improved_mask , min_size=21, connectivity=1) #min size < x, not <=


In [None]:
#Remove objects with idirium raw signal less than 1
object_prop_dict =dict()
channels_to_include = 'Ir191'

for file_name in os.listdir(stack_registred):
    
    if file_name in channels_to_include:
        im_files = os.path.join(stack_registred , file_name)    
        img_list = image_filepath_for_3D_stack(im_files)
        metal_im = imread(img_list, pattern=None)
        metal_im = np.array(metal_im, dtype=dtype_to_measure)
        metal_im = np.clip(metal_im, 0,np.percentile(metal_im, 99))
        #warning regionprops used enumerate() to retriev objects. Thus I assume that the order of objects retreived is always the same. Should be checked 
        # if code is modified. Can be checked inside of this function by including 'label' again for properties_to_measure_for_metals
        object_prop=regionprops(improved_mask,intensity_image= metal_im)
        out_dict = dict()
        out_dict = skimage_props_to_dict(object_prop, properties=properties_to_measure_for_metals)
        for prop_to_measure in properties_to_measure_for_metals: 
            new_key =  file_name + '_' + prop_to_measure
            out_dict[new_key] = out_dict.pop(prop_to_measure)

        object_prop_dict.update(out_dict)
        
props_table =pd.DataFrame()
props_table = pd.DataFrame(object_prop_dict)
object_dict = pd.DataFrame.to_dict(props_table)

low_iridium_signal_objects = []
for obi in object_dict['Ir191_mean_intensity'].keys():
    ir = object_dict['Ir191_mean_intensity'][obi]
    if ir < 1:
        low_iridium_signal_objects.append(obi)

updated_mask = copy.deepcopy(improved_mask)
for item in range(len(low_iridium_signal_objects)):
    obi = int(low_iridium_signal_objects[item])
    updated_mask[improved_mask == obi] = 0

new_lables_name = input_base + new_lables_name
imsave(new_lables_name, updated_mask)

In [None]:
scaling_factors = [2,1,1]
with napari.gui_qt():
    viewer = napari.view_image(ir_im_stack[:,row_start:row_end,col_start:col_end], name = channel_1, scale = scaling_factors)
    viewer.add_image(cyto_im_stack[:,row_start:row_end,col_start:col_end], name = channel_2, scale = scaling_factors)
    viewer.add_labels(updated_mask[:,row_start:row_end,col_start:col_end], name = napari_name_seg,scale = scaling_factors)

#### 2. Measure intensities for all channels in the model

In [None]:
# WARNING: this measures only the intensities for objects in the overlapping area
mask_to_measure = updated_mask[:,row_start:row_end,col_start:col_end]

mask_to_measure = remove_small_objects(mask_to_measure , min_size=21, connectivity=1) #min size < x, not <=

mask_to_measure_name = input_base + measured_mask_name
imsave(mask_to_measure_name, mask_to_measure)

object_prop_dict =dict()

for file_name in os.listdir(stack_registred):
    
    if file_name in channels_not_to_include:
        continue
    else:
        im_files = os.path.join(stack_registred , file_name)    
        img_list = image_filepath_for_3D_stack(im_files)
        metal_im = imread(img_list, pattern=None)
        metal_im = np.array(metal_im, dtype=dtype_to_measure)
        metal_im = metal_im[:,row_start:row_end,col_start:col_end]
        metal_im_smoothed = np.zeros(metal_im.shape)
        for layer in range(metal_im.shape[0]):
            one_im = metal_im[layer, :, :]
            smoothed_im = clip_hot_pixels(one_im, [3,3], 50)
            metal_im_smoothed[layer] = smoothed_im
        #warning regionprops used enumerate() to retriev objects. Thus I assume that the order of objects retreived is always the same. Should be checked 
        # if code is modified. Can be checked inside of this function by including 'label' again for properties_to_measure_for_metals
        object_prop=regionprops(mask_to_measure,intensity_image= metal_im)
        out_dict = dict()
        out_dict = skimage_props_to_dict(object_prop, properties=properties_to_measure_for_metals)
        for prop_to_measure in properties_to_measure_for_metals: 
            new_key =  file_name + '_' + prop_to_measure
            out_dict[new_key] = out_dict.pop(prop_to_measure)

        object_prop_dict.update(out_dict)                   

In [None]:
names = {}
with open(csv_pannel, 'r') as NN:            
    reader = csv.DictReader(NN)
    for row in reader:
        names[row[csv_pannel_metal]] = row[csv_pannel_new_name]

entries = list(object_prop_dict.keys())
updated_measures = object_prop_dict

#entries = object_prop_dict.keys()
for entry in entries:
    if '_mean_' in entry:
        entry_new = entry[:-15]
        updated_measures[entry_new] = updated_measures.pop(entry)
        updated_measures.update(updated_measures)

        
props_table =pd.DataFrame()
props_table = pd.DataFrame(updated_measures)

In [None]:
ori_comp_matrix = '~/compensationMatrix.csv'
compensation_matrix = pd.read_csv(ori_comp_matrix,index_col=0)

metals_not_in_cm = []
for metal in compensation_matrix.columns.values:
    if metal not in props_table.columns.values:
        metals_not_in_cm.append(metal)
compensation_matrix_new = compensation_matrix.drop(columns=metals_not_in_cm)
compensation_matrix_new = compensation_matrix_new.drop(metals_not_in_cm)

metals_not_in_data = []

for metal in props_table.columns.values:
    if metal not in compensation_matrix.columns.values:
        metals_not_in_data.append(metal)

for metal in metals_not_in_data:
    compensation_matrix_new[metal] = 0
    
cm = np.asarray(compensation_matrix_new.values)
diagonal_ones = np.identity(len(metals_not_in_data))
added_rows = np.zeros((cm.shape[1]-cm.shape[0],cm.shape[0]),dtype=int)
rows_final = np.concatenate((added_rows, diagonal_ones), axis =1)
cm_final = np.concatenate((cm, rows_final), axis =0)

new_order = compensation_matrix_new.columns.values
props_table_ordered = props_table[new_order]
uncompensated_data = props_table_ordered.values
comp_data = compensate_dat( uncompensated_data, cm_final, 'METHOD_NNLS')

props_table_compensated = pd.DataFrame(comp_data, columns = props_table_ordered.columns.values)
props_table_compensated =props_table_compensated.rename(columns = names)
props_table_compensated.to_csv(intensities_table_name, index = False)

In [None]:
object_additional_prop=regionprops_table(mask_to_measure, properties= additional_properties_to_measure)
props_table = pd.DataFrame(object_additional_prop)
props_table.to_csv(extra_properties_table_name, index = False)

#### Measure major and minor axis of each object in xy and z direction to establish if segmentation is biased in x-y direction

In [None]:
object_diameter_dict =dict()
objects_to_measure = copy.deepcopy(mask_to_measure)
objects_to_measure = clear_border(objects_to_measure)

stack_max = max(range(objects_to_measure.shape[0]))

for x in range(stack_max+1):
    if x ==0 or x == stack_max:
        continue
    object_diameter_2d_xy=regionprops(objects_to_measure[x,:,:])
    out_dict = dict()
    out_dict = skimage_props_to_dict(object_diameter_2d_xy, properties=['label','minor_axis_length', 'major_axis_length'])
    diameter_list_minor = out_dict['minor_axis_length']
    diameter_list_major = out_dict['major_axis_length']
    
    diameter_keys = out_dict['label']
    
    for i in range(len(diameter_list_major)):
        minor = diameter_list_minor[i]
        major = diameter_list_major[i]
        
        if minor == 'NaN':
            minor = 0
            
        if major == 'NaN':
            major = 0
            
        average_axis = (minor + major)/2  
        
        obi_key = diameter_keys[i]
       
        if  obi_key not in object_diameter_dict.keys():            
            object_diameter_dict[obi_key] = average_axis 
        else:
            current_diameter = object_diameter_dict[obi_key]            
            if current_diameter < average_axis:
                object_diameter_dict[obi_key] = average_axis
                
                               
diameter_table = pd.DataFrame.from_dict(object_diameter_dict, 'index', columns= ['average_axis_xy'])
diameter_table = diameter_table.reset_index()
diameter_table =diameter_table.rename(columns = {'index': 'label', 'average_axis_xy': 'average_axis_xy'})

In [None]:
extra_props_table = pd.merge(props_table, diameter_table, how='left', on='label', left_on=None, right_on=None,left_index=False, right_index=False, sort=True,
      suffixes=('_x', '_y'), copy=True)

In [None]:
image1_shape = mask_to_measure.shape

double_stack = np.zeros((image1_shape[0]*2-1, image1_shape[1], image1_shape[2]), dtype = mask_im_stack.dtype)

k = 0
for i in range(image1_shape[0]-1):
    if i == 0 :
        double_stack[i,:,:] = mask_to_measure[k, :,:]
        double_stack[i+1,:,:] = mask_to_measure[k, :,:]

    else:
        double_stack[i*2,:,:] = mask_to_measure[k, :,:]
        double_stack[i*2+1,:,:] = mask_to_measure[k, :,:]

    k = k+1

double_stack[k*2,:,:] = mask_to_measure[k, :,:]
    
axis_changed_lables = copy.deepcopy(double_stack)
axis_changed_lables = np.swapaxes(axis_changed_lables,0,1)
axis_changed_lables.shape

In [None]:
object_diameter_dict =dict()
axis_changed_lables = clear_border(axis_changed_lables)

stack_max = max(range(axis_changed_lables.shape[0]))
for x in range(stack_max+1):
    if x ==0 or x == stack_max:
        continue
        
    object_diameter_2d_xy=regionprops(axis_changed_lables[x,:,:])
    out_dict = dict()
    out_dict = skimage_props_to_dict(object_diameter_2d_xy, properties=['label','minor_axis_length', 'major_axis_length'])
    diameter_list_minor = out_dict['minor_axis_length']
    diameter_list_major = out_dict['major_axis_length']
    
    diameter_keys = out_dict['label']
    
    for i in range(len(diameter_list_major)):
        minor = diameter_list_minor[i]
        major = diameter_list_major[i]
        
        if minor == 'NaN':
            minor = 0
            
        if major == 'NaN':
            major = 0
            
        average_axis = (minor + major)/2  
        
        obi_key = diameter_keys[i]
       
        if  obi_key not in object_diameter_dict.keys():            
            object_diameter_dict[obi_key] = average_axis 
        else:
            current_diameter = object_diameter_dict[obi_key]            
            if current_diameter < average_axis:
                object_diameter_dict[obi_key] = average_axis                    
                    
diameter_table = pd.DataFrame.from_dict(object_diameter_dict, 'index', columns= ['average_axis_z'])
diameter_table = diameter_table.reset_index()
diameter_table =diameter_table.rename(columns = {'index': 'label', 'average_axis_z': 'average_axis_z'})

In [None]:
final_extra_props = pd.merge(extra_props_table, diameter_table, how='left', on='label', left_on=None, right_on=None,left_index=False, right_index=False, sort=True,
      suffixes=('_x', '_y'), copy=True)

In [None]:
fig_scatter =final_extra_props.plot.scatter(x = 'average_axis_xy', y = 'average_axis_z',s = 0.4, c = 'DarkGray')
fig_scatter.set_xlim([0,18])
fig_scatter.set_ylim([0,18])
s = pd.Series([0,1,2,3,4,5,6,7,8,9,10,11, 12, 13, 14, 15, 16, 17, 18, ])

s.plot.line(color= 'black', linewidth=0.5)

In [None]:
fig = fig_scatter.get_figure()
fig.savefig('~/figures/xy_vs_z_average_axis_model201710.png')

###### End of notebook