In [1]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from skimage import io, measure, segmentation, morphology, draw
import networkx as nx
from tqdm.notebook import tqdm
import os
import re
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define input and output directories
nuclei_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-1/Nuclei'
membrane_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-1/Membrane_Adjusted'
output_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-1/Fused'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Define all the functions from the original code
def load_images(nuclei_path, membrane_path):
    """
    Load nuclei and membrane images.

    Parameters:
    -----------
    nuclei_path : str
        Path to the nuclei mask image
    membrane_path : str
        Path to the membrane mask image

    Returns:
    --------
    nuclei_mask : ndarray
        Labeled nuclei mask where each nucleus has a unique integer ID
    membrane_mask : ndarray
        Binary membrane mask (1 for membrane, 0 for background)
    """
    nuclei_mask = io.imread(nuclei_path)
    membrane_mask = io.imread(membrane_path)

    # Ensure membrane mask is binary
    if membrane_mask.max() > 1:
        membrane_mask = (membrane_mask > 0).astype(np.uint8)

    return nuclei_mask, membrane_mask

def get_nuclei_properties(nuclei_mask):
    """
    Get properties of each nucleus in the mask.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask

    Returns:
    --------
    nuclei_props : list
        List of region properties for each nucleus
    """
    return measure.regionprops(nuclei_mask)

def dilate_nuclei_mask(nuclei_mask, dilation_radius=3):
    """
    Dilate each nucleus in the mask to detect proximity.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    dilation_radius : int
        Radius for dilation operation

    Returns:
    --------
    dilated_masks : dict
        Dictionary mapping nucleus ID to its dilated mask
    """
    dilated_masks = {}
    for i in np.unique(nuclei_mask):
        if i == 0:  # Skip background
            continue

        # Create binary mask for this nucleus
        binary_mask = (nuclei_mask == i).astype(np.uint8)

        # Dilate the mask
        dilated_binary_mask = ndimage.binary_dilation(
            binary_mask,
            iterations=dilation_radius
        ).astype(np.uint8)

        dilated_masks[i] = dilated_binary_mask

    return dilated_masks

def find_overlapping_nuclei(dilated_masks, nuclei_props):
    """
    Find pairs of nuclei whose dilated masks overlap.

    Parameters:
    -----------
    dilated_masks : dict
        Dictionary of dilated masks for each nucleus
    nuclei_props : list
        List of region properties for each nucleus

    Returns:
    --------
    overlapping_pairs : list
        List of tuples (id1, id2, centroid1, centroid2) for overlapping nuclei
    """
    nuclei_ids = list(dilated_masks.keys())
    overlapping_pairs = []

    # Create mapping from nucleus ID to its centroid
    centroids = {prop.label: prop.centroid for prop in nuclei_props}

    for i, id1 in enumerate(nuclei_ids):
        for id2 in nuclei_ids[i+1:]:
            # Check if dilated masks overlap
            if np.any(dilated_masks[id1] * dilated_masks[id2]):
                overlapping_pairs.append((
                    id1, id2,
                    centroids[id1], centroids[id2]
                ))

    # Sort by proximity (using centroid distance as an approximation)
    overlapping_pairs.sort(key=lambda x: np.sqrt(
        (x[2][0] - x[3][0])**2 + (x[2][1] - x[3][1])**2
    ))

    return overlapping_pairs

