In [None]:
# Standard library imports
import numpy as np
import matplotlib.pyplot as plt

# Third-party specialized imports
import tifffile

# Scikit-image imports
from skimage import data, filters, io
from skimage.measure import label, regionprops
from skimage.morphology import (
    disk, 
    dilation, 
    remove_small_objects, 
    skeletonize
)
from skimage.segmentation import expand_labels, find_boundaries

from scipy import ndimage

import os

# Detect if running in Binder or set manually
IN_BINDER = os.environ.get('BINDER_SERVICE_HOST') is not None

# Configuration flags
USE_NAPARI = not IN_BINDER  # Automatically use matplotlib in Binder
SHOW_IMAGES = True  # Set to False to skip all visualization

# Initialize viewer variable
viewer = None

# Setup based on environment
if USE_NAPARI and SHOW_IMAGES:
    try:
        import napari
        # Only use %gui qt if not in Binder and in Jupyter
        if not IN_BINDER:
            get_ipython().run_line_magic('gui', 'qt')
        viewer = napari.Viewer()
        print("Using napari for visualization")
    except Exception as e:
        print(f"Could not initialize napari: {e}")
        print("Falling back to matplotlib")
        USE_NAPARI = False
else:
    print("Using matplotlib for visualization")



In [None]:
def display_image(image, name='Image', colormap='gray', contrast_limits=None):
    """
    Display an image using either napari or matplotlib.
    
    Parameters:
    -----------
    image : numpy.ndarray
        The image to display
    name : str
        Name/title for the image
    colormap : str
        Colormap to use ('gray', 'viridis', 'plasma', etc.)
    contrast_limits : tuple or None
        Optional (min, max) values for contrast adjustment
    """
    if not SHOW_IMAGES:
        return
    
    if USE_NAPARI and viewer is not None:
        # Napari display
        viewer.add_image(image, name=name, colormap=colormap, 
                        contrast_limits=contrast_limits)
    else:
        # Matplotlib display
        plt.figure(figsize=(8, 8))
        vmin, vmax = contrast_limits if contrast_limits else (None, None)
        plt.imshow(image, cmap=colormap, vmin=vmin, vmax=vmax)
        plt.title(name)
        plt.axis('off')
        plt.colorbar()
        plt.show()

def display_labels(labels, name='Labels'):
    """
    Display a label image using either napari or matplotlib.
    
    Parameters:
    -----------
    labels : numpy.ndarray
        The label image to display
    name : str
        Name/title for the labels
    """
    if not SHOW_IMAGES:
        return
    
    if USE_NAPARI and viewer is not None:
        # Napari display
        viewer.add_labels(labels, name=name)
    else:
        # Matplotlib display
        plt.figure(figsize=(8, 8))
        plt.imshow(labels, cmap='tab20', interpolation='nearest')
        plt.title(name)
        plt.axis('off')
        plt.show()

In [None]:
# All variable parameters are here
input_file = 'data/yeast.tif'
ground_truth_file = 'data/cellpose_truth.tif'
show_images = True
save_result = False

In [None]:
image = io.imread(input_file)
plane_of_interest = image[1,:,:] # Extract second plane

if show_images:
    display_image(plane_of_interest, name='Original', colormap='gray')


# Display cellpose results, considered ground truth
cellpose_mask = io.imread('data/cellpose_mask.tif')

if show_images:
    display_labels(cellpose_mask, name="Cellpose Label Image")


In [None]:
import pandas as pd
import numpy as np

# Create empty DataFrame with all columns
columns = [
    'blur_radius_for_normalisation',
    'median_filter_size', 
    'threshold_bright_pixels',
    'remove_objects_below_size',
    'min_distance_bridge_cut',
    'dilate_borders_px',
    'remove_objects_above_size',
    'label_expansion',
    'IoU_Score',
    'Ground_Truth_Objects',
    'Detected_Objects'
]

results_table = pd.DataFrame(columns=columns)
print("Results table initialized:")

from IPython.display import display
display(results_table)


In [None]:
# Analysis parameters

# Pseudo flat field: we divide the image by its gaussian blurred version
blur_radius_for_normalisation = 20

# Removes noise from the image
median_filter_size = 3

# Threshold bright pixels which are surrounding the cell
threshold_bright_pixels = 1.14 # 1.15 original

# Remove small objects which are assumed not to represent any cell border
remove_objects_below_size = 140 # 150 original

# Splits seeds connected by a thin bridge
min_distance_bridge_cut = 4

# Dilate border mask. Goal: closing cells, without losing small ones. Value too big: small cell are missed. Values too small: cells remain opened, thus closed
dilate_borders_px = 9 # original 10

# Once the image is becoming a label image, we remove the labels which are too big: they are not cells but rather the background or interstices between cells
remove_objects_above_size = 5000

# Makes cells more blobby, make them match better their outline
label_expansion = 16

In [None]:
# Normalize: divide by the gaussian blurred version of the image (radius 20 pixels), output is 32 bits
def normalize_with_blur(image, blur_radius):
    """
    Normalize image by dividing by its Gaussian blurred version.
    Returns 32-bit float result.
    """
    # Gaussian blur
    blurred = filters.gaussian(image.astype(np.float32), sigma=blur_radius)
    
    # Normalize and convert to 32-bit float
    normalized = (image.astype(np.float32) / blurred).astype(np.float32)
    
    return normalized

