In [None]:
# Standard Libraries
import glob

# Image Processing and Computer Vision
import cv2
from tifffile import imread
from random import randint

# Data Visualization
import matplotlib.pyplot as plt
matplotlib.rcParams["image.interpolation"] = None  # Set interpolation to None for matplotlib

# Progress Monitoring
from tqdm import tqdm

# Deep Learning and Image Segmentation (StarDist)
from csbdeep.utils import normalize
from csbdeep.io import save_tiff_imagej_compatible
from stardist import random_label_cmap, _draw_polygons, export_imagej_rois
from stardist.models import StarDist2D
from stardist.plot import render_label


Setting Up for Image Segmentation

This segment of the code is designed to prepare for image segmentation tasks. It involves initializing the environment, reading in image data, and verifying the image dimensions.
Steps:

   - Initialize Random Seed: Setting a fixed random seed for reproducibility in any random operation.

   - Create Random Label Color Map: Generating a random color map, which is typically used for visualizing labeled data in segmentation tasks.

   - Read Image Data: Loading TIFF images from a specified directory for segmentation. This step involves finding the image file paths and reading the images into an array.

   - Check Image Size: Verifying the dimensions of the loaded images to ensure they match the expected format and size.

In [None]:
# Set a fixed random seed for reproducibility
np.random.seed(42)

# Generate a random label color map for visualization
lbl_cmap = random_label_cmap()

# Read TIFF images from the specified directory
X = sorted(glob('path/to/your/images/*.tif'))
X = list(map(imread, X))  # Use imread to load each image

# Check the shape of the first image in the dataset
X[0].shape

Normalizing Image Data for Segmentation

This section of the code is focused on preparing the image data for segmentation by normalizing the pixel intensities. Normalization is a crucial step in image processing as it standardizes the range of pixel values across the dataset.
Steps:

   - Determine Number of Channels: The number of channels in the images is identified. If the image is grayscale (2D), it has one channel. If it's color (3D), the number of channels corresponds to the third dimension of the array.

   - Set Normalization Axis: Define the axes along which normalization will occur. This can be set to normalize channels independently or jointly.

   - Normalization Process: Apply normalization to each image in the dataset. The normalization scales the pixel values to a specified percentile range, enhancing the model's ability to learn from the data.

In [None]:
# Determine the number of channels in the images
n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]

# Define the axis for normalization (independently or jointly)
axis_norm = (0, 1)  # Normalize channels independently
# axis_norm = (0, 1, 2) # Uncomment to normalize channels jointly

# Print the normalization method based on the number of channels
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))

# Normalize each image in the dataset
X = [normalize(x, 1, 99.8, axis=axis_norm) for x in tqdm(X)]

Visualizing Normalized Images

This code block is designed for visualizing the normalized images from the dataset. It arranges the images in a grid layout, making it easy to visually inspect each image and its normalization effect.
Steps:

   - Create Plotting Grid: A grid of subplots is created using plt.subplots. The grid size is 7 rows by 8 columns, accommodating 56 images.

   - Plot Images: Each image is displayed in one of the subplot axes. The images are shown in grayscale, and if an image has multiple channels, only the first channel is displayed.

   - Set Titles and Format: Each subplot is titled with its corresponding index in the dataset. The axes are turned off for a cleaner look, and the layout is adjusted for better spacing between images.

In [None]:
# Create a grid of subplots
fig, ax = plt.subplots(7, 8, figsize=(16, 16))

# Loop through the images and plot each one
for i, (a, x) in enumerate(zip(ax.flat, X)):
    # Display the image in grayscale
    a.imshow(x if x.ndim == 2 else x[..., 0], cmap='gray')
    
    # Set the title of each subplot as the image index
    a.set_title(i)

# Turn off the axes and adjust layout
for a in ax.flat:
    a.axis('off')
plt.tight_layout()

Loading a Pre-Trained StarDist 2D Model

This section of the code focuses on loading a pre-trained StarDist 2D model. The StarDist model is specifically designed for image segmentation tasks, particularly for star-convex object segmentation.
Steps:

   - Model Instantiation: The StarDist2D model is instantiated by specifying the model name and the directory where the model is saved.

   - Model Parameters:
       - None indicates that a new configuration is not being set during this instantiation.
       - name is the name of the pre-trained model to be loaded.
       - basedir is the directory path where the model files are located.

In [None]:
model = StarDist2D(None, name='model name', basedir='path/to/your/model/directory')

