In [1]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from skimage import io, measure, segmentation, morphology, draw
import networkx as nx
from tqdm.notebook import tqdm
import os
import re
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define input and output directories
nuclei_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Nuclei_filtered'
membrane_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Membrane'
output_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Seed_or'
visualization_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Seed_res_vis_or'

# Create output directories if they don't exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

# Define all the functions from the original code
def load_images(nuclei_path, membrane_path):
    """
    Load nuclei and membrane images.

    Parameters:
    -----------
    nuclei_path : str
        Path to the nuclei mask image
    membrane_path : str
        Path to the membrane mask image

    Returns:
    --------
    nuclei_mask : ndarray
        Labeled nuclei mask where each nucleus has a unique integer ID
    membrane_mask : ndarray
        Binary membrane mask (1 for membrane, 0 for background)
    """
    nuclei_mask = io.imread(nuclei_path)
    membrane_mask = io.imread(membrane_path)

    # Ensure membrane mask is binary
    if membrane_mask.max() > 1:
        membrane_mask = (membrane_mask > 0).astype(np.uint8)

    return nuclei_mask, membrane_mask

def get_nuclei_properties(nuclei_mask):
    """
    Get properties of each nucleus in the mask.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask

    Returns:
    --------
    nuclei_props : list
        List of region properties for each nucleus
    """
    return measure.regionprops(nuclei_mask)

def dilate_nuclei_mask(nuclei_mask, dilation_radius=3):
    """
    Dilate each nucleus in the mask to detect proximity.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    dilation_radius : int
        Radius for dilation operation

    Returns:
    --------
    dilated_masks : dict
        Dictionary mapping nucleus ID to its dilated mask
    """
    dilated_masks = {}
    for i in np.unique(nuclei_mask):
        if i == 0:  # Skip background
            continue

        # Create binary mask for this nucleus
        binary_mask = (nuclei_mask == i).astype(np.uint8)

        # Dilate the mask
        dilated_binary_mask = ndimage.binary_dilation(
            binary_mask,
            iterations=dilation_radius
        ).astype(np.uint8)

        dilated_masks[i] = dilated_binary_mask

    return dilated_masks

def find_overlapping_nuclei(dilated_masks, nuclei_props):
    """
    Find pairs of nuclei whose dilated masks overlap.

    Parameters:
    -----------
    dilated_masks : dict
        Dictionary of dilated masks for each nucleus
    nuclei_props : list
        List of region properties for each nucleus

    Returns:
    --------
    overlapping_pairs : list
        List of tuples (id1, id2, centroid1, centroid2) for overlapping nuclei
    """
    nuclei_ids = list(dilated_masks.keys())
    overlapping_pairs = []

    # Create mapping from nucleus ID to its centroid
    centroids = {prop.label: prop.centroid for prop in nuclei_props}

    for i, id1 in enumerate(nuclei_ids):
        for id2 in nuclei_ids[i+1:]:
            # Check if dilated masks overlap
            if np.any(dilated_masks[id1] * dilated_masks[id2]):
                overlapping_pairs.append((
                    id1, id2,
                    centroids[id1], centroids[id2]
                ))

    # Sort by proximity (using centroid distance as an approximation)
    overlapping_pairs.sort(key=lambda x: np.sqrt(
        (x[2][0] - x[3][0])**2 + (x[2][1] - x[3][1])**2
    ))

    return overlapping_pairs

def create_line_corridor(centroid1, centroid2, corridor_width=15):
    """
    Create a corridor of parallel lines between two centroids.

    Parameters:
    -----------
    centroid1, centroid2 : tuple
        Centroids of two nuclei (y, x)
    corridor_width : int
        Number of parallel lines in the corridor

    Returns:
    --------
    lines : list
        List of lines, where each line is a list of points (y, x)
    """
    # Convert to integer coordinates
    y1, x1 = int(centroid1[0]), int(centroid1[1])
    y2, x2 = int(centroid2[0]), int(centroid2[1])

    # Calculate direction vector and perpendicular vector
    dx, dy = x2 - x1, y2 - y1
    length = np.sqrt(dx**2 + dy**2)

    if length == 0:  # Handle the case where centroids are at the same position
        return []

    # Normalize direction vector
    dx, dy = dx / length, dy / length

    # Calculate perpendicular vector (normalized)
    px, py = -dy, dx

    # Half the corridor width
    half_width = corridor_width // 2

    # Generate parallel lines
    lines = []
    for i in range(-half_width, half_width + 1):
        # Calculate offset for this line
        offset_x, offset_y = i * px, i * py

        # Calculate start and end points for this line
        start_y, start_x = y1 + offset_y, x1 + offset_x
        end_y, end_x = y2 + offset_y, x2 + offset_x

        # Use Bresenham's line algorithm to get points along the line
        rr, cc = draw.line(int(start_y), int(start_x), int(end_y), int(end_x))

        # Add line points to the list
        line_points = list(zip(rr, cc))
        lines.append(line_points)

    return lines

def check_membrane_barrier(lines, membrane_mask, threshold=0.5):
    """
    Check if there's a membrane barrier between two nuclei.

    Parameters:
    -----------
    lines : list
        List of lines in the corridor
    membrane_mask : ndarray
        Binary membrane mask
    threshold : float
        Fraction of lines that must be blocked to consider it a barrier

    Returns:
    --------
    has_barrier : bool
        True if a membrane barrier exists, False otherwise
    blocked_count : int
        Number of lines that are blocked by a membrane
    """
    # Skip empty lines list (could happen if centroids are at the same position)
    if not lines:
        return False, 0

    blocked_lines = 0

    for line_points in lines:
        # Check each line independently
        line_blocked = False

        for y, x in line_points:
            # Check if point is within image bounds
            if (0 <= y < membrane_mask.shape[0] and
                0 <= x < membrane_mask.shape[1]):

                # Check if point is on a membrane
                if membrane_mask[y, x] == 1:
                    line_blocked = True
                    break

        if line_blocked:
            blocked_lines += 1

    # Check if enough lines are blocked to consider it a barrier
    has_barrier = (blocked_lines / len(lines)) >= threshold

    return has_barrier, blocked_lines