In [None]:
normalized_image = normalize_with_blur(plane_of_interest, blur_radius=blur_radius_for_normalisation)

if show_images:
    display_image(normalized_image, name='Normalized', colormap='gray')

In [None]:
# Apply median filter
filtered_image = filters.median(normalized_image, disk(median_filter_size))

if show_images:
    display_image(filtered_image, name='Median_Disk_3', colormap='gray')

In [None]:
# Keeps bright pixels surrounding cells
thresholded_image = (filtered_image>threshold_bright_pixels)

if show_images:
    display_image(thresholded_image, name='Threshold_'+str(threshold_bright_pixels), colormap='gray')

In [None]:
# Remove small connected components, 8-bits connectivity
cleaned_image = remove_small_objects(thresholded_image, min_size=remove_objects_below_size, connectivity=2)

if show_images:
    display_image(cleaned_image, name='Cleaned Objects below '+str(remove_objects_below_size), colormap='gray')

In [None]:
# Dilate border image

dilated_image = dilation(cleaned_image, disk(dilate_borders_px))

if show_images:
    display_image(dilated_image, name='Dilated_'+str(dilate_borders_px), colormap='gray')

In [None]:
# Keep only objects SMALLER than a certain size
def keep_small_objects(binary_mask, max_size):
    # Remove objects >= max_size (keeping smaller ones)
    large_objects = remove_small_objects(binary_mask, min_size=max_size)
    # Subtract large objects from original to keep only small ones
    small_objects_only = binary_mask & ~large_objects
    return small_objects_only

# Usage
seeds_raw = keep_small_objects(~dilated_image, max_size=remove_objects_above_size)

if show_images:
    display_image(seeds_raw, name='Seeds', colormap='gray')


In [None]:

def distance_threshold_mask(binary_mask, min_distance=2):
    """
    Create a mask keeping only pixels that are at least min_distance 
    away from the edge of objects.
    """
    # Compute distance transform (distance to nearest background pixel)
    distance = ndimage.distance_transform_edt(binary_mask)
    
    # Keep only pixels with distance >= min_distance
    thresholded_mask = distance >= min_distance
    
    return thresholded_mask

seeds_split = distance_threshold_mask(seeds_raw, min_distance=min_distance_bridge_cut)

if show_images:
    display_image(seeds_split, name='Seeds Split', colormap='gray')

In [None]:
# Now let's invert and mark as label

# Convert binary image to labeled image
labels_raw = label(seeds_split)

if show_images:
    display_labels(labels_raw, name = "Seeds Labels")

In [None]:
# Expand labels to cover the whole cell
labels = expand_labels(labels_raw, distance=label_expansion)

if show_images:
    # Find boundaries between labeled regions
    display_labels(labels, name="Labels, Final")

if save_result:
    # save results to tif
    import tifffile
    
    # Save the labeled mask as TIFF
    tifffile.imwrite('data/result_mask.tif', labels)

if show_images:
    # Find boundaries between labeled regions
    boundaries = find_boundaries(labels, mode='thick')
    
    # Create a display image: labels with black boundaries
    blobby_labels_eroded = labels.copy()
    blobby_labels_eroded[boundaries] = 0  # Set boundaries to 0 (black)
    
    display_labels(boundaries, name="Cell Delimitation")

In [None]:
import numpy as np

def iou_binary(labels1, labels2):
    """Compute IoU between two label images (converted to binary)."""
    # Convert to binary (any label > 0)
    binary1 = labels1 > 0
    binary2 = labels2 > 0
    
    # Compute intersection and union
    intersection = np.sum(binary1 & binary2)
    union = np.sum(binary1 | binary2)
    
    # Avoid division by zero
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    
    iou = intersection / union
    return iou

# Usage
iou_score = iou_binary(labels, cellpose_mask)
print(f"IoU: {iou_score:.3f}")

In [None]:
# Calculate metrics
iou_score = iou_binary(labels, cellpose_mask)
n_gt_objects = len(np.unique(cellpose_mask)) - 1
n_detected_objects = len(np.unique(labels)) - 1


# Create new row with current parameters and results
new_row = {
    'blur_radius_for_normalisation': blur_radius_for_normalisation,
    'median_filter_size': median_filter_size,
    'threshold_bright_pixels': threshold_bright_pixels,
    'remove_objects_below_size': remove_objects_below_size,
    'min_distance_bridge_cut': min_distance_bridge_cut,
    'dilate_borders_px': dilate_borders_px,
    'remove_objects_above_size': remove_objects_above_size,
    'label_expansion': label_expansion,
    'IoU_Score': round(iou_score, 3),
    'Ground_Truth_Objects': n_gt_objects,
    'Detected_Objects': n_detected_objects
}

# Add to results table
results_table = pd.concat([results_table, pd.DataFrame([new_row])], ignore_index=True)

# Display updated table
print(f"Experiment {len(results_table)} added:")
display(results_table)