This notebook is part of the `deepcell-tf` documentation: https://deepcell.readthedocs.io/.

# Nuclear segmentation and tracking

In [1]:
import os

import numpy as np

import imageio
import matplotlib.pyplot as plt

import cv2
from skimage import io

import tensorflow as tf
from tensorflow.keras import backend as K

## Nuclear Segmentation

### Initialize nuclear model

The application will download pretrained weights for nuclear segmentation. For more information about application objects, please see our [documentation](https://deepcell.readthedocs.io/en/master/API/deepcell.applications.html).

In [2]:
from deepcell.applications import NuclearSegmentation

app = NuclearSegmentation()

Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearSegmentation-3.tar.gz


## Use the application to generate labeled images

Typically, neural networks perform best on test data that is similar to the training data. In the realm of biological imaging, the most common difference between datasets is the resolution of the data measured in microns per pixel. The training resolution of the model can be identified using `app.model_mpp`.

In [6]:
print('Training Resolution:', app.model_mpp, 'microns per pixel')

Training Resolution: 0.65 microns per pixel


The resolution of the input data can be specified in `app.predict` using the `image_mpp` option. The `Application` will rescale the input data to match the training resolution and then rescale to the original size before returning the labeled image.

In [None]:
# load a 3D image
DATASET_DIR='../../data/'
print(os.listdir(DATASET_DIR))
IMAGES_DIR = os.path.join(DATASET_DIR,'val_data','#3c NUCLEAR MORPHOLOGY - RAW')
images_names = os.listdir(IMAGES_DIR)
images = io.imread(IMAGES_DIR + '/' + images_names[1])

# loads one slice only
# image = images[len(images)//2+4]

# converts the image to fit it to the network
images = np.array([cv2.convertScaleAbs(image, alpha=(255.0/65535.0)) for image in images])
images = images.astype(np.uint8)

plt.imshow(images[len(images)//2])
plt.show()

images = np.expand_dims(images,-1)
# image = np.expand_dims(image, 0)
print(images.shape)

# prediction
y_pred = app.predict(images, image_mpp=.6)

print(y_pred.shape)
plt.imshow(y_pred[len(y_pred)//2,...,0])
plt.show()

### we keep only the biggest nuclei

In [None]:
labels, counts = np.unique(y_pred,return_counts=True)
print("Number of predicted labels: {}".format(len(labels)))

max_count_idx = np.argmax(counts[1:]) # index of the maximum
print("Maximum index: {}".format(max_count_idx))
# the other indices are set to zeros in the images
reformat_pred = np.where(y_pred == max_count_idx+1, y_pred, y_pred * 0)

In [None]:
plt.imshow(reformat_pred[len(y_pred)//2-5,...,0])
plt.show()

### loop over the dataset

In [None]:
def pad_square(images):
    """
    adds zeros to transform the image in a square.
    """
    n,x,y,c = images.shape  
    nbof_zeros = []
    if x > y:
        nbof_zeros_before = (x-y)//2 
        nbof_zeros_after = (x-y)//2 if (x-y)%2==0 else (x-y)//2 + 1
        nbof_zeros = [[0,0],[0,0],[nbof_zeros_before,nbof_zeros_after],[0,0]]
        images = np.pad(images, nbof_zeros, 'constant', constant_values=0)
    elif y > x:
        nbof_zeros_before = (y-x)//2 
        nbof_zeros_after = (y-x)//2 if (y-x)%2==0 else (y-x)//2 + 1
        nbof_zeros = [[0,0],[nbof_zeros_before,nbof_zeros_after],[0,0],[0,0]]
        images = np.pad(images, nbof_zeros, 'constant', constant_values=0)
    
    return images, nbof_zeros

In [None]:
# find the list of file names
DATASET_DIR='../../data/'
IMAGES_DIR = os.path.join(DATASET_DIR,'val_data','#3c NUCLEAR MORPHOLOGY - RAW')
images_names = os.listdir(IMAGES_DIR)

for i in range(len(images_names)):
    
    # read the images
    print("current index: {}/{}".format(i,len(images_names)))
    images = io.imread(IMAGES_DIR + '/' + images_names[i])

    # loads one slice only
    # image = images[len(images)//2+4]

    # converts the image to fit it to the network
    images = np.array([cv2.convertScaleAbs(image, alpha=(255.0/65535.0)) for image in images])
    images = images.astype(np.uint8)

    # adapt images dimension to fit to the network
    # the network can take in some cases a non-squared input but I did not managed to find 
    # under what exact conditions it is possible.
    images = np.expand_dims(images,-1)
    images, nbof_zeros = pad_square(images) 
    assert images.shape[1]==images.shape[2]

    # prediction
    y_pred = app.predict(images)
    print(y_pred.shape)
    
    # crop the prediction back to the orginal size
    if nbof_zeros != []:
        n0,x0,y0,c0 = nbof_zeros
        n,x,y,c = images.shape
        images = images[:,x0[0]:x-x0[1],y0[0]:y-y0[0],:]
    print(images.shape)
    
    # remove multiple labels
    labels, counts = np.unique(y_pred,return_counts=True)
    
    # save only if enough labels in the picture
    if len(labels) > 1: 
        
        # index of the maximum
        max_count_idx = np.argmax(counts[1:]) 
        
        # the other indices are set to zeros in the images
        reformat_pred = np.where(y_pred == max_count_idx+1, y_pred, y_pred * 0)

        # saves the images output
        io.imsave(DATASET_DIR+'deepcell_output/'+images_names[i], np.squeeze(reformat_pred))

### Save labeled images as a gif to visualize

In [8]:
def plot(im1, im2, vmin, vmax):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(im1)
    ax[0].axis('off')
    ax[0].set_title('Raw')
    ax[1].imshow(im2, cmap='jet', vmin=vmin, vmax=vmax)
    ax[1].set_title('Segmented')
    ax[1].axis('off')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)

    return image

imageio.mimsave(
    './labeled.gif',
    [plot(x[i,...,0], y_pred[i,...,0], y_pred.min(), y_pred.max())
     for i in range(y_pred.shape[0])]
)