def merge_nuclei(nuclei_mask, overlapping_pairs, membrane_mask, corridor_width=15, barrier_threshold=0.5):
    """
    Merge nuclei that belong to the same cell based on membrane barriers.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    overlapping_pairs : list
        List of overlapping nuclei pairs
    membrane_mask : ndarray
        Binary membrane mask
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists

    Returns:
    --------
    merged_mask : ndarray
        Nuclei mask after merging
    merge_graph : networkx.Graph
        Graph representing merge operations
    """
    # Create a copy of the nuclei mask
    merged_mask = nuclei_mask.copy()

    # Create a graph to track merges
    merge_graph = nx.Graph()

    # Add all nuclei as nodes
    for label in np.unique(nuclei_mask):
        if label > 0:  # Skip background
            merge_graph.add_node(label)

    print(f"Processing {len(overlapping_pairs)} overlapping nuclei pairs...")

    # Process each overlapping pair
    for id1, id2, centroid1, centroid2 in tqdm(overlapping_pairs):
        # Check if these IDs still exist (haven't been merged already)
        if id1 not in np.unique(merged_mask) or id2 not in np.unique(merged_mask):
            continue

        # Create corridor of lines between centroids
        lines = create_line_corridor(centroid1, centroid2, corridor_width)

        # Check if there's a membrane barrier
        has_barrier, blocked_count = check_membrane_barrier(
            lines, membrane_mask, barrier_threshold
        )

        # If no barrier, merge the nuclei
        if not has_barrier:
            print(f"Merging nuclei {id1} and {id2} (blocked lines: {blocked_count}/{len(lines)})")

            # Always merge higher ID into lower ID
            source_id = max(id1, id2)
            target_id = min(id1, id2)

            # Update the mask
            merged_mask[merged_mask == source_id] = target_id

            # Add edge in the merge graph
            merge_graph.add_edge(source_id, target_id)

    return merged_mask, merge_graph

def relabel_mask(mask):
    """
    Relabel a mask to have consecutive IDs.

    Parameters:
    -----------
    mask : ndarray
        Input mask

    Returns:
    --------
    relabeled_mask : ndarray
        Mask with consecutive IDs
    """
    # Get unique IDs (excluding 0/background)
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids > 0]

    # Create mapping from old IDs to new IDs
    id_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_ids, 1)}

    # Create new mask
    relabeled_mask = np.zeros_like(mask)

    # Apply mapping
    for old_id, new_id in id_mapping.items():
        relabeled_mask[mask == old_id] = new_id

    return relabeled_mask