Displaying Segmentation Results on a Random Image

This code block demonstrates how to apply the trained StarDist 2D model to a randomly selected image from the dataset and visualize the segmentation results.
Steps:

   - Select a Random Image:
        A random index is generated to select an image from the dataset.
        The selected image is then loaded for segmentation.

   - Segmentation:
        The predict_instances method of the StarDist 2D model is used to perform segmentation on the selected image.
        This method returns the segmented labels and additional details about the segmentation.

   - Visualization:
        The original image and the segmentation results (labels) are visualized together.
        The segmentation labels are overlaid on the original image with some transparency (alpha=0.5) for better visibility.

In [None]:
# Select a random image from the dataset
img_num = randint(9, len(X)-1)
img = X[img_num]

# Perform segmentation using the StarDist 2D model
labels, details = model.predict_instances(img)

# Visualize the original image and the segmentation results
plt.figure(figsize=(8, 8))
plt.imshow(img if img.ndim == 2 else img[..., 0], clim=(0, 1), cmap='gray')
plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)
plt.axis('off')

Visualizing Segmentation Results with Detailed Annotations

The example function is designed to visualize the segmentation results of a specific image using the trained StarDist 2D model. It provides a detailed view by displaying both the original image with annotated segmentation polygons and the segmented image with labeled regions.
Functionality:

   - Normalize and Segment Image:
        The selected image is normalized, and then segmentation is performed using the predict_instances method.

   - Visualization Setup:
        Two subplots are created: one for the annotated original image and another for the image with segmentation labels overlaid.

   - Annotate with Segmentation Polygons:
        The _draw_polygons function from StarDist is used to draw segmentation polygons, illustrating how the model perceives and segments different objects in the image.

   - Overlay Segmentation Labels:
        The segmentation labels are overlaid on the original image in the second subplot, providing a clear view of the segmentation output.

   - Display Function:
        The function example can be called with the model and an image index to visualize the segmentation results for that specific image.

In [None]:
def example(model, i, show_dist=True):
    # Normalize and segment the image
    img = normalize(X[i], 1, 99.8, axis=axis_norm)
    labels, details = model.predict_instances(img)

    # Set up the figure for visualization
    plt.figure(figsize=(13, 10))
    img_show = img if img.ndim == 2 else img[..., 0]

    # Draw segmentation polygons on the first subplot
    plt.subplot(121)
    plt.imshow(img_show, cmap='gray')
    plt.axis('off')
    a = plt.axis()
    coord, points, prob = details['coord'], details['points'], details['prob']
    _draw_polygons(coord, points, prob, show_dist=show_dist)
    plt.axis(a)

    # Show the image with segmentation labels on the second subplot
    plt.subplot(122)
    plt.imshow(img_show, cmap='gray')
    plt.axis('off')
    plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)
    plt.tight_layout()
    plt.show()

# Visualize segmentation results for a specific image
example(model, 16)

Segmenting and Visualizing an External Image

This code demonstrates how to load an external image, preprocess it for segmentation, apply the trained StarDist 2D model to segment the image, and then visualize both the original and segmented images.
Steps:

   - Load External Image:
        An image is loaded from a specified path using cv2.imread.
        The image is converted to grayscale, which is a common preprocessing step in segmentation tasks.

   - Segmentation:
        The StarDist 2D model is used to segment the preprocessed image. The segmented output and additional details are captured.

   - Visualization:
        The original and segmented images are visualized side by side for comparison.
        The render_label function from StarDist is used to overlay the segmentation results on the original image.

In [None]:
# Load and preprocess the image
my_img = cv2.imread("path/to/your/image.TIFF")
my_img = cv2.cvtColor(my_img, cv2.COLOR_BGR2GRAY)

# Normalize the image (assuming normalization is required)
my_img_norm = normalize(my_img, 1, 99.8, axis=axis_norm)

# Perform segmentation
segmented_img, details_img = model.predict_instances(my_img_norm, verbose=True)

In [None]:
# Set up the figure for visualization
plt.rcParams["figure.figsize"] = (50, 50)

# Visualize the original image
plt.subplot(1, 2, 1)
plt.imshow(my_img_norm, cmap="gray")
plt.axis("off")
plt.title("Input Image")

# Visualize the segmented image
plt.subplot(1, 2, 2)
plt.imshow(render_label(segmented_img, img=my_img_norm))
plt.axis("off")
plt.title("Prediction")
plt.show()