In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
from skimage.io import imread
import skimage.measure as skimeas
from os.path import join
from mayavi import mlab
from skimage import measure
%matplotlib inline  

In [2]:
filename = r'cells.tif'   
dirpath = r'' 
filepath = join(dirpath, filename) #final path to load image
img = imread(filepath) #loading image           

In [3]:
def myplt(img, title):
    """This function displays the input image with interpolation='none', cmap='magma' in the range of 0-255 (8-bits). """
    plt.imshow(img,interpolation='none',cmap='magma', vmin=0, vmax=255)
    plt.colorbar(fraction=0.046, pad=0.04) 
    plt.title(title)

In [4]:
# Function to plot pairs of images
def pltPair(img1, img2, title1, title2, cmap1, cmap2, lim):
    """This function displays a pair of input image with interpolation='none',in the range of lims. """

    plt.figure(figsize=(16,16))
    plt.subplot(1,2,1)
    plt.imshow(img1, interpolation='none', cmap=cmap1, vmin=lim[0], vmax=lim[1])
    plt.title(title1)
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.subplot(1,2,2)
    plt.imshow(img2, interpolation='none', cmap=cmap2, vmin=lim[2], vmax=lim[3])
    plt.title(title2)
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.show()

In [5]:
# Plot image and diplay image features
imgDim = img.shape
nPlanes = imgDim[0]
from ipywidgets import interact
@interact (z=(0,nPlanes-1,1))
def plot_slides (z):
  myplt(img[z,:,:], "Membranes")

print('Variable Type: ', type(img))
print('Image data type: ', img.dtype)
print('Image dimension: ', img.shape)
plt.figure(figsize=(20,20))
plt.show()

interactive(children=(IntSlider(value=153, description='z', max=306), Output()), _dom_classes=('widget-interac…

Variable Type:  <class 'numpy.ndarray'>
Image data type:  uint16
Image dimension:  (307, 786, 712)


<Figure size 1440x1440 with 0 Axes>

In [6]:
from datetime import datetime
import time
start_time = time.time()
# Create an image border mask

border_mask = np.zeros(img.shape, dtype=bool)
border_mask = ndi.binary_dilation(border_mask, border_value=1)

# Remove the cells at the border

# Create a copy of the initial result
clean_img = np.copy(img)

for cell_ID in np.unique(img):

    # Create a mask that contains only the 'current' cell of the iteration
    cell_mask = img ==cell_ID 
    
    # Use the cell mask and the border mask to test if the cell has pixels touching 
    # the image border or not.
    cell_border_overlap = np.logical_and(cell_mask, border_mask)  # Overlap of cell mask and boundary mask
    total_overlap_pixels = np.sum(cell_border_overlap)            # Sum overlapping pixels

    # If a cell touches the image boundary, delete it by setting its pixels in the segmentation to 0.
    if total_overlap_pixels > 0: 
        clean_img[cell_mask] = 0

# Re-label the remaining cells to keep the numbering consistent from 1 to N (with 0 as background).

for new_ID, cell_ID in enumerate(np.unique(clean_img)[1:]):  # The [1:] excludes 0 from the list (background)!
    clean_img[clean_img==cell_ID] = new_ID+1                  # The same here for the +1
        
nCells = clean_img.max();
print(str(nCells), ' cells detected after removing the ones at the border')
print("--- %s seconds ---" % (time.time() - start_time))

249  cells detected after removing the ones at the border
--- 409.03640508651733 seconds ---


In [7]:
@interact(z=(1,nPlanes-1,1))
def plot_slides(z):
  pltPair(img[z,:,:], clean_img[z,:,:], 'Cells initial', 'Cells final', 'jet', 'jet', [0,255,0,nCells])
  plt.show()

interactive(children=(IntSlider(value=153, description='z', max=306, min=1), Output()), _dom_classes=('widget-…

In [13]:
# Create the 3D reconstruction of the cells
start_time = time.time()
for cell_ID in np.unique(clean_img)[1:nCells]:#here we can choose how many cells we want to calculate
    cell_mask = clean_img==cell_ID
   
    verts, faces, normals, values = measure.marching_cubes(cell_mask, 0.0,step_size=2)

    thefile = open('Cell_reconstruction' + str(cell_ID) +'.txt', 'w')
    for item in verts:
          thefile.write("v {0} {1} {2}\n".format(item[0],item[1],item[2]))

    for item in normals:
          thefile.write("vn {0} {1} {2}\n".format(item[0],item[1],item[2]))

    for item in faces:
          thefile.write("f {0} {1} {2}\n".format(item[0],item[1],item[2]))  

    thefile.close()
print("--- %s seconds ---" % (time.time() - start_time))

--- 4.445544004440308 seconds ---