def create_line_corridor(centroid1, centroid2, corridor_width=15):
    """
    Create a corridor of parallel lines between two centroids.

    Parameters:
    -----------
    centroid1, centroid2 : tuple
        Centroids of two nuclei (y, x)
    corridor_width : int
        Number of parallel lines in the corridor

    Returns:
    --------
    lines : list
        List of lines, where each line is a list of points (y, x)
    """
    # Convert to integer coordinates
    y1, x1 = int(centroid1[0]), int(centroid1[1])
    y2, x2 = int(centroid2[0]), int(centroid2[1])

    # Calculate direction vector and perpendicular vector
    dx, dy = x2 - x1, y2 - y1
    length = np.sqrt(dx**2 + dy**2)

    if length == 0:  # Handle the case where centroids are at the same position
        return []

    # Normalize direction vector
    dx, dy = dx / length, dy / length

    # Calculate perpendicular vector (normalized)
    px, py = -dy, dx

    # Half the corridor width
    half_width = corridor_width // 2

    # Generate parallel lines
    lines = []
    for i in range(-half_width, half_width + 1):
        # Calculate offset for this line
        offset_x, offset_y = i * px, i * py

        # Calculate start and end points for this line
        start_y, start_x = y1 + offset_y, x1 + offset_x
        end_y, end_x = y2 + offset_y, x2 + offset_x

        # Use Bresenham's line algorithm to get points along the line
        rr, cc = draw.line(int(start_y), int(start_x), int(end_y), int(end_x))

        # Add line points to the list
        line_points = list(zip(rr, cc))
        lines.append(line_points)

    return lines

def check_membrane_barrier(lines, membrane_mask, threshold=0.5):
    """
    Check if there's a membrane barrier between two nuclei.

    Parameters:
    -----------
    lines : list
        List of lines in the corridor
    membrane_mask : ndarray
        Binary membrane mask
    threshold : float
        Fraction of lines that must be blocked to consider it a barrier

    Returns:
    --------
    has_barrier : bool
        True if a membrane barrier exists, False otherwise
    blocked_count : int
        Number of lines that are blocked by a membrane
    """
    # Skip empty lines list (could happen if centroids are at the same position)
    if not lines:
        return False, 0

    blocked_lines = 0

    for line_points in lines:
        # Check each line independently
        line_blocked = False

        for y, x in line_points:
            # Check if point is within image bounds
            if (0 <= y < membrane_mask.shape[0] and
                0 <= x < membrane_mask.shape[1]):

                # Check if point is on a membrane
                if membrane_mask[y, x] == 1:
                    line_blocked = True
                    break

        if line_blocked:
            blocked_lines += 1

    # Check if enough lines are blocked to consider it a barrier
    has_barrier = (blocked_lines / len(lines)) >= threshold

    return has_barrier, blocked_lines

def merge_nuclei(nuclei_mask, overlapping_pairs, membrane_mask, corridor_width=15, barrier_threshold=0.5):
    """
    Merge nuclei that belong to the same cell based on membrane barriers.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    overlapping_pairs : list
        List of overlapping nuclei pairs
    membrane_mask : ndarray
        Binary membrane mask
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists

    Returns:
    --------
    merged_mask : ndarray
        Nuclei mask after merging
    merge_graph : networkx.Graph
        Graph representing merge operations
    """
    # Create a copy of the nuclei mask
    merged_mask = nuclei_mask.copy()

    # Create a graph to track merges
    merge_graph = nx.Graph()

    # Add all nuclei as nodes
    for label in np.unique(nuclei_mask):
        if label > 0:  # Skip background
            merge_graph.add_node(label)

    print(f"Processing {len(overlapping_pairs)} overlapping nuclei pairs...")

    # Process each overlapping pair
    for id1, id2, centroid1, centroid2 in tqdm(overlapping_pairs):
        # Check if these IDs still exist (haven't been merged already)
        if id1 not in np.unique(merged_mask) or id2 not in np.unique(merged_mask):
            continue

        # Create corridor of lines between centroids
        lines = create_line_corridor(centroid1, centroid2, corridor_width)

        # Check if there's a membrane barrier
        has_barrier, blocked_count = check_membrane_barrier(
            lines, membrane_mask, barrier_threshold
        )

        # If no barrier, merge the nuclei
        if not has_barrier:
            print(f"Merging nuclei {id1} and {id2} (blocked lines: {blocked_count}/{len(lines)})")

            # Always merge higher ID into lower ID
            source_id = max(id1, id2)
            target_id = min(id1, id2)

            # Update the mask
            merged_mask[merged_mask == source_id] = target_id

            # Add edge in the merge graph
            merge_graph.add_edge(source_id, target_id)

    return merged_mask, merge_graph

