In [None]:
# import modules
from __future__ import division
import os
import numpy as np
import skimage.io as io
import scipy.ndimage as ndi
import sympy.geometry as sp
import matplotlib.pyplot as plt
import scipy.spatial as spatial

from sympy import Point3D
from sympy.core.cache import *
from skimage.morphology import disk
from scipy.spatial.distance import cdist
from skimage.segmentation import find_boundaries
%matplotlib inline

def real_space_convert(px_points, z_px_res, x_px_res, y_px_res):
    
    um_points = np.zeros_like(px_points)
    points = px_points.shape[0]
    
    for point in range(points):
        um_points[point, 0] = px_points[point, 0] * z_px_res
        um_points[point, 1] = px_points[point, 1] * x_px_res
        um_points[point, 2] = px_points[point, 2] * y_px_res
        
    return um_points

In [None]:
# load signal mask, mT mask and lumen boundaries

fpath = r'/Volumes/TOSHIBA EXT/Figure 6/20181119 most successful trial/injected 20181210'
fname_mT_mask = r'mT mask AQ_W0001_P0001.tif'
fname_lumen = r'lumen mT mask AQ_W0001_P0001.tif'
fname_lumen_boundaries = r'W0001_P0001 lumen boundaries.npy'

mT_mask = io.imread(os.path.join(fpath,fname_mT_mask))
lumen = io.imread(os.path.join(fpath,fname_lumen))
lumen = lumen.astype(bool)
lumen_boundaries = np.load(os.path.join(fpath,fname_lumen_boundaries))

z_res = 2
y_res = 0.2306294
x_res = 0.2306294

In [None]:
# check shape of inputs and reshape if necessary
print 'mT mask', mT_mask.shape
#mT_mask = np.rollaxis(mT_mask, ---)
#print 'mT mask reshaped', mT_mask.shape

In [None]:
print 'lumen', lumen.shape
#lumen = np.rollaxis(lumen, ---)
#print 'lumen reshaped', lumen.shape

In [None]:
print 'lumen boundaries', lumen_boundaries.shape
#lumen_boundaries = np.rollaxis(lumen, ---)
#print 'lumen boundaries reshaped', lumen_boundaries.shape

In [None]:
# find center of lumen
lumen_centers = np.zeros((lumen.shape[0],3),dtype=np.float)
for timepoint in range(lumen.shape[0]):
    stack = lumen[timepoint,:,:,:]
    if stack.max() > 0:
        stack_center = ndi.measurements.center_of_mass(stack)
        lumen_centers[timepoint,:] = stack_center
    else:
        print 'timepoint ', timepoint, ' empty'
        continue
lumen_centers_um = real_space_convert(lumen_centers,z_res,y_res,x_res)

In [None]:
# smooth mT mask, fill holes, find outer boundaries, find center

mT_smooth = np.zeros_like(mT_mask)
mT_smooth_centers = np.zeros((mT_mask.shape[0],3),dtype=np.float)
mT_filled = np.zeros_like(mT_smooth, dtype = np.uint8)
mT_outer_boundaries = np.zeros_like(mT_filled, dtype = np.uint8)
embryo = np.zeros_like(mT_mask, dtype = np.uint8)


size = --- # for structuring element
se = disk(size) # create structurig element
iter_rep = ---


for timepoint in range(mT_mask.shape[0]):
    
    if lumen[timepoint,:,:,:].max() > 0:
    
        print 'timepoint', timepoint
        stack = mT_mask[timepoint,:,:,:].astype(np.uint8)
        for z_slice in range(stack.shape[0]):
            stack[z_slice,:,:] = ndi.binary_opening(stack[z_slice,:,:], structure = se, iterations = iter_rep)
        mT_smooth[timepoint,:,:] = stack
        
        stack_filled = stack + lumen[timepoint,:,:,:]
        stack_filled = stack_filled.astype(bool)
        for z_slice in range(stack.shape[0]):
            stack_filled[z_slice,:,:] = ndi.binary_fill_holes(stack_filled[z_slice,:,:], structure = se)
        
        labels, num_labels = ndi.label(stack_filled)
        sizes = ndi.measurements.sum(stack_filled,labels,index=range(num_labels+1))
        print num_labels, sizes.max(), 'embryo size'
    
        for index,size in enumerate(sizes):
            if size < sizes.max():
                labels[labels == index] = 0
        filt_sizes = ndi.measurements.sum(labels.astype(bool))
        print 'filt sizes', filt_sizes
        print 'embryo size retained', filt_sizes.max()
        
        stack_filled = labels.astype(bool)
        mT_filled[timepoint,:,:,:] = stack_filled
        mT_outer_boundaries[timepoint,:,:,:] = find_boundaries(stack_filled)
        
        cells = stack_filled.astype(np.uint8)
        cells[lumen[timepoint,:,:,:].astype(np.uint8) == 1] = 0
        embryo[timepoint,:,:,:] = cells
        
        stack_center = ndi.measurements.center_of_mass(cells)
        print stack_center
        mT_smooth_centers[timepoint,:] = stack_center
    
    else:
        print 'timepoint ', timepoint, ' empty'
        continue