def visualize_results(nuclei_mask, membrane_mask, merged_mask, save_path=None):
    """
    Visualize the segmentation results.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Original nuclei mask
    membrane_mask : ndarray
        Membrane mask
    merged_mask : ndarray
        Nuclei mask after merging
    save_path : str, optional
        Path to save the visualization
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Plot original nuclei mask
    axes[0].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[0].set_title('Original Nuclei Mask')
    axes[0].axis('off')

    # Plot membrane mask overlay
    axes[1].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[1].imshow(membrane_mask, cmap='gray', alpha=0.5)
    axes[1].set_title('Nuclei with Membrane Overlay')
    axes[1].axis('off')

    # Plot merged mask
    axes[2].imshow(merged_mask, cmap='nipy_spectral')
    axes[2].set_title('Merged Nuclei Mask')
    axes[2].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def visualize_merge_graph(merge_graph, save_path=None):
    """
    Visualize the merge graph.

    Parameters:
    -----------
    merge_graph : networkx.Graph
        Graph representing merge operations
    save_path : str, optional
        Path to save the visualization
    """
    plt.figure(figsize=(10, 8))

    # Get connected components (each represents a merged cell)
    components = list(nx.connected_components(merge_graph))

    # Assign a different color to each component
    color_map = {}
    for i, component in enumerate(components):
        for node in component:
            color_map[node] = i

    # Set node colors
    node_colors = [color_map.get(node, len(components)) for node in merge_graph.nodes()]

    # Draw the graph
    pos = nx.spring_layout(merge_graph, seed=42)
    nx.draw_networkx(
        merge_graph, pos,
        node_color=node_colors,
        cmap=plt.cm.tab20,
        node_size=200,
        with_labels=True
    )

    plt.title(f'Nuclei Merge Graph ({len(components)} cells with multiple nuclei)')
    plt.axis('off')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def segment_cells(nuclei_path, membrane_path, proximity_threshold=3,
                 corridor_width=15, barrier_threshold=0.5, output_dir=None,
                 visualization_dir=None, filename_prefix=None):
    """
    Main function to segment cells using nuclei and membrane images.

    Parameters:
    -----------
    nuclei_path : str
        Path to nuclei mask image
    membrane_path : str
        Path to membrane mask image
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    output_dir : str, optional
        Directory to save output TIF files
    visualization_dir : str, optional
        Directory to save visualization PNG files
    filename_prefix : str, optional
        Prefix to use for output filenames

    Returns:
    --------
    merged_mask : ndarray
        Final segmented cell mask
    """
    print(f"Processing {filename_prefix}...")
    print("Loading images...")
    nuclei_mask, membrane_mask = load_images(nuclei_path, membrane_path)

    print(f"Nuclei mask shape: {nuclei_mask.shape}, unique IDs: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Membrane mask shape: {membrane_mask.shape}, values: {np.unique(membrane_mask)}")

    print("Analyzing nuclei properties...")
    nuclei_props = get_nuclei_properties(nuclei_mask)

    print(f"Dilating nuclei with radius {proximity_threshold}...")
    dilated_masks = dilate_nuclei_mask(nuclei_mask, proximity_threshold)

    print("Finding overlapping nuclei...")
    overlapping_pairs = find_overlapping_nuclei(dilated_masks, nuclei_props)
    print(f"Found {len(overlapping_pairs)} potentially overlapping nuclei pairs")

    print("Merging nuclei based on membrane barriers...")
    merged_mask, merge_graph = merge_nuclei(
        nuclei_mask, overlapping_pairs, membrane_mask,
        corridor_width, barrier_threshold
    )

    print("Relabeling mask to have consecutive IDs...")
    final_mask = relabel_mask(merged_mask)

    print(f"Original nuclei count: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Final cell count: {len(np.unique(final_mask)) - 1}")

    # Save outputs if directory is provided
    if output_dir:
        # Save final mask TIF file directly to output_dir
        mask_filename = f"{filename_prefix}_segmented_cells.tif" if filename_prefix else "segmented_cells.tif"
        io.imsave(os.path.join(output_dir, mask_filename), final_mask.astype(np.uint16))
        print(f"Saved TIF file to: {os.path.join(output_dir, mask_filename)}")

    # Save visualizations if directory is provided
    if visualization_dir:
        # Create a subfolder for this image if filename_prefix is provided
        if filename_prefix:
            image_vis_dir = os.path.join(visualization_dir, filename_prefix)
            os.makedirs(image_vis_dir, exist_ok=True)
        else:
            image_vis_dir = visualization_dir

        # Save visualizations
        vis_filename = f"{filename_prefix}_segmentation_results.png" if filename_prefix else "segmentation_results.png"
        visualize_results(
            nuclei_mask, membrane_mask, final_mask,
            save_path=os.path.join(image_vis_dir, vis_filename)
        )
        print(f"Saved visualization to: {os.path.join(image_vis_dir, vis_filename)}")

        graph_filename = f"{filename_prefix}_merge_graph.png" if filename_prefix else "merge_graph.png"
        visualize_merge_graph(
            merge_graph,
            save_path=os.path.join(image_vis_dir, graph_filename)
        )
        print(f"Saved merge graph to: {os.path.join(image_vis_dir, graph_filename)}")
    else:
        # Display visualizations
        visualize_results(nuclei_mask, membrane_mask, final_mask)
        visualize_merge_graph(merge_graph)

    return final_mask

# Function to find matching nuclei and membrane files
def find_file_pairs(nuclei_dir, membrane_dir):
    """
    Find matching pairs of nuclei and membrane files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files

    Returns:
    --------
    file_pairs : list
        List of tuples (nuclei_path, membrane_path, common_prefix)
    """
    # List all files in both directories
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith('.tif')]
    membrane_files = [f for f in os.listdir(membrane_dir) if f.endswith('.tif')]

    print(f"Found {len(nuclei_files)} .tif files in nuclei directory")
    print(f"Found {len(membrane_files)} .tif files in membrane directory")

    # Print some example filenames to help with debugging
    if nuclei_files:
        print("Example nuclei files:")
        for f in nuclei_files[:3]:
            print(f"  - {f}")

    if membrane_files:
        print("Example membrane files:")
        for f in membrane_files[:3]:
            print(f"  - {f}")

    # Match files based on the common prefix pattern
    file_pairs = []

    # Based on examples:
    # Nuclei: "denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq018_Nuclei_regional_tophat_filtered_mask"
    # Cadherins: "denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq018_Cadherins_regional_tophat_cadherins_mask_cleaned"

    for nuclei_file in nuclei_files:
        if "Nuclei" not in nuclei_file:
            continue

        # Extract the prefix (everything before "_Nuclei")
        nuclei_parts = nuclei_file.split('_')

        # Find where "Nuclei" appears in the parts
        try:
            nuclei_index = nuclei_parts.index("Nuclei")
            prefix_parts = nuclei_parts[:nuclei_index]
            prefix = "_".join(prefix_parts)

            # Look for matching cadherin file with same prefix
            for membrane_file in membrane_files:
                if "Cadherins" not in membrane_file:
                    continue

                membrane_parts = membrane_file.split('_')
                try:
                    cadherins_index = membrane_parts.index("Cadherins")
                    membrane_prefix_parts = membrane_parts[:cadherins_index]
                    membrane_prefix = "_".join(membrane_prefix_parts)

                    # If prefixes match, we have a pair
                    if prefix == membrane_prefix:
                        nuclei_path = os.path.join(nuclei_dir, nuclei_file)
                        membrane_path = os.path.join(membrane_dir, membrane_file)
                        file_pairs.append((nuclei_path, membrane_path, prefix))
                except ValueError:
                    # "Cadherins" not found as an exact part
                    continue
        except ValueError:
            # "Nuclei" not found as an exact part
            continue

    # If we still don't have matches, try a more flexible approach
    if not file_pairs:
        print("No matches found with exact part matching. Trying a more flexible approach...")

        # For files like:
        # denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq018_Nuclei_regional_tophat_filtered_mask
        # denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq018_Cadherins_regional_tophat_cadherins_mask_cleaned

        # Try to match based on the first several parts (e.g., first 8 components)
        for nuclei_file in nuclei_files:
            if "Nuclei" not in nuclei_file:
                continue

            nuclei_parts = nuclei_file.split('_')
            # Get the first N parts (adjust this number based on your files)
            prefix_length = 8  # First 8 parts seem to be common in examples
            if len(nuclei_parts) >= prefix_length:
                prefix_parts = nuclei_parts[:prefix_length]
                prefix = "_".join(prefix_parts)

                for membrane_file in membrane_files:
                    if "Cadherins" not in membrane_file:
                        continue

                    membrane_parts = membrane_file.split('_')
                    if len(membrane_parts) >= prefix_length:
                        membrane_prefix_parts = membrane_parts[:prefix_length]
                        membrane_prefix = "_".join(membrane_prefix_parts)

                        if prefix == membrane_prefix:
                            nuclei_path = os.path.join(nuclei_dir, nuclei_file)
                            membrane_path = os.path.join(membrane_dir, membrane_file)
                            file_pairs.append((nuclei_path, membrane_path, prefix))

    print(f"Found {len(file_pairs)} matching file pairs")

    # If still no matches, try one more approach with sequence numbers
    if not file_pairs:
        print("Trying to match based on sequence numbers...")

        for nuclei_file in nuclei_files:
            # Look for sequence pattern like "seq018"
            seq_match = re.search(r'seq(\d+)', nuclei_file)
            if not seq_match:
                continue

            seq_num = seq_match.group(0)  # e.g., "seq018"

            for membrane_file in membrane_files:
                if seq_num in membrane_file:
                    nuclei_path = os.path.join(nuclei_dir, nuclei_file)
                    membrane_path = os.path.join(membrane_dir, membrane_file)
                    # Use sequence number as part of the prefix
                    prefix = f"sequence_{seq_num}"
                    file_pairs.append((nuclei_path, membrane_path, prefix))
                    break  # Found a match for this nuclei file

    print(f"Final count: {len(file_pairs)} matching file pairs")

    # Last resort - manual pairing
    if not file_pairs and nuclei_files and membrane_files:
        print("No automatic matches found. Using manual pairing...")

        # Pair files in sequence (first nuclei with first membrane, etc.)
        n = min(len(nuclei_files), len(membrane_files))
        for i in range(n):
            nuclei_path = os.path.join(nuclei_dir, nuclei_files[i])
            membrane_path = os.path.join(membrane_dir, membrane_files[i])
            # Extract a meaningful prefix from the nuclei filename
            nuclei_name = os.path.splitext(nuclei_files[i])[0]
            if '_Nuclei' in nuclei_name:
                prefix = nuclei_name.split('_Nuclei')[0]
            else:
                prefix = nuclei_name
            file_pairs.append((nuclei_path, membrane_path, prefix))

        print(f"Created {len(file_pairs)} pairs by sequential matching")

    return file_pairs  # THIS RETURN STATEMENT WAS MISSING

# Main script to process all matching file pairs
def process_all_pairs(nuclei_dir, membrane_dir, output_dir, visualization_dir,
                     proximity_threshold=3, corridor_width=15, barrier_threshold=0.5):
    """
    Process all matching pairs of nuclei and membrane files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files
    output_dir : str
        Directory to save TIF output files
    visualization_dir : str
        Directory to save visualization PNG files
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    """
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    # Print the directories to verify paths
    print(f"Nuclei directory: {nuclei_dir}")
    print(f"Membrane directory: {membrane_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Visualization directory: {visualization_dir}")

    # Find all matching file pairs
    file_pairs = find_file_pairs(nuclei_dir, membrane_dir)

    # Check if we found any pairs
    if not file_pairs:
        print("No file pairs found. Please check the file naming patterns and directories.")

        # Check if directory structure is correct - membrane dir should not be inside nuclei dir
        if membrane_dir.startswith(nuclei_dir):
            print("WARNING: Membrane directory appears to be a subdirectory of the nuclei directory.")
            print("This might cause issues with file matching.")

            # Suggest a possible solution
            suggested_membrane_dir = os.path.join(os.path.dirname(nuclei_dir), "Cadherins")
            print(f"Consider using a parallel directory structure, e.g.: {suggested_membrane_dir}")

        return

    # Process each pair
    for nuclei_path, membrane_path, prefix in tqdm(file_pairs):
        try:
            segment_cells(
                nuclei_path, membrane_path,
                proximity_threshold=proximity_threshold,
                corridor_width=corridor_width,
                barrier_threshold=barrier_threshold,
                output_dir=output_dir,
                visualization_dir=visualization_dir,
                filename_prefix=prefix
            )
            print(f"Completed processing {prefix}")
        except Exception as e:
            print(f"Error processing {prefix}: {e}")
            import traceback
            traceback.print_exc()  # Print stack trace for more detailed error information

# Run the main processing function with adjusted directory paths
# Let's check if the membrane directory is correct or needs adjustment
print("Original directory settings:")
print(f"Nuclei directory: {nuclei_dir}")
print(f"Membrane directory: {membrane_dir}")

# Check if membrane directory appears to be a subdirectory of nuclei directory
if membrane_dir.startswith(nuclei_dir):
    print("\nWARNING: Membrane directory appears to be a subdirectory of nuclei directory.")
    print("This can cause issues with file discovery.")
    print("You have two options:")

    print("\nOption 1: Continue with current directory structure")
    print("This will look for Cadherins files inside the Nuclei_filtered/Cadherins directory.")

    # Option 2: Try a parallel directory structure
    suggested_membrane_dir = os.path.join(os.path.dirname(nuclei_dir), "Cadherins")
    print(f"\nOption 2: Use a parallel directory structure: {suggested_membrane_dir}")
    print("This would look for files in a Cadherins directory next to Nuclei_filtered.")

    # Ask user for confirmation (this would be interactive in a normal notebook)
    # Since we can't be interactive here, I'll add a note
    print("\nSince this is in a code file, we'll proceed with Option 1, but you can modify the paths if needed.")

process_all_pairs(
    nuclei_dir=nuclei_dir,
    membrane_dir=membrane_dir,
    output_dir=output_dir,
    visualization_dir=visualization_dir,
    proximity_threshold=12,
    corridor_width=15,
    barrier_threshold=0.6
)


# Main script to process all matching file pairs
def process_all_pairs(nuclei_dir, membrane_dir, output_dir, visualization_dir,
                     proximity_threshold=3, corridor_width=15, barrier_threshold=0.5):
    """
    Process all matching pairs of nuclei and membrane files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files
    output_dir : str
        Directory to save TIF output files
    visualization_dir : str
        Directory to save visualization PNG files
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    """
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    # Print the directories to verify paths
    print(f"Nuclei directory: {nuclei_dir}")
    print(f"Membrane directory: {membrane_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Visualization directory: {visualization_dir}")

    # Find all matching file pairs
    file_pairs = find_file_pairs(nuclei_dir, membrane_dir)

    # Check if we found any pairs
    if not file_pairs:
        print("No file pairs found. Please check the file naming patterns and directories.")

        # Check if directory structure is correct - membrane dir should not be inside nuclei dir
        if membrane_dir.startswith(nuclei_dir):
            print("WARNING: Membrane directory appears to be a subdirectory of the nuclei directory.")
            print("This might cause issues with file matching.")

            # Suggest a possible solution
            suggested_membrane_dir = os.path.join(os.path.dirname(nuclei_dir), "Cadherins")
            print(f"Consider using a parallel directory structure, e.g.: {suggested_membrane_dir}")

        return

    # Process each pair
    for nuclei_path, membrane_path, prefix in tqdm(file_pairs):
        try:
            segment_cells(
                nuclei_path, membrane_path,
                proximity_threshold=proximity_threshold,
                corridor_width=corridor_width,
                barrier_threshold=barrier_threshold,
                output_dir=output_dir,
                visualization_dir=visualization_dir,
                filename_prefix=prefix
            )
            print(f"Completed processing {prefix}")
        except Exception as e:
            print(f"Error processing {prefix}: {e}")
            import traceback
            traceback.print_exc()


Output hidden; open in https://colab.research.google.com to view.

This one also incorporates Golgi. However, when considering our case, it's possible not to take it into account as we have removed the Golgi and increased the distances.

## Not really

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from skimage import io, measure, segmentation, morphology, draw
import networkx as nx
from tqdm.notebook import tqdm
import os
import re
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define input and output directories
nuclei_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Nuclei_filtered'
membrane_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Cadherins'
golgi_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Golgi'
output_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Seed'
visualization_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/1.4Pa-x20/Seed_res_vis'

# Create output directories if they don't exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

# Define all the functions
def load_images(nuclei_path, membrane_path, golgi_path=None):
    """
    Load nuclei, membrane, and optionally Golgi images.

    Parameters:
    -----------
    nuclei_path : str
        Path to the nuclei mask image
    membrane_path : str
        Path to the membrane mask image
    golgi_path : str, optional
        Path to the Golgi mask image

    Returns:
    --------
    nuclei_mask : ndarray
        Labeled nuclei mask where each nucleus has a unique integer ID
    membrane_mask : ndarray
        Binary membrane mask (1 for membrane, 0 for background)
    golgi_mask : ndarray or None
        Labeled Golgi mask or None if no Golgi image provided
    """
    nuclei_mask = io.imread(nuclei_path)
    membrane_mask = io.imread(membrane_path)

    # Ensure membrane mask is binary
    if membrane_mask.max() > 1:
        membrane_mask = (membrane_mask > 0).astype(np.uint8)

    # Load Golgi mask if provided
    golgi_mask = None
    if golgi_path:
        golgi_mask = io.imread(golgi_path)

    return nuclei_mask, membrane_mask, golgi_mask

def get_nuclei_properties(nuclei_mask):
    """
    Get properties of each nucleus in the mask.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask

    Returns:
    --------
    nuclei_props : list
        List of region properties for each nucleus
    """
    return measure.regionprops(nuclei_mask)

def dilate_nuclei_mask(nuclei_mask, dilation_radius=3):
    """
    Dilate each nucleus in the mask to detect proximity.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    dilation_radius : int
        Radius for dilation operation

    Returns:
    --------
    dilated_masks : dict
        Dictionary mapping nucleus ID to its dilated mask
    """
    dilated_masks = {}
    for i in np.unique(nuclei_mask):
        if i == 0:  # Skip background
            continue

        # Create binary mask for this nucleus
        binary_mask = (nuclei_mask == i).astype(np.uint8)

        # Dilate the mask
        dilated_binary_mask = ndimage.binary_dilation(
            binary_mask,
            iterations=dilation_radius
        ).astype(np.uint8)

        dilated_masks[i] = dilated_binary_mask

    return dilated_masks

def find_overlapping_nuclei(dilated_masks, nuclei_props):
    """
    Find pairs of nuclei whose dilated masks overlap.

    Parameters:
    -----------
    dilated_masks : dict
        Dictionary of dilated masks for each nucleus
    nuclei_props : list
        List of region properties for each nucleus

    Returns:
    --------
    overlapping_pairs : list
        List of tuples (id1, id2, centroid1, centroid2) for overlapping nuclei
    """
    nuclei_ids = list(dilated_masks.keys())
    overlapping_pairs = []

    # Create mapping from nucleus ID to its centroid
    centroids = {prop.label: prop.centroid for prop in nuclei_props}

    for i, id1 in enumerate(nuclei_ids):
        for id2 in nuclei_ids[i+1:]:
            # Check if dilated masks overlap
            if np.any(dilated_masks[id1] * dilated_masks[id2]):
                overlapping_pairs.append((
                    id1, id2,
                    centroids[id1], centroids[id2]
                ))

    # Sort by proximity (using centroid distance as an approximation)
    overlapping_pairs.sort(key=lambda x: np.sqrt(
        (x[2][0] - x[3][0])**2 + (x[2][1] - x[3][1])**2
    ))

    return overlapping_pairs

def create_line_corridor(centroid1, centroid2, corridor_width=15):
    """
    Create a corridor of parallel lines between two centroids.

    Parameters:
    -----------
    centroid1, centroid2 : tuple
        Centroids of two nuclei (y, x)
    corridor_width : int
        Number of parallel lines in the corridor

    Returns:
    --------
    lines : list
        List of lines, where each line is a list of points (y, x)
    """
    # Convert to integer coordinates
    y1, x1 = int(centroid1[0]), int(centroid1[1])
    y2, x2 = int(centroid2[0]), int(centroid2[1])

    # Calculate direction vector and perpendicular vector
    dx, dy = x2 - x1, y2 - y1
    length = np.sqrt(dx**2 + dy**2)

    if length == 0:  # Handle the case where centroids are at the same position
        return []

    # Normalize direction vector
    dx, dy = dx / length, dy / length

    # Calculate perpendicular vector (normalized)
    px, py = -dy, dx

    # Half the corridor width
    half_width = corridor_width // 2

    # Generate parallel lines
    lines = []
    for i in range(-half_width, half_width + 1):
        # Calculate offset for this line
        offset_x, offset_y = i * px, i * py

        # Calculate start and end points for this line
        start_y, start_x = y1 + offset_y, x1 + offset_x
        end_y, end_x = y2 + offset_y, x2 + offset_x

        # Use Bresenham's line algorithm to get points along the line
        rr, cc = draw.line(int(start_y), int(start_x), int(end_y), int(end_x))

        # Add line points to the list
        line_points = list(zip(rr, cc))
        lines.append(line_points)

    return lines

def check_membrane_barrier(lines, membrane_mask, threshold=0.5):
    """
    Check if there's a membrane barrier between two nuclei.

    Parameters:
    -----------
    lines : list
        List of lines in the corridor
    membrane_mask : ndarray
        Binary membrane mask
    threshold : float
        Fraction of lines that must be blocked to consider it a barrier

    Returns:
    --------
    has_barrier : bool
        True if a membrane barrier exists, False otherwise
    blocked_count : int
        Number of lines that are blocked by a membrane
    """
    # Skip empty lines list (could happen if centroids are at the same position)
    if not lines:
        return False, 0

    blocked_lines = 0

    for line_points in lines:
        # Check each line independently
        line_blocked = False

        for y, x in line_points:
            # Check if point is within image bounds
            if (0 <= y < membrane_mask.shape[0] and
                0 <= x < membrane_mask.shape[1]):

                # Check if point is on a membrane
                if membrane_mask[y, x] == 1:
                    line_blocked = True
                    break

        if line_blocked:
            blocked_lines += 1

    # Check if enough lines are blocked to consider it a barrier
    has_barrier = (blocked_lines / len(lines)) >= threshold

    return has_barrier, blocked_lines

def check_golgi_linkage(nuclei_mask, golgi_mask, id1, id2):
    """
    Check if two nuclei are linked by a Golgi structure.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    golgi_mask : ndarray
        Labeled Golgi mask
    id1, id2 : int
        IDs of the two nuclei to check

    Returns:
    --------
    linked : bool
        True if the nuclei are linked by the same Golgi structure
    """
    # If no Golgi mask is provided, return False
    if golgi_mask is None:
        return False

    # Create masks for each nucleus
    mask1 = (nuclei_mask == id1)
    mask2 = (nuclei_mask == id2)

    # Dilate nucleus masks slightly to detect Golgi that's directly adjacent
    dilated_mask1 = ndimage.binary_dilation(mask1, iterations=2)
    dilated_mask2 = ndimage.binary_dilation(mask2, iterations=2)

    # Find Golgi IDs that overlap with each nucleus
    golgi_ids1 = set(np.unique(golgi_mask[dilated_mask1]))
    golgi_ids2 = set(np.unique(golgi_mask[dilated_mask2]))

    # Remove background (ID 0)
    if 0 in golgi_ids1:
        golgi_ids1.remove(0)
    if 0 in golgi_ids2:
        golgi_ids2.remove(0)

    # Check if there's any common Golgi ID between the two nuclei
    common_golgi_ids = golgi_ids1.intersection(golgi_ids2)

    # If there are common Golgi IDs, the nuclei are linked
    return len(common_golgi_ids) > 0

def merge_nuclei(nuclei_mask, overlapping_pairs, membrane_mask, golgi_mask=None,
                 corridor_width=15, barrier_threshold=0.5):
    """
    Merge nuclei that belong to the same cell based on membrane barriers and Golgi linkage.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Labeled nuclei mask
    overlapping_pairs : list
        List of overlapping nuclei pairs
    membrane_mask : ndarray
        Binary membrane mask
    golgi_mask : ndarray, optional
        Labeled Golgi mask
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists

    Returns:
    --------
    merged_mask : ndarray
        Nuclei mask after merging
    merge_graph : networkx.Graph
        Graph representing merge operations
    """
    # Create a copy of the nuclei mask
    merged_mask = nuclei_mask.copy()

    # Create a graph to track merges
    merge_graph = nx.Graph()

    # Add all nuclei as nodes
    for label in np.unique(nuclei_mask):
        if label > 0:  # Skip background
            merge_graph.add_node(label)

    print(f"Processing {len(overlapping_pairs)} overlapping nuclei pairs...")

    # Process each overlapping pair
    for id1, id2, centroid1, centroid2 in tqdm(overlapping_pairs):
        # Check if these IDs still exist (haven't been merged already)
        if id1 not in np.unique(merged_mask) or id2 not in np.unique(merged_mask):
            continue

        # Create corridor of lines between centroids
        lines = create_line_corridor(centroid1, centroid2, corridor_width)

        # Check if there's a membrane barrier
        has_barrier, blocked_count = check_membrane_barrier(
            lines, membrane_mask, barrier_threshold
        )

        # Check if nuclei are linked by Golgi (new condition)
        golgi_linked = check_golgi_linkage(nuclei_mask, golgi_mask, id1, id2)

        # Merge if EITHER there's no membrane barrier OR they're linked by Golgi
        if not has_barrier or golgi_linked:
            merge_reason = "no membrane barrier" if not has_barrier else "Golgi linkage"
            if golgi_linked and not has_barrier:
                merge_reason = "no membrane barrier AND Golgi linkage"

            print(f"Merging nuclei {id1} and {id2} due to {merge_reason}")
            if not has_barrier:
                print(f"  - Blocked lines: {blocked_count}/{len(lines)}")

            # Always merge higher ID into lower ID
            source_id = max(id1, id2)
            target_id = min(id1, id2)

            # Update the mask
            merged_mask[merged_mask == source_id] = target_id

            # Add edge in the merge graph
            merge_graph.add_edge(source_id, target_id)

    return merged_mask, merge_graph

def relabel_mask(mask):
    """
    Relabel a mask to have consecutive IDs.

    Parameters:
    -----------
    mask : ndarray
        Input mask

    Returns:
    --------
    relabeled_mask : ndarray
        Mask with consecutive IDs
    """
    # Get unique IDs (excluding 0/background)
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids > 0]

    # Create mapping from old IDs to new IDs
    id_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_ids, 1)}

    # Create new mask
    relabeled_mask = np.zeros_like(mask)

    # Apply mapping
    for old_id, new_id in id_mapping.items():
        relabeled_mask[mask == old_id] = new_id

    return relabeled_mask

def visualize_results(nuclei_mask, membrane_mask, merged_mask, save_path=None):
    """
    Visualize the segmentation results.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Original nuclei mask
    membrane_mask : ndarray
        Membrane mask
    merged_mask : ndarray
        Nuclei mask after merging
    save_path : str, optional
        Path to save the visualization
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Plot original nuclei mask
    axes[0].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[0].set_title('Original Nuclei Mask')
    axes[0].axis('off')

    # Plot membrane mask overlay
    axes[1].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[1].imshow(membrane_mask, cmap='gray', alpha=0.5)
    axes[1].set_title('Nuclei with Membrane Overlay')
    axes[1].axis('off')

    # Plot merged mask
    axes[2].imshow(merged_mask, cmap='nipy_spectral')
    axes[2].set_title('Merged Nuclei Mask')
    axes[2].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def visualize_results_with_golgi(nuclei_mask, membrane_mask, golgi_mask, merged_mask, save_path=None):
    """
    Visualize the segmentation results including Golgi.

    Parameters:
    -----------
    nuclei_mask : ndarray
        Original nuclei mask
    membrane_mask : ndarray
        Membrane mask
    golgi_mask : ndarray
        Golgi mask
    merged_mask : ndarray
        Nuclei mask after merging
    save_path : str, optional
        Path to save the visualization
    """
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    axes = axes.flatten()

    # Plot original nuclei mask
    axes[0].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[0].set_title('Original Nuclei Mask')
    axes[0].axis('off')

    # Plot membrane mask overlay
    axes[1].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[1].imshow(membrane_mask, cmap='gray', alpha=0.5)
    axes[1].set_title('Nuclei with Membrane Overlay')
    axes[1].axis('off')

    # Plot Golgi mask overlay
    axes[2].imshow(nuclei_mask, cmap='nipy_spectral')
    axes[2].imshow(golgi_mask > 0, cmap='Greens', alpha=0.6)
    axes[2].set_title('Nuclei with Golgi Overlay')
    axes[2].axis('off')

    # Plot merged mask
    axes[3].imshow(merged_mask, cmap='nipy_spectral')
    axes[3].set_title('Merged Nuclei Mask')
    axes[3].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def visualize_merge_graph(merge_graph, save_path=None):
    """
    Visualize the merge graph.

    Parameters:
    -----------
    merge_graph : networkx.Graph
        Graph representing merge operations
    save_path : str, optional
        Path to save the visualization
    """
    plt.figure(figsize=(10, 8))

    # Get connected components (each represents a merged cell)
    components = list(nx.connected_components(merge_graph))

    # Assign a different color to each component
    color_map = {}
    for i, component in enumerate(components):
        for node in component:
            color_map[node] = i

    # Set node colors
    node_colors = [color_map.get(node, len(components)) for node in merge_graph.nodes()]

    # Draw the graph
    pos = nx.spring_layout(merge_graph, seed=42)
    nx.draw_networkx(
        merge_graph, pos,
        node_color=node_colors,
        cmap=plt.cm.tab20,
        node_size=200,
        with_labels=True
    )

    plt.title(f'Nuclei Merge Graph ({len(components)} cells with multiple nuclei)')
    plt.axis('off')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def segment_cells(nuclei_path, membrane_path, golgi_path=None, proximity_threshold=3,
                 corridor_width=15, barrier_threshold=0.5, output_dir=None,
                 visualization_dir=None, filename_prefix=None):
    """
    Main function to segment cells using nuclei, membrane, and optionally Golgi images.

    Parameters:
    -----------
    nuclei_path : str
        Path to nuclei mask image
    membrane_path : str
        Path to membrane mask image
    golgi_path : str, optional
        Path to Golgi mask image
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    output_dir : str, optional
        Directory to save output TIF files
    visualization_dir : str, optional
        Directory to save visualization PNG files
    filename_prefix : str, optional
        Prefix to use for output filenames

    Returns:
    --------
    merged_mask : ndarray
        Final segmented cell mask
    """
    print(f"Processing {filename_prefix}...")
    print("Loading images...")
    nuclei_mask, membrane_mask, golgi_mask = load_images(nuclei_path, membrane_path, golgi_path)

    print(f"Nuclei mask shape: {nuclei_mask.shape}, unique IDs: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Membrane mask shape: {membrane_mask.shape}, values: {np.unique(membrane_mask)}")

    if golgi_mask is not None:
        print(f"Golgi mask shape: {golgi_mask.shape}, unique IDs: {len(np.unique(golgi_mask)) - 1}")
    else:
        print("No Golgi mask provided.")

    print("Analyzing nuclei properties...")
    nuclei_props = get_nuclei_properties(nuclei_mask)

    print(f"Dilating nuclei with radius {proximity_threshold}...")
    dilated_masks = dilate_nuclei_mask(nuclei_mask, proximity_threshold)

    print("Finding overlapping nuclei...")
    overlapping_pairs = find_overlapping_nuclei(dilated_masks, nuclei_props)
    print(f"Found {len(overlapping_pairs)} potentially overlapping nuclei pairs")

    print("Merging nuclei based on membrane barriers and Golgi linkage...")
    merged_mask, merge_graph = merge_nuclei(
        nuclei_mask, overlapping_pairs, membrane_mask, golgi_mask,
        corridor_width, barrier_threshold
    )

    print("Relabeling mask to have consecutive IDs...")
    final_mask = relabel_mask(merged_mask)

    print(f"Original nuclei count: {len(np.unique(nuclei_mask)) - 1}")
    print(f"Final cell count: {len(np.unique(final_mask)) - 1}")

    # Save outputs if directory is provided
    if output_dir:
        mask_filename = f"{filename_prefix}_segmented_cells.tif" if filename_prefix else "segmented_cells.tif"
        io.imsave(os.path.join(output_dir, mask_filename), final_mask.astype(np.uint16))
        print(f"Saved TIF file to: {os.path.join(output_dir, mask_filename)}")

    # Save visualizations if directory is provided
    if visualization_dir:
        # Create a subfolder for this image if filename_prefix is provided
        if filename_prefix:
            image_vis_dir = os.path.join(visualization_dir, filename_prefix)
            os.makedirs(image_vis_dir, exist_ok=True)
        else:
            image_vis_dir = visualization_dir

        # Save visualizations
        vis_filename = f"{filename_prefix}_segmentation_results.png" if filename_prefix else "segmentation_results.png"

        # Select the appropriate visualization based on whether Golgi is available
        if golgi_mask is not None:
            visualize_results_with_golgi(
                nuclei_mask, membrane_mask, golgi_mask, final_mask,
                save_path=os.path.join(image_vis_dir, vis_filename)
            )
        else:
            visualize_results(
                nuclei_mask, membrane_mask, final_mask,
                save_path=os.path.join(image_vis_dir, vis_filename)
            )
        print(f"Saved visualization to: {os.path.join(image_vis_dir, vis_filename)}")

        graph_filename = f"{filename_prefix}_merge_graph.png" if filename_prefix else "merge_graph.png"
        visualize_merge_graph(
            merge_graph,
            save_path=os.path.join(image_vis_dir, graph_filename)
        )
        print(f"Saved merge graph to: {os.path.join(image_vis_dir, graph_filename)}")
    else:
        # Display visualizations
        if golgi_mask is not None:
            visualize_results_with_golgi(nuclei_mask, membrane_mask, golgi_mask, final_mask)
        else:
            visualize_results(nuclei_mask, membrane_mask, final_mask)
        visualize_merge_graph(merge_graph)

    return final_mask