def relabel_mask(mask):
    """
    Relabel a mask to have consecutive IDs.

    Parameters:
    -----------
    mask : ndarray
        Input mask

    Returns:
    --------
    relabeled_mask : ndarray
        Mask with consecutive IDs
    """
    # Get unique IDs (excluding 0/background)
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids > 0]

    # Create mapping from old IDs to new IDs
    id_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_ids, 1)}

    # Create new mask
    relabeled_mask = np.zeros_like(mask)

    # Apply mapping
    for old_id, new_id in id_mapping.items():
        relabeled_mask[mask == old_id] = new_id

    return relabeled_mask

def visualize_results(nuclei_mask, membrane_mask, merged_mask, save_path=None):
    """
    Visualize the segmentation results.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Original nuclei mask
    membrane_mask : ndarray
        Membrane mask
    merged_mask : ndarray
        Nuclei mask after merging
    save_path : str, optional
        Path to save the visualization
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Plot original nuclei mask
    axes[0].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[0].set_title('Original Nuclei Mask')
    axes[0].axis('off')

    # Plot membrane mask overlay
    axes[1].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[1].imshow(membrane_mask, cmap='gray', alpha=0.5)
    axes[1].set_title('Nuclei with Membrane Overlay')
    axes[1].axis('off')

    # Plot merged mask
    axes[2].imshow(merged_mask, cmap='nipy_spectral')
    axes[2].set_title('Merged Nuclei Mask')
    axes[2].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def visualize_merge_graph(merge_graph, save_path=None):
    """
    Visualize the merge graph.

    Parameters:
    -----------
    merge_graph : networkx.Graph
        Graph representing merge operations
    save_path : str, optional
        Path to save the visualization
    """
    plt.figure(figsize=(10, 8))

    # Get connected components (each represents a merged cell)
    components = list(nx.connected_components(merge_graph))

    # Assign a different color to each component
    color_map = {}
    for i, component in enumerate(components):
        for node in component:
            color_map[node] = i

    # Set node colors
    node_colors = [color_map.get(node, len(components)) for node in merge_graph.nodes()]

    # Draw the graph
    pos = nx.spring_layout(merge_graph, seed=42)
    nx.draw_networkx(
        merge_graph, pos,
        node_color=node_colors,
        cmap=plt.cm.tab20,
        node_size=200,
        with_labels=True
    )

    plt.title(f'Nuclei Merge Graph ({len(components)} cells with multiple nuclei)')
    plt.axis('off')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def segment_cells(nuclei_path, membrane_path, proximity_threshold=3,
                 corridor_width=15, barrier_threshold=0.5, output_dir=None, filename_prefix=None):
    """
    Main function to segment cells using nuclei and membrane images.

    Parameters:
    -----------
    nuclei_path : str
        Path to nuclei mask image
    membrane_path : str
        Path to membrane mask image
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    output_dir : str, optional
        Directory to save outputs
    filename_prefix : str, optional
        Prefix to use for output filenames

    Returns:
    --------
    merged_mask : ndarray
        Final segmented cell mask
    """
    print(f"Processing {filename_prefix}...")
    print("Loading images...")
    nuclei_mask, membrane_mask = load_images(nuclei_path, membrane_path)

    print(f"Nuclei mask shape: {nuclei_mask.shape}, unique IDs: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Membrane mask shape: {membrane_mask.shape}, values: {np.unique(membrane_mask)}")

    print("Analyzing nuclei properties...")
    nuclei_props = get_nuclei_properties(nuclei_mask)

    print(f"Dilating nuclei with radius {proximity_threshold}...")
    dilated_masks = dilate_nuclei_mask(nuclei_mask, proximity_threshold)

    print("Finding overlapping nuclei...")
    overlapping_pairs = find_overlapping_nuclei(dilated_masks, nuclei_props)
    print(f"Found {len(overlapping_pairs)} potentially overlapping nuclei pairs")

    print("Merging nuclei based on membrane barriers...")
    merged_mask, merge_graph = merge_nuclei(
        nuclei_mask, overlapping_pairs, membrane_mask,
        corridor_width, barrier_threshold
    )

    print("Relabeling mask to have consecutive IDs...")
    final_mask = relabel_mask(merged_mask)

    print(f"Original nuclei count: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Final cell count: {len(np.unique(final_mask)) - 1}")

    # Save outputs if directory is provided
    if output_dir:
        # Create a subfolder for this image if filename_prefix is provided
        if filename_prefix:
            image_output_dir = os.path.join(output_dir, filename_prefix)
            os.makedirs(image_output_dir, exist_ok=True)
        else:
            image_output_dir = output_dir

        # Save final mask
        mask_filename = f"{filename_prefix}_segmented_cells.tif" if filename_prefix else "segmented_cells.tif"
        io.imsave(os.path.join(image_output_dir, mask_filename),
                  final_mask.astype(np.uint16))

        # Save visualizations
        vis_filename = f"{filename_prefix}_segmentation_results.png" if filename_prefix else "segmentation_results.png"
        visualize_results(
            nuclei_mask, membrane_mask, final_mask,
            save_path=os.path.join(image_output_dir, vis_filename)
        )

        graph_filename = f"{filename_prefix}_merge_graph.png" if filename_prefix else "merge_graph.png"
        visualize_merge_graph(
            merge_graph,
            save_path=os.path.join(image_output_dir, graph_filename)
        )
    else:
        # Display visualizations
        visualize_results(nuclei_mask, membrane_mask, final_mask)
        visualize_merge_graph(merge_graph)

    return final_mask

# Function to find matching nuclei and membrane files
def find_file_pairs(nuclei_dir, membrane_dir):
    """
    Find matching pairs of nuclei and membrane files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files

    Returns:
    --------
    file_pairs : list
        List of tuples (nuclei_path, membrane_path, common_prefix)
    """
    # List all files in both directories
    nuclei_files = os.listdir(nuclei_dir)
    membrane_files = os.listdir(membrane_dir)

    # Extract prefixes from nuclei files
    nuclei_prefixes = {}
    for file in nuclei_files:
        # Use regex to extract the prefix until "_Nuclei"
        match = re.match(r'(.+?)_Nuclei', file)
        if match:
            prefix = match.group(1)
            nuclei_prefixes[prefix] = file

    # Extract prefixes from membrane files
    membrane_prefixes = {}
    for file in membrane_files:
        # Use regex to extract the prefix until "_membrane"
        match = re.match(r'(.+?)_membrane', file)
        if match:
            prefix = match.group(1)
            membrane_prefixes[prefix] = file

    # Find common prefixes
    common_prefixes = set(nuclei_prefixes.keys()) & set(membrane_prefixes.keys())

    # Create pairs of full paths
    file_pairs = []
    for prefix in common_prefixes:
        nuclei_path = os.path.join(nuclei_dir, nuclei_prefixes[prefix])
        membrane_path = os.path.join(membrane_dir, membrane_prefixes[prefix])
        file_pairs.append((nuclei_path, membrane_path, prefix))

    return file_pairs

# Main script to process all matching file pairs
def process_all_pairs(nuclei_dir, membrane_dir, output_dir,
                     proximity_threshold=3, corridor_width=15, barrier_threshold=0.5):
    """
    Process all matching pairs of nuclei and membrane files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files
    output_dir : str
        Directory to save outputs
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    """
    # Find all matching file pairs
    file_pairs = find_file_pairs(nuclei_dir, membrane_dir)

    print(f"Found {len(file_pairs)} matching file pairs")

    # Process each pair
    for nuclei_path, membrane_path, prefix in tqdm(file_pairs):
        try:
            # Create a new folder for each image pair
            segment_cells(
                nuclei_path, membrane_path,
                proximity_threshold=proximity_threshold,
                corridor_width=corridor_width,
                barrier_threshold=barrier_threshold,
                output_dir=output_dir,
                filename_prefix=prefix
            )
            print(f"Completed processing {prefix}")
        except Exception as e:
            print(f"Error processing {prefix}: {e}")

# Run the main processing function
process_all_pairs(
    nuclei_dir=nuclei_dir,
    membrane_dir=membrane_dir,
    output_dir=output_dir,
    proximity_threshold=5,
    corridor_width=15,
    barrier_threshold=0.6
)

print("All cell segmentation completed!")

Output hidden; open in https://colab.research.google.com to view.