# Cell segmentation C. Elegans

In [None]:
# Importing libraries
import sys
import matplotlib.pyplot as plt 
import numpy as np 
#import tifffile
#import shutil
#import zipfile
#import os
import seaborn as sns
import pandas as pd
import pathlib
import warnings
#import imageio
import re  
from skimage.io import imread
import os; from os import listdir; from os.path import isfile, join
from skimage import io
from cellpose import models
#from cellpose import plot
from scipy.ndimage import binary_fill_holes
from skimage.morphology import erosion, disk, square, remove_small_objects
import matplotlib.patches as patches
from scipy.ndimage import binary_dilation
from skimage.measure import label

import cv2
warnings.filterwarnings("ignore")

In [None]:
# Defining directories
current_dir = pathlib.Path().absolute()
fa_dir = current_dir.parents[1].joinpath('src')

# Importing fish_analyses module
sys.path.append(str(fa_dir))
import fish_analyses as fa

In [None]:
directory =pathlib.Path('/home/luisub/Desktop/FISH_Processing/dev/codes_C_elegans/OneDrive_1_6-12-2024/4-cell/all_images')


In [None]:
list_files_names_complete = sorted([f for f in listdir(directory) if isfile(join(directory, f)) and ('.png') in f], key=str.lower)  # reading all tif files in the folder
list_files_names_complete.sort(key=lambda f: int(re.sub('\D', '', f)))  # sorting the index in numerical order
path_files_complete = [ str(directory.joinpath(f).resolve()) for f in list_files_names_complete ] # creating the complete path for each file
number_files = len(path_files_complete)
list_images_complete = [imread(str(f)) for f in path_files_complete]
print(number_files)
print(list_images_complete[0].shape)

In [None]:
path_files_complete

In [None]:
# plotting all list_images_complete, in a subplot with 4 coulmns for each color channel.
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for idx_image in range (number_files):
        axs[idx_image].imshow(list_images_complete[idx_image][:,:,0])
        axs[idx_image].axis('off')
plt.show()

In [None]:
# rescaling images
list_rescaled_images = []
for idx_image in range (number_files):
    list_rescaled_images.append( fa.RemoveExtrema(list_images_complete[idx_image][:,:,0],min_percentile=0, max_percentile=99.8).remove_outliers()  )

# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for idx_image in range (number_files):
    axs[idx_image].imshow(list_rescaled_images[idx_image],cmap='Greys_r')
    axs[idx_image].axis('off')

In [None]:
# cell segmentation
list_masks = []
for i, image in enumerate(list_rescaled_images):
        model = models.Cellpose(gpu=True, model_type='cyto2') # model_type='cyto', 'cyto2' or model_type='nuclei'
        masks = model.eval(image, diameter=100, flow_threshold=1, channels=[0,0], augment=True)[0]
        list_masks.append(masks)

# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, image in enumerate(list_masks):
    axs[i].imshow(image,cmap='Greys_r')
    axs[i].axis('off')

In [None]:
# calculating the membrane mask
def compute_membrane_mask(mask):
    NUM_PIXELS_TO_DILATE = 10
    mask = mask.astype(np.int)
    mask_neighbors = np.zeros_like(mask)
    for i in range(1, np.max(mask) + 1):
        mask_i = mask == i
        mask_i_dilated = binary_dilation(mask_i, iterations=NUM_PIXELS_TO_DILATE) 
        # Use XOR to find the difference between the dilated and original mask
        mask_i_neighbors = mask_i_dilated ^ mask_i
        # Only keep neighbor values that correspond to different labels in the original mask
        mask_i_neighbors = mask_i_neighbors * mask
        mask_neighbors += mask_i_neighbors
        mask_neighbors[mask_neighbors>0] = 1
        mask_neighbors_binary = binary_fill_holes(mask_neighbors.astype(np.int8))
    return mask_neighbors_binary

In [None]:
membrane_mask = compute_membrane_mask(list_masks[0].astype(np.int8))
plt.imshow(membrane_mask,cmap='Greys_r')   
plt.axis('off')

In [None]:
# removing small objects and reordering the masks
MIN_CELL_SIZE = 5000
list_masks_filtered = []
for i, masks in enumerate(list_masks):
    reorder_masks = np.zeros(masks.shape)
    masks_filtered =  remove_small_objects(masks, min_size=MIN_CELL_SIZE)
    max_masks_filtered = np.max(masks_filtered)
    for mc, mn in enumerate(range(1,max_masks_filtered+1)):
        reorder_masks = np.where(masks_filtered == mn, -mc, masks_filtered)
    list_masks_filtered.append(np.absolute(reorder_masks))
    
# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, masks in enumerate(list_masks_filtered):
    axs[i].imshow(masks,cmap='Greys_r')
    axs[i].axis('off')

In [None]:
# create a mask for all non-zero elements in the list_masks_filtered
list_masks_all_embryo = []
mask_all = np.zeros(list_masks_filtered[0].shape)
for i, mask in enumerate(list_masks_filtered):
    list_masks_all_embryo.append( np.where(mask > 0, 1, mask_all) )    
    
# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, masks in enumerate(list_masks_all_embryo):
    axs[i].imshow(masks,cmap='Greys_r')
    axs[i].axis('off')

In [None]:
# fill empty spaces in the mask
list_masks_filled = []
for i, mask in enumerate(list_masks_all_embryo):
    list_masks_filled.append( binary_fill_holes(mask) )
# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, masks in enumerate(list_masks_filled):
    axs[i].imshow(masks,cmap='Greys_r')
    axs[i].axis('off')

In [None]:
# expand the mask by dilating it and removing the smallest object
list_masks_dilated = []
NUM_PIXELS_TO_DILATE = 20
for i, mask in enumerate(list_masks_filled):
    dilated_mask = binary_dilation(mask, iterations=NUM_PIXELS_TO_DILATE)
    # eliminate smallest disconnected mask by replacing smallest mask with zeros
    label_image = label(dilated_mask)
    if np.max(label_image) > 1:
        # count the number of pixels in each mask that are not zeros
        non_zero_sizes = [np.count_nonzero(label_image == i) for i in range(1, np.max(label_image)+1)] 
        dilated_mask = np.where(label_image == np.argmax(non_zero_sizes)+1, 1, 0)
    list_masks_dilated.append( dilated_mask)
# plotting list_rescaled_images as a single row
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, masks in enumerate(list_masks_dilated):
    axs[i].imshow(masks,cmap='Greys_r')
    axs[i].axis('off')

In [None]:
# calculate the center of mass of the mask
list_center_of_mass = []
for i, mask in enumerate(list_masks_dilated):
    center_of_mass = np.array([np.mean(np.where(mask > 0)[0]), np.mean(np.where(mask > 0)[1])])
    list_center_of_mass.append(center_of_mass)
# plotting list_rescaled_images as a single row and the center of mass as a red spot in the image
fig, axs = plt.subplots(1, number_files, figsize=(30, 10))
for i, masks in enumerate(list_masks_dilated):
    axs[i].imshow(masks,cmap='Greys_r')
    axs[i].scatter(list_center_of_mass[i][1], list_center_of_mass[i][0], c='r', s=100)
    axs[i].axis('off')

In [None]:
# Fitting an ellipse to the mask
image = list_masks_dilated[0]
# Find contours in the binary image
contours = cv2.findContours(image.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
# Assume the largest contour is the region of interest
cnt = max(contours, key=cv2.contourArea)
# Fit an ellipse to the contour
ellipse = cv2.fitEllipse(cnt)
# Extract ellipse parameters
(xc, yc), (d1, d2), angle = ellipse
# Create a figure and a set of subplots
fig, ax = plt.subplots()
# Display the image
ax.imshow(image, cmap='gray')
# Create an ellipse patch. The width and height are switched here to match the image coordinate system used by imshow
ellipse_patch = patches.Ellipse(xy= (xc, yc), width= d1, height= d2, angle=angle, edgecolor='r', facecolor='none', linewidth=2.5)
# Add the ellipse patch to the axes
ax.add_patch(ellipse_patch)
# Set plot limits and properties
ax.set_axis_off()
plt.tight_layout()
plt.show()
print(d1, d2, angle)


In [None]:
membrane_mask = compute_membrane_mask(list_masks[0])
# making all elements larger than one equal to one
fig, ax = plt.subplots()
ellipse_patch = patches.Ellipse(xy= (xc, yc), width= d1, height= d2, angle=angle, edgecolor='r', facecolor='none', linewidth=2.5)
ax.add_patch(ellipse_patch)
plt.imshow(membrane_mask,cmap='Greys_r')   
plt.axis('off')

In [None]:
raise

In [None]:
# removing masks that can be included into a ellipse
list_masks_filtered_ellipse = []
for i, masks in enumerate(list_masks_filtered):
    masks_filtered_ellipse = fa.RemoveEllipses(masks, min_axis=0.5, max_axis=1.5).remove_ellipses()
    list_masks_filtered_ellipse.append(masks_filtered_ellipse)

In [None]:
# label cell based on cell characteristics