def find_file_triplets(nuclei_dir, membrane_dir, golgi_dir=None):
    """
    Find matching triplets of nuclei, membrane, and optionally Golgi files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files
    golgi_dir : str, optional
        Directory containing Golgi mask files

    Returns:
    --------
    file_triplets : list
        List of tuples (nuclei_path, membrane_path, golgi_path, common_prefix)
        If no Golgi file is found, golgi_path will be None
    """
    # List all files in directories
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith('.tif')]
    membrane_files = [f for f in os.listdir(membrane_dir) if f.endswith('.tif')]
    golgi_files = [] if golgi_dir is None else [f for f in os.listdir(golgi_dir) if f.endswith('.tif')]

    print(f"Found {len(nuclei_files)} .tif files in nuclei directory")
    print(f"Found {len(membrane_files)} .tif files in membrane directory")
    if golgi_dir:
        print(f"Found {len(golgi_files)} .tif files in Golgi directory")

    # Match files based on the common prefix pattern
    file_triplets = []

    # First try to extract the prefix from filename parts
    for nuclei_file in nuclei_files:
        if "Nuclei" not in nuclei_file:
            continue

        # Extract the prefix (everything before "_Nuclei")
        nuclei_parts = nuclei_file.split('_')

        try:
            nuclei_index = nuclei_parts.index("Nuclei")
            prefix_parts = nuclei_parts[:nuclei_index]
            prefix = "_".join(prefix_parts)

            # Look for matching membrane file
            matching_membrane_file = None
            for membrane_file in membrane_files:
                if "Cadherins" not in membrane_file:
                    continue

                membrane_parts = membrane_file.split('_')
                try:
                    cadherins_index = membrane_parts.index("Cadherins")
                    membrane_prefix_parts = membrane_parts[:cadherins_index]
                    membrane_prefix = "_".join(membrane_prefix_parts)

                    if prefix == membrane_prefix:
                        matching_membrane_file = membrane_file
                        break
                except ValueError:
                    continue

            # Look for matching Golgi file
            matching_golgi_file = None
            if golgi_dir:
                for golgi_file in golgi_files:
                    if "Golgi" not in golgi_file:
                        continue

                    golgi_parts = golgi_file.split('_')
                    try:
                        golgi_index = golgi_parts.index("Golgi")
                        golgi_prefix_parts = golgi_parts[:golgi_index]
                        golgi_prefix = "_".join(golgi_prefix_parts)

                        if prefix == golgi_prefix:
                            matching_golgi_file = golgi_file
                            break
                    except ValueError:
                        continue

            # If we found a matching membrane file, create triplet
            if matching_membrane_file:
                nuclei_path = os.path.join(nuclei_dir, nuclei_file)
                membrane_path = os.path.join(membrane_dir, matching_membrane_file)
                golgi_path = None if matching_golgi_file is None else os.path.join(golgi_dir, matching_golgi_file)

                file_triplets.append((nuclei_path, membrane_path, golgi_path, prefix))

        except ValueError:
            continue

    # If still no matches, try seq number matching
    if not file_triplets:
        print("Trying to match based on sequence numbers...")

        for nuclei_file in nuclei_files:
            # Look for sequence pattern like "seq018"
            seq_match = re.search(r'seq(\d+)', nuclei_file)
            if not seq_match:
                continue

            seq_num = seq_match.group(0)  # e.g., "seq018"

            # Find matching membrane file
            matching_membrane_file = None
            for membrane_file in membrane_files:
                if seq_num in membrane_file:
                    matching_membrane_file = membrane_file
                    break

            # Find matching Golgi file
            matching_golgi_file = None
            if golgi_dir:
                for golgi_file in golgi_files:
                    if seq_num in golgi_file:
                        matching_golgi_file = golgi_file
                        break

            # If we found a matching membrane file, create triplet
            if matching_membrane_file:
                nuclei_path = os.path.join(nuclei_dir, nuclei_file)
                membrane_path = os.path.join(membrane_dir, matching_membrane_file)
                golgi_path = None if matching_golgi_file is None else os.path.join(golgi_dir, matching_golgi_file)

                # Use sequence number as part of the prefix
                prefix = f"sequence_{seq_num}"
                file_triplets.append((nuclei_path, membrane_path, golgi_path, prefix))

    print(f"Found {len(file_triplets)} matching file triplets")
    print(f"Of these, {sum(1 for _, _, g, _ in file_triplets if g is not None)} include Golgi files")

    # Last resort - manual pairing
    if not file_triplets and nuclei_files and membrane_files:
        print("No automatic matches found. Using manual pairing...")

        # Pair files in sequence (first nuclei with first membrane, etc.)
        n = min(len(nuclei_files), len(membrane_files))
        for i in range(n):
            nuclei_path = os.path.join(nuclei_dir, nuclei_files[i])
            membrane_path = os.path.join(membrane_dir, membrane_files[i])

            # Try to find matching Golgi file
            golgi_path = None
            if golgi_dir and i < len(golgi_files):
                golgi_path = os.path.join(golgi_dir, golgi_files[i])

            # Extract a meaningful prefix from the nuclei filename
            nuclei_name = os.path.splitext(nuclei_files[i])[0]
            if '_Nuclei' in nuclei_name:
                prefix = nuclei_name.split('_Nuclei')[0]
            else:
                prefix = nuclei_name
            file_triplets.append((nuclei_path, membrane_path, golgi_path, prefix))

        print(f"Created {len(file_triplets)} pairs by sequential matching")

    return file_triplets