mT_smooth_centers_um = real_space_convert(mT_smooth_centers,z_res,y_res,x_res)
print mT_smooth_centers_um

In [None]:
# find point in lumen boundary points 'perpendicular' to mask center
ICM_widths = np.zeros((mT_mask.shape[0],1), dtype = np.float)
L_boundary_pts = np.zeros_like(mT_smooth_centers)
O_boundary_pts = np.zeros_like(mT_smooth_centers)

for timepoint in range(mT_mask.shape[0]):
    L_stack = lumen_boundaries[timepoint,:,:,:]
    O_stack = mT_outer_boundaries[timepoint,:,:,:]
    
    if L_stack.max() > 0:
        zL,yL,xL = np.where(lumen_boundaries[timepoint,:,:,:])
        zL = np.reshape(zL,(len(zL),1))
        yL = np.reshape(yL,(len(yL),1))
        xL = np.reshape(xL,(len(xL),1))
        coordinatesL = np.concatenate((zL,yL,xL),axis = 1)
        print 'lumen boundaries space conversion', coordinatesL[1,:], 'px',
        coordinatesL_um = real_space_convert(coordinatesL,z_res,y_res,x_res)
        print coordinatesL_um[1,:], 'um'
        
        mT_smooth_center_um = np.reshape(mT_smooth_centers_um[timepoint,:],(1,3))
        
        all_distances = cdist(mT_smooth_center_um,coordinatesL_um)
        print all_distances.shape, 'all distances'
        min_dist = all_distances.min(axis = 1)
        min_dist_coordinates = np.where(min_dist == all_distances)
        print coordinatesL_um[min_dist_coordinates[1]].shape, 'IDed lumen boundary points'
        
        L_boundary_pt = coordinatesL_um[min_dist_coordinates[1]][0,:]
        L_boundary_pts[timepoint,:] = L_boundary_pt
        print L_boundary_pt, 'lumen boundary point'
        
        line = sp.line3d.Line3D(Point3D(lumen_centers_um[timepoint,:]),Point3D(mT_smooth_center_um[0,:]))
        
        zO,yO,xO = np.where(mT_outer_boundaries[timepoint,:,:,:])
        zO = np.reshape(zO,(len(zO),1))
        yO = np.reshape(yO,(len(yO),1))
        xO = np.reshape(xO,(len(xO),1))
        coordinatesO = np.concatenate((zO,yO,xO),axis = 1)
        print 'outer boundaries space conversion', coordinatesO[1,:], 'px',
        coordinatesO_um = real_space_convert(coordinatesO,z_res,y_res,x_res)
        
        print coordinatesO_um[1,:], 'um'
        
        hull = spatial.ConvexHull(coordinatesO_um)
        facets = hull.simplices
        
        for facet in range(facets.shape[0]):
            test_facet = facets[facet]
            point1 = Point3D(coordinatesO_um[test_facet[0],:])
            point2 = Point3D(coordinatesO_um[test_facet[1],:])
            point3 = Point3D(coordinatesO_um[test_facet[2],:])
            plane = sp.Plane(point1, point2, point3)
            
            p1 = np.array(point1)
            p2 = np.array(point2)
            p3 = np.array(point3)
            
            bounds = np.vstack((p1,p2,p3))
            
            z_min = bounds[:,0].min()
            y_min = bounds[:,1].min()
            x_min = bounds[:,2].min()
            z_max = bounds[:,0].max()
            y_max = bounds[:,1].max()
            x_max = bounds[:,2].max()
            
            if line.intersection(plane):
                O_boundary_pt = line.intersection(plane)[0]
                
                if (O_boundary_pt[0] >= z_min and O_boundary_pt[0] <= z_max and
                    O_boundary_pt[1] >= y_min and O_boundary_pt[1] <= y_max and
                    O_boundary_pt[2] >= x_min and O_boundary_pt[2] <= x_max ):
                    
                    if (O_boundary_pt.distance(Point3D(L_boundary_pt)) <
                        O_boundary_pt.distance(Point3D(lumen_centers_um[timepoint,:]))):
                    
                        O_boundary_pts[timepoint,:] = np.array([float(O_boundary_pt[0]),
                                                                float(O_boundary_pt[1]),
                                                                float(O_boundary_pt[2])])
                        print O_boundary_pts[timepoint,:], 'outer boundary point'
                        break
                
            else:
                clear_cache()
                continue
        
        ICM_widths[timepoint,:] = O_boundary_pt.distance(Point3D(L_boundary_pt))
        print ICM_widths[timepoint,:], 'ICM width'
    else:
        print 'timepoint ', timepoint, ' empty'
        continue

In [None]:
io.imsave(os.path.join(fpath, fname_lumen[--:--] + '---.tif'),mT_filled.astype(np.uint8))
io.imsave(os.path.join(fpath, fname_lumen[--:--] + '---.tif'),embryo.astype(np.uint8))
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'),ICM_widths)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'),mT_outer_boundaries)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'),lumen_centers_um)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'),lumen_centers)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'), mT_smooth_centers_um)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'), mT_smooth_centers)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'), O_boundary_pts)
np.save(os.path.join(fpath,   fname_lumen[--:--] + '---.npy'), L_boundary_pts)