def process_all_triplets(nuclei_dir, membrane_dir, golgi_dir, output_dir, visualization_dir,
                       proximity_threshold=3, corridor_width=15, barrier_threshold=0.5):
    """
    Process all matching triplets of nuclei, membrane, and Golgi files.

    Parameters:
    -----------
    nuclei_dir : str
        Directory containing nuclei mask files
    membrane_dir : str
        Directory containing membrane mask files
    golgi_dir : str
        Directory containing Golgi mask files
    output_dir : str
        Directory to save TIF output files
    visualization_dir : str
        Directory to save visualization PNG files
    proximity_threshold : int
        Dilation radius for proximity detection
    corridor_width : int
        Width of the corridor for barrier detection
    barrier_threshold : float
        Threshold for determining if a barrier exists
    """
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    # Print the directories to verify paths
    print(f"Nuclei directory: {nuclei_dir}")
    print(f"Membrane directory: {membrane_dir}")
    print(f"Golgi directory: {golgi_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Visualization directory: {visualization_dir}")

    # Find all matching file triplets
    file_triplets = find_file_triplets(nuclei_dir, membrane_dir, golgi_dir)

    # Check if we found any triplets
    if not file_triplets:
        print("No file triplets found. Please check the file naming patterns and directories.")

        # Check if directory structure is correct - membrane dir should not be inside nuclei dir
        if membrane_dir.startswith(nuclei_dir):
            print("WARNING: Membrane directory appears to be a subdirectory of the nuclei directory.")
            print("This might cause issues with file matching.")

            # Suggest a possible solution
            suggested_membrane_dir = os.path.join(os.path.dirname(nuclei_dir), "Cadherins")
            print(f"Consider using a parallel directory structure, e.g.: {suggested_membrane_dir}")

        return

    # Process each triplet
    for nuclei_path, membrane_path, golgi_path, prefix in tqdm(file_triplets):
        try:
            segment_cells(
                nuclei_path, membrane_path, golgi_path,
                proximity_threshold=proximity_threshold,
                corridor_width=corridor_width,
                barrier_threshold=barrier_threshold,
                output_dir=output_dir,
                visualization_dir=visualization_dir,
                filename_prefix=prefix
            )
            print(f"Completed processing {prefix}")

            # Indicate if Golgi was used in this segmentation
            if golgi_path:
                print(f"  - Used Golgi mask for additional merging criteria")
            else:
                print(f"  - No Golgi mask available for this image")

        except Exception as e:
            print(f"Error processing {prefix}: {e}")
            import traceback
            traceback.print_exc()

# Check if the directories exist
print("Checking directory structure...")
if not os.path.exists(nuclei_dir):
    print(f"WARNING: Nuclei directory does not exist: {nuclei_dir}")
if not os.path.exists(membrane_dir):
    print(f"WARNING: Membrane directory does not exist: {membrane_dir}")
if not os.path.exists(golgi_dir):
    print(f"WARNING: Golgi directory does not exist: {golgi_dir}")
    print("Creating Golgi directory...")
    os.makedirs(golgi_dir, exist_ok=True)

# Main execution - run the processing with Golgi integration
process_all_triplets(
    nuclei_dir=nuclei_dir,
    membrane_dir=membrane_dir,
    golgi_dir=golgi_dir,
    output_dir=output_dir,
    visualization_dir=visualization_dir,
    proximity_threshold=14,
    corridor_width=15,
    barrier_threshold=0.6
)

Output hidden; open in https://colab.research.google.com to view.