In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation, feature
from scipy import ndimage
from collections import Counter
import tifffile
from pathlib import Path
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection

# Set up matplotlib for better visualization
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')

# Define paths based on your Google Drive structure
base_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-2"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/Static-A-2"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Define function to extract sample ID from filenames
def extract_sample_id(filename):
    # Example: denoised_0Pa_A1_20dec21_40x_L2RA_FlatA_seq007_contrast_Nuclei_mask.tif
    # We want to extract everything before the component name (Nuclei, Golgi, etc.)
    parts = str(filename).split('_')
    # Find the index of component parts (Nuclei, Golgi, membrane, cell)
    components = ['Nuclei', 'Golgi', 'membrane', 'cell']
    for i, part in enumerate(parts):
        if part in components:
            return '_'.join(parts[:i])
    return None  # If no match found

# Function to find matching files across different folders
def find_matching_files(base_dir):
    # Define component folders to look in (based on your Google Drive structure)
    component_folders = ['Nuclei', 'Membrane_Adjusted', 'Golgi', 'Cell']

    # Dictionary to store file paths for each sample and component
    file_dict = {}

    # Scan through each component folder
    for component in component_folders:
        component_dir = os.path.join(base_dir, component)
        if not os.path.exists(component_dir):
            print(f"Warning: {component_dir} does not exist.")
            continue

        # Get all TIFF files in the component folder
        files = [f for f in os.listdir(component_dir) if f.endswith('.tif')]
        print(f"Found {len(files)} TIFF files in {component} folder.")

        for file in files:
            # Extract the sample ID from the filename
            sample_id = extract_sample_id(file)
            if sample_id:
                if sample_id not in file_dict:
                    file_dict[sample_id] = {}

                # Store the full path to the file
                # Use 'membrane' as the key for consistency even though folder is Membrane_Adjusted
                if component == 'Membrane_Adjusted':
                    file_dict[sample_id]['membrane'] = os.path.join(component_dir, file)
                else:
                    file_dict[sample_id][component.lower()] = os.path.join(component_dir, file)

    # Find samples that have files in all components
    complete_samples = []
    for sample_id, components in file_dict.items():
        # Check if we have at least nuclei and cell data
        if 'nuclei' in components and 'cell' in components:
            complete_samples.append(sample_id)

    print(f"Found {len(complete_samples)} complete samples with at least nuclei and cell data.")

    return file_dict, complete_samples

# Function to load images for a specific sample
def load_sample_images(sample_id, file_dict):
    images = {}
    for component, filepath in file_dict[sample_id].items():
        if os.path.exists(filepath):
            try:
                # Read image and ensure it's an integer type
                img = io.imread(filepath)

                # Convert boolean images to uint8
                if img.dtype == bool:
                    img = img.astype(np.uint8)

                # Handle binary images that should be labeled
                if component in ['cell', 'nuclei'] and np.max(img) <= 1:
                    print(f"Converting binary {component} image to labeled image")
                    img, num_labels = ndimage.label(img)
                    print(f"Found {num_labels} {component} regions")

                images[component] = img
                print(f"Loaded {component} image: shape {img.shape}, dtype: {img.dtype}, value range [{np.min(img)}, {np.max(img)}]")
            except Exception as e:
                print(f"Error loading {filepath}: {str(e)}")
        else:
            print(f"Warning: File not found - {filepath}")

    return images

# Function to detect senescent cells
def detect_senescent_cells(cell_image, nucleus_image, expected_senescent_fraction=0.3,
                     size_threshold_factor=1.5, holes_ratio_quantile=0.7):
    """
    Detect senescent cells based on multiple features and adjust masks accordingly.

    Parameters:
    cell_image: Label image of cells
    nucleus_image: Label image of nuclei
    expected_senescent_fraction: Expected fraction of senescent cells (default 0.3)
    size_threshold_factor: Factor multiplied by std dev to set size threshold (default 1.5)
    holes_ratio_quantile: Quantile threshold for holes-to-cell ratio (default 0.7)

    Returns:
    Dictionary with detection results and adjusted masks
    """
    # Ensure we have integer type images
    cell_image = cell_image.astype(np.int32)
    nucleus_image = nucleus_image.astype(np.int32)

    # Get cell properties
    cell_props = measure.regionprops(cell_image)

    # Prepare arrays for storing metrics
    cell_metrics = []

    # Process each cell
    for cell_prop in cell_props:
        cell_id = cell_prop.label

        # Skip very small objects (likely artifacts)
        if cell_prop.area < 100:
            continue

        # Create binary mask for current cell
        cell_mask = (cell_image == cell_id)

        # Count nuclei that overlap with this cell
        nuclei_in_cell = np.unique(nucleus_image[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background (0)
        nuclei_count = len(nuclei_in_cell)

        # Calculate total nuclear area within this cell
        nuclear_area = np.sum(np.isin(nucleus_image, nuclei_in_cell) & cell_mask)
        nuclear_cytoplasmic_ratio = nuclear_area / cell_prop.area if cell_prop.area > 0 else 0

        # Calculate hole properties
        # Invert the cell mask to detect holes
        filled_mask = ndimage.binary_fill_holes(cell_mask)
        holes_mask = filled_mask & ~cell_mask

        # Label the holes
        labeled_holes, num_holes = ndimage.label(holes_mask)
        hole_sizes = [np.sum(labeled_holes == i) for i in range(1, num_holes + 1)]
        total_hole_area = np.sum(holes_mask)
        holes_to_cell_ratio = total_hole_area / cell_prop.area if cell_prop.area > 0 else 0

        # Extract shape metrics
        perimeter = cell_prop.perimeter if cell_prop.perimeter else 0
        circularity = (4 * np.pi * cell_prop.area) / (perimeter * perimeter) if perimeter > 0 else 0
        solidity = cell_prop.solidity

        # Store all metrics
        metrics = {
            'cell_id': cell_id,
            'area': cell_prop.area,
            'perimeter': perimeter,
            'circularity': circularity,
            'solidity': solidity,
            'nuclei_count': nuclei_count,
            'nuclear_area': nuclear_area,
            'nuclear_cytoplasmic_ratio': nuclear_cytoplasmic_ratio,
            'num_holes': num_holes,
            'total_hole_area': total_hole_area,
            'holes_to_cell_ratio': holes_to_cell_ratio
        }

        cell_metrics.append(metrics)

    # Handle empty metrics case
    if not cell_metrics:
        print("No valid cells found for analysis")
        return {
            'cell_metrics': pd.DataFrame(),
            'original_cell_image': cell_image,
            'adjusted_cell_image': cell_image.copy(),
            'senescent_count': 0,
            'total_cells': 0,
            'senescent_fraction': 0
        }

    # Convert to DataFrame for easier analysis
    metrics_df = pd.DataFrame(cell_metrics)

    # Calculate thresholds for senescence detection
    # We'll use multiple features with adaptive thresholds

    # 1. Cell size threshold (senescent cells are larger)
    size_mean = metrics_df['area'].mean()
    size_std = metrics_df['area'].std()
    size_threshold = size_mean + size_threshold_factor * size_std  # Cells above threshold are considered large

    # 2. Multinucleation threshold
    multinucleated = metrics_df['nuclei_count'] > 1

    # 3. Holes-to-cell ratio threshold (senescent cells may have more holes)
    holes_ratio_threshold = metrics_df['holes_to_cell_ratio'].quantile(holes_ratio_quantile)  # Top 30% by default

    # Log thresholds for debugging
    print(f"Size threshold: {size_threshold:.1f} pixels (mean {size_mean:.1f} + {size_threshold_factor} * std {size_std:.1f})")
    print(f"Holes ratio threshold: {holes_ratio_threshold:.4f} (quantile {holes_ratio_quantile})")

    # Combine features to classify senescent cells
    # Initialize senescence score
    metrics_df['senescence_score'] = 0

    # Add points for each senescence indicator
    metrics_df.loc[metrics_df['area'] > size_threshold, 'senescence_score'] += 1
    metrics_df.loc[multinucleated, 'senescence_score'] += 1
    metrics_df.loc[metrics_df['holes_to_cell_ratio'] > holes_ratio_threshold, 'senescence_score'] += 1

    # Dynamically set threshold to achieve expected senescent fraction
    # Handle the case where all cells might have the same score
    unique_scores = metrics_df['senescence_score'].unique()
    if len(unique_scores) == 1:
        # If all cells have the same score, use that score as threshold
        # Will classify all as senescent if score > 0, none if score = 0
        score_threshold = unique_scores[0]
    else:
        # Normal case - use quantile to set threshold
        score_threshold = np.quantile(metrics_df['senescence_score'], 1 - expected_senescent_fraction)

    # Classify cells
    metrics_df['is_senescent'] = metrics_df['senescence_score'] >= score_threshold

    # Create adjusted cell masks by filling holes in senescent cells
    adjusted_cell_image = cell_image.copy()

    for _, row in metrics_df[metrics_df['is_senescent']].iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        # Fill holes in senescent cells
        filled_mask = ndimage.binary_fill_holes(cell_mask)

        # Update the adjusted image with the filled mask
        # We need to handle overlaps with other cells
        # First, remove the original cell
        adjusted_cell_image[cell_mask] = 0

        # Then add the filled version
        adjusted_cell_image[filled_mask] = cell_id

    # Calculate statistics
    senescent_count = metrics_df['is_senescent'].sum()
    total_cells = len(metrics_df)
    senescent_fraction = senescent_count / total_cells if total_cells > 0 else 0

    print(f"Detected {senescent_count} senescent cells out of {total_cells} total cells ({senescent_fraction:.2%})")

    # Return results
    return {
        'cell_metrics': metrics_df,
        'original_cell_image': cell_image,
        'adjusted_cell_image': adjusted_cell_image,
        'senescent_count': senescent_count,
        'total_cells': total_cells,
        'senescent_fraction': senescent_fraction
    }

# Function to visualize senescence detection results
def visualize_senescence_detection(results, cell_image, nucleus_image, output_path=None):
    """
    Visualize the senescence detection results.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    # Check if we have valid results
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))

    # 1. Original cell image with nuclei overlay
    cell_boundary = segmentation.find_boundaries(cell_image)
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)

    overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Cell boundaries in green
    overlay[cell_boundary, 0] = 0
    overlay[cell_boundary, 1] = 255
    overlay[cell_boundary, 2] = 0
    # Nucleus boundaries in blue
    overlay[nucleus_boundary, 0] = 0
    overlay[nucleus_boundary, 1] = 0
    overlay[nucleus_boundary, 2] = 255

    axes[0, 0].imshow(overlay)
    axes[0, 0].set_title("Original Cells (green) and Nuclei (blue)")
    axes[0, 0].axis('off')

    # 2. Senescent vs normal cells
    senescent_mask = np.zeros_like(cell_image, dtype=bool)
    normal_mask = np.zeros_like(cell_image, dtype=bool)

    cell_metrics = results['cell_metrics']

    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        if row['is_senescent']:
            senescent_mask |= cell_mask
        else:
            normal_mask |= cell_mask

    classification = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Senescent cells in red
    classification[senescent_mask, 0] = 255
    classification[senescent_mask, 1] = 0
    classification[senescent_mask, 2] = 0
    # Normal cells in green
    classification[normal_mask, 0] = 0
    classification[normal_mask, 1] = 255
    classification[normal_mask, 2] = 0

    axes[0, 1].imshow(classification)
    axes[0, 1].set_title(f"Senescent Cells (red): {results['senescent_fraction']:.1%}")
    axes[0, 1].axis('off')

    # 3. Adjusted cell masks
    adjusted_boundary = segmentation.find_boundaries(results['adjusted_cell_image'])

    adjusted_overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Original cell boundaries in green
    adjusted_overlay[cell_boundary, 0] = 0
    adjusted_overlay[cell_boundary, 1] = 255
    adjusted_overlay[cell_boundary, 2] = 0
    # Adjusted cell boundaries in yellow
    adjusted_overlay[adjusted_boundary, 0] = 255
    adjusted_overlay[adjusted_boundary, 1] = 255
    adjusted_overlay[adjusted_boundary, 2] = 0

    axes[1, 0].imshow(adjusted_overlay)
    axes[1, 0].set_title("Original (green) vs Adjusted (yellow) Cell Boundaries")
    axes[1, 0].axis('off')

    # 4. Metrics visualization
    senescent_metrics = cell_metrics[cell_metrics['is_senescent']]
    normal_metrics = cell_metrics[~cell_metrics['is_senescent']]

    # Choose one key metric for visualization
    metric = 'area'

    if len(senescent_metrics) > 0 and len(normal_metrics) > 0:
        axes[1, 1].boxplot([normal_metrics[metric], senescent_metrics[metric]],
                          labels=['Normal', 'Senescent'])
        axes[1, 1].set_ylabel(f'Cell {metric}')
        axes[1, 1].set_title(f'Distribution of {metric} by Cell Type')
    else:
        axes[1, 1].text(0.5, 0.5, "Not enough data for box plot",
                        ha='center', va='center', transform=axes[1, 1].transAxes)

    # Add metrics summary as text
    if len(senescent_metrics) > 0:
        sen_text = (f"Senescent cells (n={len(senescent_metrics)}):\n" +
                   f"Mean area: {senescent_metrics['area'].mean():.1f}\n" +
                   f"Mean nuclei count: {senescent_metrics['nuclei_count'].mean():.1f}")
    else:
        sen_text = "No senescent cells detected"

    if len(normal_metrics) > 0:
        norm_text = (f"Normal cells (n={len(normal_metrics)}):\n" +
                    f"Mean area: {normal_metrics['area'].mean():.1f}\n" +
                    f"Mean nuclei count: {normal_metrics['nuclei_count'].mean():.1f}")
    else:
        norm_text = "No normal cells detected"

    fig.text(0.02, 0.02, sen_text, fontsize=10)
    fig.text(0.52, 0.02, norm_text, fontsize=10)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")

    plt.show()

# Function to create additional visualizations for multi-nucleated cells
def visualize_multinucleated_cells(results, cell_image, nucleus_image, output_path=None):
    """
    Create visualization specifically highlighting multinucleated cells.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Get metrics
    cell_metrics = results['cell_metrics']

    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Prepare colored image based on nuclei count
    multi_colored = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Color scheme:
    # 0 nuclei: gray
    # 1 nucleus: blue
    # 2 nuclei: green
    # 3+ nuclei: red
    color_map = {
        0: [100, 100, 100],  # Gray
        1: [0, 0, 255],      # Blue
        2: [0, 255, 0],      # Green
        3: [255, 0, 0]       # Red (3 or more)
    }

    # Apply colors
    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        nuclei_count = min(3, row['nuclei_count'])  # Cap at 3+ for coloring

        cell_mask = (cell_image == cell_id)
        color = color_map[nuclei_count]

        multi_colored[cell_mask, 0] = color[0]
        multi_colored[cell_mask, 1] = color[1]
        multi_colored[cell_mask, 2] = color[2]

    # Draw nuclei boundaries
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)
    multi_colored[nucleus_boundary] = [255, 255, 255]  # White nucleus boundaries

    # Display colored image
    axes[0].imshow(multi_colored)
    axes[0].set_title("Cell Nuclei Count\nGray: 0, Blue: 1, Green: 2, Red: 3+")
    axes[0].axis('off')

    # Create distribution bar chart
    nuclei_counts = cell_metrics['nuclei_count'].value_counts().sort_index()

    # Ensure we have entries for 0, 1, 2, 3+ nuclei
    counts = [0, 0, 0, 0]
    for count, freq in nuclei_counts.items():
        if count >= 3:
            counts[3] += freq
        else:
            counts[count] = freq

    # Plot bar chart
    bars = axes[1].bar(['0', '1', '2', '3+'], counts, color=['gray', 'blue', 'green', 'red'])

    # Add count labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{int(height)}', ha='center', va='bottom')

    axes[1].set_xlabel('Number of Nuclei')
    axes[1].set_ylabel('Number of Cells')
    axes[1].set_title('Distribution of Nuclei Count')

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Multinucleated visualization saved to {output_path}")

    plt.show()

# Function to process a sample for senescence analysis
def process_senescence_analysis(sample_id, file_dict, output_dir):
    """
    Process a sample for senescence analysis.

    Parameters:
    sample_id: ID of the sample to process
    file_dict: Dictionary mapping sample IDs to component file paths
    output_dir: Directory to save outputs

    Returns:
    Results from senescence detection
    """
    print(f"Processing senescence analysis for sample {sample_id}")

    # Create sample-specific output directory
    sample_output_dir = os.path.join(output_dir, sample_id)
    os.makedirs(sample_output_dir, exist_ok=True)

    # Load images
    images = load_sample_images(sample_id, file_dict)

    if 'cell' not in images or 'nuclei' not in images:
        print(f"Error: Required cell or nuclei image not found for sample {sample_id}")
        return None

    # Run senescence detection
    results = detect_senescent_cells(
        images['cell'],
        images['nuclei'],
        expected_senescent_fraction=0.3  # From 30% TNF-a treated cells
    )

    # Skip saving if no valid cells found
    if results['total_cells'] == 0:
        print(f"No valid cells found in sample {sample_id}, skipping output generation")
        return results

    # Save results
    # 1. Save metrics as CSV
    metrics_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_metrics.csv")
    results['cell_metrics'].to_csv(metrics_file, index=False)
    print(f"Saved cell metrics to {metrics_file}")

    # 2. Save adjusted cell mask
    adjusted_mask_file = os.path.join(sample_output_dir, f"{sample_id}_adjusted_cell_mask.tif")
    io.imsave(adjusted_mask_file, results['adjusted_cell_image'].astype(np.uint16))
    print(f"Saved adjusted cell mask to {adjusted_mask_file}")

    # 3. Create and save visualization
    vis_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_visualization.png")
    visualize_senescence_detection(results, images['cell'], images['nuclei'], vis_file)

    # 4. Create and save multinucleated visualization
    multi_vis_file = os.path.join(sample_output_dir, f"{sample_id}_multinucleated_visualization.png")
    visualize_multinucleated_cells(results, images['cell'], images['nuclei'], multi_vis_file)

    # 5. Summary statistics
    summary = {
        'sample_id': sample_id,
        'total_cells': results['total_cells'],
        'senescent_cells': results['senescent_count'],
        'senescent_fraction': results['senescent_fraction'],
        'normal_cells': results['total_cells'] - results['senescent_count'],
        'normal_fraction': 1 - results['senescent_fraction']
    }

    # Add more detailed metrics if available
    if not results['cell_metrics'].empty:
        # Compare senescent vs normal cells
        senescent = results['cell_metrics'][results['cell_metrics']['is_senescent']]
        normal = results['cell_metrics'][~results['cell_metrics']['is_senescent']]

        # Add size metrics
        if not senescent.empty and not normal.empty:
            summary['senescent_mean_area'] = senescent['area'].mean()
            summary['normal_mean_area'] = normal['area'].mean()
            summary['size_ratio'] = summary['senescent_mean_area'] / summary['normal_mean_area']

            summary['senescent_mean_nuclei'] = senescent['nuclei_count'].mean()
            summary['normal_mean_nuclei'] = normal['nuclei_count'].mean()

            summary['senescent_multi_nuclei_pct'] = (senescent['nuclei_count'] > 1).mean()
            summary['normal_multi_nuclei_pct'] = (normal['nuclei_count'] > 1).mean()

    # Save summary as CSV
    summary_df = pd.DataFrame([summary])
    summary_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_summary.csv")
    summary_df.to_csv(summary_file, index=False)
    print(f"Saved summary to {summary_file}")

    return results

# Function to compile results across all samples
def compile_cross_sample_results(output_dir, sample_ids):
    """Compile results across all analyzed samples"""

    all_summaries = []
    all_metrics = []

    for sample_id in sample_ids:
        sample_dir = os.path.join(output_dir, sample_id)

        # Read summary file
        summary_file = os.path.join(sample_dir, f"{sample_id}_senescence_summary.csv")
        if os.path.exists(summary_file):
            summary = pd.read_csv(summary_file)
            all_summaries.append(summary)

        # Read metrics file
        metrics_file = os.path.join(sample_dir, f"{sample_id}_senescence_metrics.csv")
        if os.path.exists(metrics_file):
            metrics = pd.read_csv(metrics_file)
            metrics['sample_id'] = sample_id  # Add sample ID
            all_metrics.append(metrics)

    # Combine all summaries
    if all_summaries:
        combined_summary = pd.concat(all_summaries, ignore_index=True)
        summary_output = os.path.join(output_dir, "all_samples_senescence_summary.csv")
        combined_summary.to_csv(summary_output, index=False)
        print(f"Saved combined summary to {summary_output}")

        # Create a visualization of senescence percentages across samples
        plt.figure(figsize=(12, 6))
        plt.bar(combined_summary['sample_id'],
                combined_summary['senescent_fraction'] * 100)
        plt.xlabel('Sample ID')
        plt.ylabel('Senescent Cells (%)')
        plt.title('Percentage of Senescent Cells Across Samples')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "senescence_percentages.png"), dpi=300)
        plt.close()

    # Combine all metrics
    if all_metrics:
        combined_metrics = pd.concat(all_metrics, ignore_index=True)
        metrics_output = os.path.join(output_dir, "all_samples_cell_metrics.csv")
        combined_metrics.to_csv(metrics_output, index=False)
        print(f"Saved combined metrics to {metrics_output}")

        # Create boxplots comparing senescent vs normal cells
        if 'is_senescent' in combined_metrics.columns:
            senescent = combined_metrics[combined_metrics['is_senescent']]
            normal = combined_metrics[~combined_metrics['is_senescent']]

            metrics_to_plot = ['area', 'nuclei_count', 'solidity', 'circularity']
            fig, axes = plt.subplots(2, 2, figsize=(14, 12))
            axes = axes.flatten()

            for i, metric in enumerate(metrics_to_plot):
                if metric in combined_metrics.columns:
                    if len(senescent) > 0 and len(normal) > 0:
                        axes[i].boxplot([normal[metric], senescent[metric]],
                                      labels=['Normal', 'Senescent'])
                        axes[i].set_ylabel(metric)
                        axes[i].set_title(f'{metric} Distribution')

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "senescence_metrics_comparison.png"), dpi=300)
            plt.close()

            # Create scatter plot of area vs nuclei count
            plt.figure(figsize=(10, 8))
            plt.scatter(normal['area'], normal['nuclei_count'],
                      alpha=0.5, label='Normal', color='blue')
            plt.scatter(senescent['area'], senescent['nuclei_count'],
                      alpha=0.5, label='Senescent', color='red')
            plt.xlabel('Cell Area')
            plt.ylabel('Nuclei Count')
            plt.title('Cell Area vs Nuclei Count')
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "area_vs_nuclei_scatter.png"), dpi=300)
            plt.close()

    print("Compiled results across all samples")

# Function to run all processing
def main():
    print("Starting senescent cell analysis pipeline...")

    # Find matching files
    file_dict, complete_samples = find_matching_files(base_dir)

    if not complete_samples:
        print("No complete samples found. Exiting.")
        return

    # Process all samples
    for sample_id in complete_samples:
        print(f"\nProcessing sample {sample_id}")
        senescence_results = process_senescence_analysis(sample_id, file_dict, output_dir)

        if senescence_results:
            print(f"Completed senescence analysis for sample {sample_id}")
            print(f"Found {senescence_results['senescent_count']} senescent cells " +
                  f"({senescence_results['senescent_fraction']:.1%}) out of {senescence_results['total_cells']} total cells")

    # Compile results across all samples
    compile_cross_sample_results(output_dir, complete_samples)

    print("Senescence analysis pipeline completed!")

# Run the main function
if __name__ == "__main__":
    main()

Output hidden; open in https://colab.research.google.com to view.

second try

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation, feature
from scipy import ndimage
from collections import Counter
import tifffile
from pathlib import Path
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection

# Set up matplotlib for better visualization
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')

# Define paths based on your Google Drive structure
base_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-2"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/Static-A-2"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Define function to extract sample ID from filenames
def extract_sample_id(filename):
    # Example: denoised_0Pa_A1_20dec21_40x_L2RA_FlatA_seq007_contrast_Nuclei_mask.tif
    # We want to extract everything before the component name (Nuclei, Golgi, etc.)
    parts = str(filename).split('_')
    # Find the index of component parts (Nuclei, Golgi, membrane, cell)
    components = ['Nuclei', 'Golgi', 'membrane', 'cell']
    for i, part in enumerate(parts):
        if part in components:
            return '_'.join(parts[:i])
    return None  # If no match found

# Function to find matching files across different folders
def find_matching_files(base_dir):
    # Define component folders to look in (based on your Google Drive structure)
    # The order here doesn't matter for finding files
    component_folders = ['Nuclei', 'Membrane_Adjusted', 'Golgi', 'Cell']

    # Dictionary to store file paths for each sample and component
    file_dict = {}

    # Scan through each component folder
    for component in component_folders:
        component_dir = os.path.join(base_dir, component)
        if not os.path.exists(component_dir):
            print(f"Warning: {component_dir} does not exist.")
            continue

        # Get all TIFF files in the component folder
        files = [f for f in os.listdir(component_dir) if f.endswith('.tif')]
        print(f"Found {len(files)} TIFF files in {component} folder.")

        for file in files:
            # Extract the sample ID from the filename
            sample_id = extract_sample_id(file)
            if sample_id:
                if sample_id not in file_dict:
                    file_dict[sample_id] = {}

                # Store the full path to the file
                # Use 'membrane' as the key for consistency even though folder is Membrane_Adjusted
                if component == 'Membrane_Adjusted':
                    file_dict[sample_id]['membrane'] = os.path.join(component_dir, file)
                else:
                    file_dict[sample_id][component.lower()] = os.path.join(component_dir, file)

    # Find samples that have files in all components
    complete_samples = []
    for sample_id, components in file_dict.items():
        # Check if we have at least nuclei and cell data
        if 'nuclei' in components and 'cell' in components:
            complete_samples.append(sample_id)

    print(f"Found {len(complete_samples)} complete samples with at least nuclei and cell data.")

    return file_dict, complete_samples

# Function to load images for a specific sample
def load_sample_images(sample_id, file_dict):
    images = {}
    for component, filepath in file_dict[sample_id].items():
        if os.path.exists(filepath):
            try:
                # Read image and ensure it's an integer type
                img = io.imread(filepath)

                # Convert boolean images to uint8
                if img.dtype == bool:
                    img = img.astype(np.uint8)

                # Handle binary images that should be labeled
                if component in ['cell', 'nuclei'] and np.max(img) <= 1:
                    print(f"Converting binary {component} image to labeled image")
                    img, num_labels = ndimage.label(img)
                    print(f"Found {num_labels} {component} regions")

                images[component] = img
                print(f"Loaded {component} image: shape {img.shape}, dtype: {img.dtype}, value range [{np.min(img)}, {np.max(img)}]")
            except Exception as e:
                print(f"Error loading {filepath}: {str(e)}")
        else:
            print(f"Warning: File not found - {filepath}")

    return images

# Function to merge nearby or overlapping nuclei (post-processing)
def merge_close_nuclei(cell_image, nucleus_image, distance_threshold=10):
    """
    Merge nuclei that are close to each other and might be part of a multi-nucleated cell.

    Parameters:
    cell_image: Label image of cells
    nucleus_image: Label image of nuclei
    distance_threshold: Maximum distance (in pixels) between nuclei to be merged

    Returns:
    Adjusted nucleus image with merged nuclei
    """
    # Get properties of nuclei
    nucleus_props = measure.regionprops(nucleus_image)

    # Create a new nucleus image for the result
    merged_nucleus_image = nucleus_image.copy()

    # Track which nuclei have been merged
    merged_nuclei = set()

    # For each nucleus, check if there are other nuclei in the same cell
    for prop in nucleus_props:
        nucleus_id = prop.label
        if nucleus_id in merged_nuclei:
            continue

        # Get the cell ID for this nucleus
        cell_mask = cell_image > 0  # All cells
        nucleus_mask = nucleus_image == nucleus_id
        overlapping_cells = np.unique(cell_image[nucleus_mask])
        overlapping_cells = overlapping_cells[overlapping_cells > 0]  # Remove background

        if len(overlapping_cells) == 0:
            continue  # Nucleus doesn't overlap with any cell

        cell_id = overlapping_cells[0]  # Take the first overlapping cell
        cell_mask = cell_image == cell_id

        # Find all nuclei in this cell
        nuclei_in_cell = np.unique(nucleus_image[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background

        # If there's only one nucleus, nothing to merge
        if len(nuclei_in_cell) <= 1:
            continue

        # Get centers of all nuclei in this cell
        nuclei_centers = []
        for nuc_id in nuclei_in_cell:
            nuc_props = measure.regionprops((nucleus_image == nuc_id).astype(int))
            if nuc_props:
                nuclei_centers.append((nuc_id, nuc_props[0].centroid))

        # Check distances between nuclei centers
        for i, (nuc_id1, center1) in enumerate(nuclei_centers):
            if nuc_id1 in merged_nuclei:
                continue

            to_merge = [nuc_id1]

            for j, (nuc_id2, center2) in enumerate(nuclei_centers):
                if i != j and nuc_id2 not in merged_nuclei:
                    # Calculate Euclidean distance
                    distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)

                    if distance < distance_threshold:
                        to_merge.append(nuc_id2)

            # If we found nuclei to merge
            if len(to_merge) > 1:
                # Create a new label for the merged nuclei
                merged_label = to_merge[0]  # Use the first nucleus ID as the new label

                # Update the merged image
                for nuc_id in to_merge:
                    merged_nuclei.add(nuc_id)
                    merged_nucleus_image[nucleus_image == nuc_id] = merged_label

    # Relabel the merged nuclei to ensure consecutive labels
    if len(merged_nuclei) > 0:
        merged_nucleus_image, _ = ndimage.label(merged_nucleus_image > 0)

    return merged_nucleus_image

# IMPROVED: Function to detect senescent cells
def detect_senescent_cells(cell_image, nucleus_image, process_all_samples=False,
                          global_hole_stats=None, expected_senescent_fraction=0.3):
    """
    Detect senescent cells based on the presence of holes and other features.

    Parameters:
    cell_image: Label image of cells
    nucleus_image: Label image of nuclei
    process_all_samples: Flag to indicate if we're processing all samples at once
    global_hole_stats: Statistics on holes across all samples (for global thresholding)
    expected_senescent_fraction: Expected fraction of senescent cells (default 0.3)

    Returns:
    Dictionary with detection results and adjusted masks
    """
    # Ensure we have integer type images
    cell_image = cell_image.astype(np.int32)
    nucleus_image = nucleus_image.astype(np.int32)

    # Get cell properties
    cell_props = measure.regionprops(cell_image)

    # Prepare arrays for storing metrics
    cell_metrics = []

    # Process each cell
    for cell_prop in cell_props:
        cell_id = cell_prop.label

        # Skip very small objects (likely artifacts)
        if cell_prop.area < 100:
            continue

        # Create binary mask for current cell
        cell_mask = (cell_image == cell_id)

        # Count nuclei that overlap with this cell
        nuclei_in_cell = np.unique(nucleus_image[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background (0)
        nuclei_count = len(nuclei_in_cell)

        # Calculate total nuclear area within this cell
        nuclear_area = np.sum(np.isin(nucleus_image, nuclei_in_cell) & cell_mask)
        nuclear_cytoplasmic_ratio = nuclear_area / cell_prop.area if cell_prop.area > 0 else 0

        # Calculate hole properties
        # Invert the cell mask to detect holes
        filled_mask = ndimage.binary_fill_holes(cell_mask)
        holes_mask = filled_mask & ~cell_mask

        # Label the holes
        labeled_holes, num_holes = ndimage.label(holes_mask)
        hole_sizes = [np.sum(labeled_holes == i) for i in range(1, num_holes + 1)]
        total_hole_area = np.sum(holes_mask)
        holes_to_cell_ratio = total_hole_area / cell_prop.area if cell_prop.area > 0 else 0

        # IMPROVED: Calculate more detailed hole metrics
        max_hole_size = max(hole_sizes) if hole_sizes else 0
        avg_hole_size = np.mean(hole_sizes) if hole_sizes else 0

        # Extract shape metrics
        perimeter = cell_prop.perimeter if cell_prop.perimeter else 0
        circularity = (4 * np.pi * cell_prop.area) / (perimeter * perimeter) if perimeter > 0 else 0
        solidity = cell_prop.solidity

        # Store all metrics
        metrics = {
            'cell_id': cell_id,
            'area': cell_prop.area,
            'perimeter': perimeter,
            'circularity': circularity,
            'solidity': solidity,
            'nuclei_count': nuclei_count,
            'nuclear_area': nuclear_area,
            'nuclear_cytoplasmic_ratio': nuclear_cytoplasmic_ratio,
            'num_holes': num_holes,
            'total_hole_area': total_hole_area,
            'max_hole_size': max_hole_size,
            'avg_hole_size': avg_hole_size,
            'holes_to_cell_ratio': holes_to_cell_ratio,
            'has_holes': num_holes > 0  # IMPROVED: Simple flag for cells with holes
        }

        cell_metrics.append(metrics)

    # Handle empty metrics case
    if not cell_metrics:
        print("No valid cells found for analysis")
        return {
            'cell_metrics': pd.DataFrame(),
            'original_cell_image': cell_image,
            'adjusted_cell_image': cell_image.copy(),
            'senescent_count': 0,
            'total_cells': 0,
            'senescent_fraction': 0
        }

    # Convert to DataFrame for easier analysis
    metrics_df = pd.DataFrame(cell_metrics)

    # IMPROVED: Modified senescence detection approach
    # First, mark all cells with holes as potentially senescent
    metrics_df['is_senescent'] = metrics_df['has_holes']

    # If we have global statistics, use them for thresholding
    if process_all_samples and global_hole_stats is not None:
        # Prioritize cells with holes but use global size thresholds to reach target percentage
        metrics_df['senescence_score'] = 0
        metrics_df.loc[metrics_df['has_holes'], 'senescence_score'] += 2  # Strong weight for holes
        metrics_df.loc[metrics_df['area'] > global_hole_stats['size_threshold'], 'senescence_score'] += 1
        metrics_df.loc[metrics_df['nuclei_count'] > 1, 'senescence_score'] += 1

        # Sort by score and select top X% as senescent
        metrics_df = metrics_df.sort_values('senescence_score', ascending=False)
        senescent_count = int(len(metrics_df) * expected_senescent_fraction)
        metrics_df['is_senescent'] = False
        metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True
    else:
        # For individual sample processing, prioritize cells with holes
        # but adjust to reach the expected senescent fraction
        has_holes_count = metrics_df['has_holes'].sum()
        total_cells = len(metrics_df)

        if has_holes_count / total_cells > expected_senescent_fraction:
            # If too many cells have holes, select the ones with the largest holes
            metrics_df = metrics_df.sort_values('total_hole_area', ascending=False)
            senescent_count = int(total_cells * expected_senescent_fraction)
            metrics_df['is_senescent'] = False
            metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True
        elif has_holes_count / total_cells < expected_senescent_fraction:
            # If too few cells have holes, add cells based on size and multinucleation
            metrics_df['senescence_score'] = 0
            metrics_df.loc[metrics_df['has_holes'], 'senescence_score'] += 3  # Strong weight for holes
            metrics_df.loc[metrics_df['area'] > metrics_df['area'].quantile(0.7), 'senescence_score'] += 1
            metrics_df.loc[metrics_df['nuclei_count'] > 1, 'senescence_score'] += 1

            # Sort by score and select top X% as senescent
            metrics_df = metrics_df.sort_values('senescence_score', ascending=False)
            senescent_count = int(total_cells * expected_senescent_fraction)
            metrics_df['is_senescent'] = False
            metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True

    # Create adjusted cell masks by filling holes in senescent cells
    adjusted_cell_image = cell_image.copy()

    for _, row in metrics_df[metrics_df['is_senescent']].iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        # Fill holes in senescent cells
        filled_mask = ndimage.binary_fill_holes(cell_mask)

        # Update the adjusted image with the filled mask
        # We need to handle overlaps with other cells
        # First, remove the original cell
        adjusted_cell_image[cell_mask] = 0

        # Then add the filled version
        adjusted_cell_image[filled_mask] = cell_id

    # Calculate statistics
    senescent_count = metrics_df['is_senescent'].sum()
    total_cells = len(metrics_df)
    senescent_fraction = senescent_count / total_cells if total_cells > 0 else 0

    print(f"Detected {senescent_count} senescent cells out of {total_cells} total cells ({senescent_fraction:.2%})")

    # Return results
    return {
        'cell_metrics': metrics_df,
        'original_cell_image': cell_image,
        'adjusted_cell_image': adjusted_cell_image,
        'senescent_count': senescent_count,
        'total_cells': total_cells,
        'senescent_fraction': senescent_fraction
    }

# Function to calculate global statistics across all samples
def calculate_global_statistics(file_dict, complete_samples):
    """
    Calculate global statistics across all samples to use for consistent thresholding.

    Parameters:
    file_dict: Dictionary mapping sample IDs to component file paths
    complete_samples: List of sample IDs with complete data

    Returns:
    Dictionary with global statistics
    """
    print("Calculating global statistics across all samples...")

    all_cell_areas = []
    all_hole_ratios = []
    all_hole_counts = []
    cells_with_holes = 0
    total_cells = 0

    # Process each sample to gather statistics
    for sample_id in complete_samples:
        print(f"Processing sample {sample_id} for global statistics")

        # Load images
        images = load_sample_images(sample_id, file_dict)

        if 'cell' not in images or 'nuclei' not in images:
            print(f"Error: Required cell or nuclei image not found for sample {sample_id}")
            continue

        # Get cell properties
        cell_props = measure.regionprops(images['cell'].astype(np.int32))

        for cell_prop in cell_props:
            cell_id = cell_prop.label

            # Skip very small objects (likely artifacts)
            if cell_prop.area < 100:
                continue

            # Add cell area to global list
            all_cell_areas.append(cell_prop.area)

            # Check for holes
            cell_mask = (images['cell'] == cell_id)
            filled_mask = ndimage.binary_fill_holes(cell_mask)
            holes_mask = filled_mask & ~cell_mask

            # Label the holes
            labeled_holes, num_holes = ndimage.label(holes_mask)
            total_hole_area = np.sum(holes_mask)
            holes_to_cell_ratio = total_hole_area / cell_prop.area if cell_prop.area > 0 else 0

            # Track holes statistics
            all_hole_counts.append(num_holes)
            all_hole_ratios.append(holes_to_cell_ratio)

            if num_holes > 0:
                cells_with_holes += 1

            total_cells += 1

    # Calculate global statistics
    if total_cells == 0:
        print("No valid cells found across samples")
        return None

    # Calculate thresholds for metrics
    area_mean = np.mean(all_cell_areas)
    area_std = np.std(all_cell_areas)
    size_threshold = area_mean + 1.5 * area_std

    hole_ratio_mean = np.mean(all_hole_ratios)
    hole_ratio_std = np.std(all_hole_ratios)
    hole_ratio_threshold = hole_ratio_mean + 1.0 * hole_ratio_std

    global_hole_fraction = cells_with_holes / total_cells if total_cells > 0 else 0

    # Create statistics dictionary
    global_stats = {
        'total_cells': total_cells,
        'cells_with_holes': cells_with_holes,
        'hole_fraction': global_hole_fraction,
        'size_threshold': size_threshold,
        'hole_ratio_threshold': hole_ratio_threshold,
        'area_mean': area_mean,
        'area_std': area_std
    }

    print(f"Global statistics: {global_stats}")

    return global_stats

# Function to visualize senescence detection results
def visualize_senescence_detection(results, cell_image, nucleus_image, output_path=None):
    """
    Visualize the senescence detection results.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    # Check if we have valid results
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))

    # 1. Original cell image with nuclei overlay
    cell_boundary = segmentation.find_boundaries(cell_image)
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)

    overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Cell boundaries in green
    overlay[cell_boundary, 0] = 0
    overlay[cell_boundary, 1] = 255
    overlay[cell_boundary, 2] = 0
    # Nucleus boundaries in blue
    overlay[nucleus_boundary, 0] = 0
    overlay[nucleus_boundary, 1] = 0
    overlay[nucleus_boundary, 2] = 255

    axes[0, 0].imshow(overlay)
    axes[0, 0].set_title("Original Cells (green) and Nuclei (blue)")
    axes[0, 0].axis('off')

    # 2. Senescent vs normal cells
    senescent_mask = np.zeros_like(cell_image, dtype=bool)
    normal_mask = np.zeros_like(cell_image, dtype=bool)

    cell_metrics = results['cell_metrics']

    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        if row['is_senescent']:
            senescent_mask |= cell_mask
        else:
            normal_mask |= cell_mask

    classification = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Senescent cells in red
    classification[senescent_mask, 0] = 255
    classification[senescent_mask, 1] = 0
    classification[senescent_mask, 2] = 0
    # Normal cells in green
    classification[normal_mask, 0] = 0
    classification[normal_mask, 1] = 255
    classification[normal_mask, 2] = 0

    axes[0, 1].imshow(classification)
    axes[0, 1].set_title(f"Senescent Cells (red): {results['senescent_fraction']:.1%}")
    axes[0, 1].axis('off')

    # 3. Adjusted cell masks (after filling holes in senescent cells)
    adjusted_boundary = segmentation.find_boundaries(results['adjusted_cell_image'])

    adjusted_overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Original cell boundaries in green
    adjusted_overlay[cell_boundary, 0] = 0
    adjusted_overlay[cell_boundary, 1] = 255
    adjusted_overlay[cell_boundary, 2] = 0
    # Adjusted cell boundaries in yellow
    adjusted_overlay[adjusted_boundary, 0] = 255
    adjusted_overlay[adjusted_boundary, 1] = 255
    adjusted_overlay[adjusted_boundary, 2] = 0

    axes[1, 0].imshow(adjusted_overlay)
    axes[1, 0].set_title("Original (green) vs Adjusted (yellow) Cell Boundaries")
    axes[1, 0].axis('off')

    # 4. Visualization of cells with holes
    has_holes_mask = np.zeros_like(cell_image, dtype=bool)
    no_holes_mask = np.zeros_like(cell_image, dtype=bool)

    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        if row['has_holes']:
            has_holes_mask |= cell_mask
        else:
            no_holes_mask |= cell_mask

    holes_vis = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Cells with holes in red
    holes_vis[has_holes_mask, 0] = 255
    holes_vis[has_holes_mask, 1] = 0
    holes_vis[has_holes_mask, 2] = 0
    # Cells without holes in blue
    holes_vis[no_holes_mask, 0] = 0
    holes_vis[no_holes_mask, 1] = 0
    holes_vis[no_holes_mask, 2] = 255

    # Add cell boundaries for clarity
    holes_vis[cell_boundary, :] = [255, 255, 255]  # White boundaries

    axes[1, 1].imshow(holes_vis)
    axes[1, 1].set_title("Cells With Holes (red) vs Without Holes (blue)")
    axes[1, 1].axis('off')

    # Add metrics summary as text
    senescent_metrics = cell_metrics[cell_metrics['is_senescent']]
    normal_metrics = cell_metrics[~cell_metrics['is_senescent']]

    if len(senescent_metrics) > 0:
        sen_text = (f"Senescent cells (n={len(senescent_metrics)}):\n" +
                   f"Mean area: {senescent_metrics['area'].mean():.1f}\n" +
                   f"Mean holes: {senescent_metrics['num_holes'].mean():.1f}\n" +
                   f"Mean nuclei: {senescent_metrics['nuclei_count'].mean():.1f}")
    else:
        sen_text = "No senescent cells detected"

    if len(normal_metrics) > 0:
        norm_text = (f"Normal cells (n={len(normal_metrics)}):\n" +
                    f"Mean area: {normal_metrics['area'].mean():.1f}\n" +
                    f"Mean holes: {normal_metrics['num_holes'].mean():.1f}\n" +
                    f"Mean nuclei: {normal_metrics['nuclei_count'].mean():.1f}")
    else:
        norm_text = "No normal cells detected"

    fig.text(0.02, 0.02, sen_text, fontsize=10)
    fig.text(0.52, 0.02, norm_text, fontsize=10)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")

    plt.show()

# Function to visualize multinucleated cells
def visualize_multinucleated_cells(results, cell_image, nucleus_image, output_path=None):
    """
    Create visualization specifically highlighting multinucleated cells.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Get metrics
    cell_metrics = results['cell_metrics']

    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Prepare colored image based on nuclei count
    multi_colored = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Color scheme:
    # 0 nuclei: gray
    # 1 nucleus: blue
    # 2 nuclei: green
    # 3+ nuclei: red
    color_map = {
        0: [100, 100, 100],  # Gray
        1: [0, 0, 255],      # Blue
        2: [0, 255, 0],      # Green
        3: [255, 0, 0]       # Red (3 or more)
    }

    # Apply colors
    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        nuclei_count = min(3, row['nuclei_count'])  # Cap at 3+ for coloring

        cell_mask = (cell_image == cell_id)
        color = color_map[nuclei_count]

        multi_colored[cell_mask, 0] = color[0]
        multi_colored[cell_mask, 1] = color[1]
        multi_colored[cell_mask, 2] = color[2]

    # Draw nuclei boundaries
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)
    multi_colored[nucleus_boundary] = [255, 255, 255]  # White nucleus boundaries

    # Display colored image
    axes[0].imshow(multi_colored)
    axes[0].set_title("Cell Nuclei Count\nGray: 0, Blue: 1, Green: 2, Red: 3+")
    axes[0].axis('off')

    # Create distribution bar chart
    nuclei_counts = cell_metrics['nuclei_count'].value_counts().sort_index()

    # Ensure we have entries for 0, 1, 2, 3+ nuclei
    counts = [0, 0, 0, 0]
    for count, freq in nuclei_counts.items():
        if count >= 3:
            counts[3] += freq
        else:
            counts[count] = freq

    # Plot bar chart
    bars = axes[1].bar(['0', '1', '2', '3+'], counts, color=['gray', 'blue', 'green', 'red'])

    # Add count labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{int(height)}', ha='center', va='bottom')

    axes[1].set_xlabel('Number of Nuclei')
    axes[1].set_ylabel('Number of Cells')
    axes[1].set_title('Distribution of Nuclei Count')

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Multinucleated visualization saved to {output_path}")

    plt.show()

# Function to create a visualization showing holes in cells
def visualize_cell_holes(cell_image, output_path=None):
    """
    Create a visualization showing cells with holes.

    Parameters:
    cell_image: Label image of cells
    output_path: Path to save the visualization
    """
    # Create figure
    plt.figure(figsize=(14, 12))

    # Create masks to visualize holes
    cell_mask = cell_image > 0
    filled_cell_mask = ndimage.binary_fill_holes(cell_mask)
    holes_mask = filled_cell_mask & ~cell_mask

    # Create RGB image for visualization
    vis_image = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Cells in green
    vis_image[cell_mask, 1] = 180

    # Holes in red
    vis_image[holes_mask, 0] = 255

    # Cell boundaries in white
    cell_boundary = segmentation.find_boundaries(cell_image)
    vis_image[cell_boundary] = [255, 255, 255]

    plt.imshow(vis_image)
    plt.title('Cells (green) with Holes (red)')
    plt.axis('off')

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Cell holes visualization saved to {output_path}")

    plt.tight_layout()
    plt.show()

# Function to process a sample for senescence analysis
def process_senescence_analysis(sample_id, file_dict, output_dir, global_stats=None, process_all_samples=False):
    """
    Process a sample for senescence analysis.

    Parameters:
    sample_id: ID of the sample to process
    file_dict: Dictionary mapping sample IDs to component file paths
    output_dir: Directory to save outputs
    global_stats: Global statistics for consistent thresholding
    process_all_samples: Flag to indicate if we're processing all samples together

    Returns:
    Results from senescence detection
    """
    print(f"Processing senescence analysis for sample {sample_id}")

    # Create sample-specific output directory
    sample_output_dir = os.path.join(output_dir, sample_id)
    os.makedirs(sample_output_dir, exist_ok=True)

    # Load images
    images = load_sample_images(sample_id, file_dict)

    if 'cell' not in images or 'nuclei' not in images:
        print(f"Error: Required cell or nuclei image not found for sample {sample_id}")
        return None

    # IMPROVED: Apply the polynucleated cell correction
    # Merge close nuclei that might be part of the same cell
    merged_nucleus_image = merge_close_nuclei(images['cell'], images['nuclei'], distance_threshold=15)

    # Create visualization of original vs merged nuclei
    plt.figure(figsize=(14, 7))

    plt.subplot(1, 2, 1)
    plt.imshow(segmentation.mark_boundaries(
        np.zeros_like(images['nuclei'], dtype=np.uint8),
        images['nuclei'],
        color=(1, 0, 0)
    ))
    plt.title('Original Nuclei')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(segmentation.mark_boundaries(
        np.zeros_like(merged_nucleus_image, dtype=np.uint8),
        merged_nucleus_image,
        color=(0, 1, 0)
    ))
    plt.title('Merged Nuclei')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(sample_output_dir, f"{sample_id}_nuclei_merge_comparison.png"), dpi=300)
    plt.close()

    # Create visualization showing holes in cells
    visualize_cell_holes(images['cell'],
                        os.path.join(sample_output_dir, f"{sample_id}_cell_holes.png"))

    # Run senescence detection with modified approach
    results = detect_senescent_cells(
        images['cell'],
        merged_nucleus_image,  # Use the merged nuclei image
        process_all_samples=process_all_samples,
        global_hole_stats=global_stats,
        expected_senescent_fraction=0.3  # From 30% TNF-a treated cells
    )

    # Skip saving if no valid cells found
    if results['total_cells'] == 0:
        print(f"No valid cells found in sample {sample_id}, skipping output generation")
        return results

    # Save results
    # 1. Save metrics as CSV
    metrics_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_metrics.csv")
    results['cell_metrics'].to_csv(metrics_file, index=False)
    print(f"Saved cell metrics to {metrics_file}")

    # 2. Save adjusted cell mask
    adjusted_mask_file = os.path.join(sample_output_dir, f"{sample_id}_adjusted_cell_mask.tif")
    io.imsave(adjusted_mask_file, results['adjusted_cell_image'].astype(np.uint16))
    print(f"Saved adjusted cell mask to {adjusted_mask_file}")

    # 3. Create and save visualization
    vis_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_visualization.png")
    visualize_senescence_detection(results, images['cell'], merged_nucleus_image, vis_file)

    # 4. Create and save multinucleated visualization
    multi_vis_file = os.path.join(sample_output_dir, f"{sample_id}_multinucleated_visualization.png")
    visualize_multinucleated_cells(results, images['cell'], merged_nucleus_image, multi_vis_file)

    # 5. Summary statistics
    summary = {
        'sample_id': sample_id,
        'total_cells': results['total_cells'],
        'senescent_cells': results['senescent_count'],
        'senescent_fraction': results['senescent_fraction'],
        'normal_cells': results['total_cells'] - results['senescent_count'],
        'normal_fraction': 1 - results['senescent_fraction'],
        'cells_with_holes': results['cell_metrics']['has_holes'].sum(),
        'hole_fraction': results['cell_metrics']['has_holes'].sum() / results['total_cells'] if results['total_cells'] > 0 else 0
    }

    # Add more detailed metrics if available
    if not results['cell_metrics'].empty:
        # Compare senescent vs normal cells
        senescent = results['cell_metrics'][results['cell_metrics']['is_senescent']]
        normal = results['cell_metrics'][~results['cell_metrics']['is_senescent']]

        # Add size metrics
        if not senescent.empty and not normal.empty:
            summary['senescent_mean_area'] = senescent['area'].mean()
            summary['normal_mean_area'] = normal['area'].mean()
            summary['size_ratio'] = summary['senescent_mean_area'] / summary['normal_mean_area']

            summary['senescent_mean_nuclei'] = senescent['nuclei_count'].mean()
            summary['normal_mean_nuclei'] = normal['nuclei_count'].mean()

            summary['senescent_multi_nuclei_pct'] = (senescent['nuclei_count'] > 1).mean()
            summary['normal_multi_nuclei_pct'] = (normal['nuclei_count'] > 1).mean()

            # Add hole statistics
            summary['senescent_mean_holes'] = senescent['num_holes'].mean()
            summary['normal_mean_holes'] = normal['num_holes'].mean()

            summary['senescent_with_holes_pct'] = senescent['has_holes'].mean()
            summary['normal_with_holes_pct'] = normal['has_holes'].mean()

    # Save summary as CSV
    summary_df = pd.DataFrame([summary])
    summary_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_summary.csv")
    summary_df.to_csv(summary_file, index=False)
    print(f"Saved summary to {summary_file}")

    return results

# Function to compile results across all samples
def compile_cross_sample_results(output_dir, sample_ids):
    """Compile results across all analyzed samples"""

    all_summaries = []
    all_metrics = []

    for sample_id in sample_ids:
        sample_dir = os.path.join(output_dir, sample_id)

        # Read summary file
        summary_file = os.path.join(sample_dir, f"{sample_id}_senescence_summary.csv")
        if os.path.exists(summary_file):
            summary = pd.read_csv(summary_file)
            all_summaries.append(summary)

        # Read metrics file
        metrics_file = os.path.join(sample_dir, f"{sample_id}_senescence_metrics.csv")
        if os.path.exists(metrics_file):
            metrics = pd.read_csv(metrics_file)
            metrics['sample_id'] = sample_id  # Add sample ID
            all_metrics.append(metrics)

    # Combine all summaries
    if all_summaries:
        combined_summary = pd.concat(all_summaries, ignore_index=True)
        summary_output = os.path.join(output_dir, "all_samples_senescence_summary.csv")
        combined_summary.to_csv(summary_output, index=False)
        print(f"Saved combined summary to {summary_output}")

        # Create a visualization of senescence percentages across samples
        plt.figure(figsize=(12, 6))
        plt.bar(combined_summary['sample_id'],
                combined_summary['senescent_fraction'] * 100)
        plt.xlabel('Sample ID')
        plt.ylabel('Senescent Cells (%)')
        plt.title('Percentage of Senescent Cells Across Samples')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "senescence_percentages.png"), dpi=300)
        plt.close()

        # Create visualization of hole fractions across samples
        if 'hole_fraction' in combined_summary.columns:
            plt.figure(figsize=(12, 6))
            plt.bar(combined_summary['sample_id'],
                    combined_summary['hole_fraction'] * 100)
            plt.xlabel('Sample ID')
            plt.ylabel('Cells With Holes (%)')
            plt.title('Percentage of Cells With Holes Across Samples')
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "hole_fractions.png"), dpi=300)
            plt.close()

    # Combine all metrics
    if all_metrics:
        combined_metrics = pd.concat(all_metrics, ignore_index=True)
        metrics_output = os.path.join(output_dir, "all_samples_cell_metrics.csv")
        combined_metrics.to_csv(metrics_output, index=False)
        print(f"Saved combined metrics to {metrics_output}")

        # Create boxplots comparing senescent vs normal cells
        if 'is_senescent' in combined_metrics.columns:
            senescent = combined_metrics[combined_metrics['is_senescent']]
            normal = combined_metrics[~combined_metrics['is_senescent']]

            metrics_to_plot = ['area', 'nuclei_count', 'solidity', 'num_holes', 'circularity']
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            axes = axes.flatten()

            for i, metric in enumerate(metrics_to_plot):
                if i < len(axes) and metric in combined_metrics.columns:
                    if len(senescent) > 0 and len(normal) > 0:
                        axes[i].boxplot([normal[metric], senescent[metric]],
                                      labels=['Normal', 'Senescent'])
                        axes[i].set_ylabel(metric)
                        axes[i].set_title(f'{metric} Distribution')

            # Add a plot for holes to cell ratio
            if len(axes) > len(metrics_to_plot) and 'holes_to_cell_ratio' in combined_metrics.columns:
                i = len(metrics_to_plot)
                if len(senescent) > 0 and len(normal) > 0:
                    axes[i].boxplot([normal['holes_to_cell_ratio'], senescent['holes_to_cell_ratio']],
                                  labels=['Normal', 'Senescent'])
                    axes[i].set_ylabel('Holes to Cell Ratio')
                    axes[i].set_title('Holes to Cell Ratio Distribution')

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "senescence_metrics_comparison.png"), dpi=300)
            plt.close()

            # Create scatter plot of area vs nuclei count
            plt.figure(figsize=(10, 8))
            plt.scatter(normal['area'], normal['nuclei_count'],
                      alpha=0.5, label='Normal', color='blue')
            plt.scatter(senescent['area'], senescent['nuclei_count'],
                      alpha=0.5, label='Senescent', color='red')
            plt.xlabel('Cell Area')
            plt.ylabel('Nuclei Count')
            plt.title('Cell Area vs Nuclei Count')
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "area_vs_nuclei_scatter.png"), dpi=300)
            plt.close()

            # Create scatter plot of area vs holes
            if 'num_holes' in combined_metrics.columns:
                plt.figure(figsize=(10, 8))
                plt.scatter(normal['area'], normal['num_holes'],
                          alpha=0.5, label='Normal', color='blue')
                plt.scatter(senescent['area'], senescent['num_holes'],
                          alpha=0.5, label='Senescent', color='red')
                plt.xlabel('Cell Area')
                plt.ylabel('Number of Holes')
                plt.title('Cell Area vs Number of Holes')
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, "area_vs_holes_scatter.png"), dpi=300)
                plt.close()

    print("Compiled results across all samples")

# IMPROVED: Two-phase processing approach
def process_with_global_statistics(file_dict, complete_samples, output_dir):
    """
    Process all samples with consistent thresholding based on global statistics.

    Parameters:
    file_dict: Dictionary mapping sample IDs to component file paths
    complete_samples: List of sample IDs with complete data
    output_dir: Directory to save outputs
    """
    print("Starting two-phase senescence analysis...")

    # Phase 1: Calculate global statistics
    global_stats = calculate_global_statistics(file_dict, complete_samples)

    if not global_stats or global_stats['total_cells'] == 0:
        print("Could not calculate global statistics. Falling back to per-sample processing.")
        # Fall back to individual sample processing
        for sample_id in complete_samples:
            process_senescence_analysis(sample_id, file_dict, output_dir)
        return

    # Save global statistics
    global_stats_df = pd.DataFrame([global_stats])
    global_stats_output = os.path.join(output_dir, "global_statistics.csv")
    global_stats_df.to_csv(global_stats_output, index=False)
    print(f"Saved global statistics to {global_stats_output}")

    # Phase 2: Process each sample with global statistics
    for sample_id in complete_samples:
        print(f"\nProcessing sample {sample_id} with global statistics")
        senescence_results = process_senescence_analysis(
            sample_id, file_dict, output_dir,
            global_stats=global_stats,
            process_all_samples=True
        )

        if senescence_results:
            print(f"Completed senescence analysis for sample {sample_id}")
            print(f"Found {senescence_results['senescent_count']} senescent cells " +
                  f"({senescence_results['senescent_fraction']:.1%}) out of {senescence_results['total_cells']} total cells")

    # Compile results across all samples
    compile_cross_sample_results(output_dir, complete_samples)

    print("Two-phase senescence analysis completed!")

# Function to run all processing
def main():
    print("Starting improved senescent cell analysis pipeline...")

    # Find matching files
    file_dict, complete_samples = find_matching_files(base_dir)

    if not complete_samples:
        print("No complete samples found. Exiting.")
        return

    # IMPROVED: Use the two-phase processing approach for consistent results
    process_with_global_statistics(file_dict, complete_samples, output_dir)

    print("Improved senescence analysis pipeline completed!")

# Run the main function
if __name__ == "__main__":
    main()

Starting improved senescent cell analysis pipeline...
Found 18 TIFF files in Nuclei folder.
Found 18 TIFF files in Membrane_Adjusted folder.
Found 18 TIFF files in Golgi folder.
Found 18 TIFF files in Cell folder.
Found 18 complete samples with at least nuclei and cell data.
Starting two-phase senescence analysis...
Calculating global statistics across all samples...
Processing sample denoised_0Pa_A1_20dec21_20xA_L2RA_FlatA_seq001_contrast for global statistics
Loaded nuclei image: shape (1024, 1024), dtype: uint32, value range [0, 241]
Loaded membrane image: shape (1024, 1024), dtype: uint8, value range [0, 1]
Loaded golgi image: shape (1024, 1024), dtype: uint32, value range [0, 194]
Loaded cell image: shape (1024, 1024), dtype: uint32, value range [0, 241]
Processing sample denoised_0Pa_A1_20dec21_20xA_L2RA_FlatA_seq002_contrast for global statistics
Loaded nuclei image: shape (1024, 1024), dtype: uint32, value range [0, 329]
Loaded membrane image: shape (1024, 1024), dtype: uint8, 

KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation, feature, morphology
from scipy import ndimage
from collections import Counter
import tifffile
from pathlib import Path
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection

# Set up matplotlib for better visualization
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')

# Define paths based on your Google Drive structure
base_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/Static-A-2"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/Static-A-2"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Define function to extract sample ID from filenames
def extract_sample_id(filename):
    # Example: denoised_0Pa_A1_20dec21_40x_L2RA_FlatA_seq007_contrast_Nuclei_mask.tif
    # We want to extract everything before the component name (Nuclei, Golgi, etc.)
    parts = str(filename).split('_')
    # Find the index of component parts (Nuclei, Golgi, membrane, cell)
    components = ['Nuclei', 'Golgi', 'membrane', 'cell']
    for i, part in enumerate(parts):
        if part in components:
            return '_'.join(parts[:i])
    return None  # If no match found

# Function to find matching files across different folders
def find_matching_files(base_dir):
    # Define component folders to look in (based on your Google Drive structure)
    # The order here doesn't matter for finding files
    component_folders = ['Nuclei', 'Membrane_Adjusted', 'Golgi', 'Cell']

    # Dictionary to store file paths for each sample and component
    file_dict = {}

    # Scan through each component folder
    for component in component_folders:
        component_dir = os.path.join(base_dir, component)
        if not os.path.exists(component_dir):
            print(f"Warning: {component_dir} does not exist.")
            continue

        # Get all TIFF files in the component folder
        files = [f for f in os.listdir(component_dir) if f.endswith('.tif')]
        print(f"Found {len(files)} TIFF files in {component} folder.")

        for file in files:
            # Extract the sample ID from the filename
            sample_id = extract_sample_id(file)
            if sample_id:
                if sample_id not in file_dict:
                    file_dict[sample_id] = {}

                # Store the full path to the file
                # Use 'membrane' as the key for consistency even though folder is Membrane_Adjusted
                if component == 'Membrane_Adjusted':
                    file_dict[sample_id]['membrane'] = os.path.join(component_dir, file)
                else:
                    file_dict[sample_id][component.lower()] = os.path.join(component_dir, file)

    # Find samples that have files in all components
    complete_samples = []
    for sample_id, components in file_dict.items():
        # Check if we have at least nuclei and cell data
        if 'nuclei' in components and 'cell' in components:
            complete_samples.append(sample_id)

    print(f"Found {len(complete_samples)} complete samples with at least nuclei and cell data.")

    return file_dict, complete_samples

# Function to load images for a specific sample
def load_sample_images(sample_id, file_dict):
    images = {}
    for component, filepath in file_dict[sample_id].items():
        if os.path.exists(filepath):
            try:
                # Read image and ensure it's an integer type
                img = io.imread(filepath)

                # Convert boolean images to uint8
                if img.dtype == bool:
                    img = img.astype(np.uint8)

                # Handle binary images that should be labeled
                if component in ['cell', 'nuclei'] and np.max(img) <= 1:
                    print(f"Converting binary {component} image to labeled image")
                    img, num_labels = ndimage.label(img)
                    print(f"Found {num_labels} {component} regions")

                images[component] = img
                print(f"Loaded {component} image: shape {img.shape}, dtype: {img.dtype}, value range [{np.min(img)}, {np.max(img)}]")
            except Exception as e:
                print(f"Error loading {filepath}: {str(e)}")
        else:
            print(f"Warning: File not found - {filepath}")

    return images

# Improved function to detect and create multinucleated cells
def merge_close_nuclei(cell_image, nucleus_image, distance_threshold=20):
    """
    Merge nuclei that are close to each other and might be part of a multi-nucleated cell.

    Parameters:
    cell_image: Label image of cells
    nucleus_image: Label image of nuclei
    distance_threshold: Maximum distance (in pixels) between nuclei to be merged

    Returns:
    Adjusted nucleus image with merged nuclei
    """
    # Make working copies of the images
    nucleus_image_labeled = nucleus_image.copy()

    # If the nucleus image isn't already labeled, label it
    if np.max(nucleus_image_labeled) <= 1:
        nucleus_image_labeled, _ = ndimage.label(nucleus_image_labeled > 0)

    # Get properties of nuclei
    nucleus_props = measure.regionprops(nucleus_image_labeled)

    # First, identify cells with multiple nuclei candidates
    cells_with_multiple_nuclei = {}

    # For each cell, find all nuclei inside it
    for cell_id in np.unique(cell_image):
        if cell_id == 0:  # Skip background
            continue

        # Get the mask for this cell
        cell_mask = cell_image == cell_id

        # Find all nuclei in this cell
        nuclei_in_cell = np.unique(nucleus_image_labeled[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background

        if len(nuclei_in_cell) > 1:
            # This cell contains multiple nuclei
            cells_with_multiple_nuclei[cell_id] = nuclei_in_cell

    # Create a new image for merged nuclei
    merged_nucleus_image = np.zeros_like(nucleus_image_labeled)
    next_label = 1

    # Process cells with multiple nuclei
    for cell_id, nuclei_list in cells_with_multiple_nuclei.items():
        cell_mask = cell_image == cell_id

        # Check if these nuclei should be merged into a multinucleated cell
        # by analyzing their proximity and relative positions

        # Get properties of all nuclei in this cell
        nuclei_props = []
        for nuc_id in nuclei_list:
            nuc_mask = nucleus_image_labeled == nuc_id
            nuc_area = np.sum(nuc_mask)

            # Get centroid
            nuc_coords = np.where(nuc_mask)
            if len(nuc_coords[0]) > 0:
                centroid = (np.mean(nuc_coords[0]), np.mean(nuc_coords[1]))
                nuclei_props.append({
                    'id': nuc_id,
                    'area': nuc_area,
                    'centroid': centroid
                })

        # If we have multiple nuclei properties
        if len(nuclei_props) > 1:
            # Calculate pairwise distances between all nuclei
            merge_groups = []
            processed = set()

            for i, prop1 in enumerate(nuclei_props):
                if prop1['id'] in processed:
                    continue

                current_group = [prop1['id']]
                processed.add(prop1['id'])

                for j, prop2 in enumerate(nuclei_props):
                    if i != j and prop2['id'] not in processed:
                        # Calculate Euclidean distance
                        distance = np.sqrt(
                            (prop1['centroid'][0] - prop2['centroid'][0])**2 +
                            (prop1['centroid'][1] - prop2['centroid'][1])**2
                        )

                        # Also consider relative sizes - nuclei of similar sizes are more likely to be from the same cell
                        size_ratio = min(prop1['area'], prop2['area']) / max(prop1['area'], prop2['area'])

                        # Check if these nuclei should be merged
                        # Either they are very close or moderately close with similar sizes
                        if distance < distance_threshold or (distance < distance_threshold*1.5 and size_ratio > 0.7):
                            current_group.append(prop2['id'])
                            processed.add(prop2['id'])

                if len(current_group) > 0:
                    merge_groups.append(current_group)

            # Apply merging based on identified groups
            for group in merge_groups:
                if len(group) > 1:  # Only process groups with multiple nuclei
                    # Create a new label for this multinucleated cell
                    for nuc_id in group:
                        merged_nucleus_image[nucleus_image_labeled == nuc_id] = next_label
                    next_label += 1
                else:
                    # Single nucleus, copy as is
                    merged_nucleus_image[nucleus_image_labeled == group[0]] = next_label
                    next_label += 1
        else:
            # Just one nucleus, copy as is
            for nuc_id in nuclei_list:
                merged_nucleus_image[nucleus_image_labeled == nuc_id] = next_label
                next_label += 1

    # Process cells with a single nucleus
    for cell_id in np.unique(cell_image):
        if cell_id == 0 or cell_id in cells_with_multiple_nuclei:  # Skip background and already processed cells
            continue

        # Get the mask for this cell
        cell_mask = cell_image == cell_id

        # Find the single nucleus in this cell
        nuclei_in_cell = np.unique(nucleus_image_labeled[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background

        if len(nuclei_in_cell) == 1:
            # Copy this nucleus to the merged image
            merged_nucleus_image[nucleus_image_labeled == nuclei_in_cell[0]] = next_label
            next_label += 1

    # Check if any nuclei were missed and ensure the output is labeled
    if np.max(merged_nucleus_image) == 0:
        # Fall back to the original nucleus image
        print("Warning: No nuclei were successfully processed for merging.")
        return nucleus_image_labeled

    # Ensure no zeros inside nuclei (can happen with overlapping cells)
    merged_nucleus_image = ndimage.binary_fill_holes(merged_nucleus_image > 0) * merged_nucleus_image
    filled_mask = merged_nucleus_image > 0
    labeled_filled, _ = ndimage.label(filled_mask)

    # Make sure each nucleus has a unique label
    for i in range(1, np.max(labeled_filled) + 1):
        nucleus_mask = labeled_filled == i
        if np.sum(nucleus_mask) > 0:
            # Get the most common non-zero label in this region
            labels = merged_nucleus_image[nucleus_mask]
            labels = labels[labels > 0]
            if len(labels) > 0:
                most_common = np.bincount(labels).argmax()
                merged_nucleus_image[nucleus_mask] = most_common
            else:
                # If no existing label, assign a new one
                merged_nucleus_image[nucleus_mask] = next_label
                next_label += 1

    return merged_nucleus_image

# IMPROVED: Function to detect senescent cells
def detect_senescent_cells(cell_image, nucleus_image, process_all_samples=False,
                          global_hole_stats=None, expected_senescent_fraction=0.3):
    """
    Detect senescent cells based on the presence of holes and other features.

    Parameters:
    cell_image: Label image of cells
    nucleus_image: Label image of nuclei
    process_all_samples: Flag to indicate if we're processing all samples at once
    global_hole_stats: Statistics on holes across all samples (for global thresholding)
    expected_senescent_fraction: Expected fraction of senescent cells (default 0.3)

    Returns:
    Dictionary with detection results and adjusted masks
    """
    # Ensure we have integer type images
    cell_image = cell_image.astype(np.int32)
    nucleus_image = nucleus_image.astype(np.int32)

    # Get cell properties
    cell_props = measure.regionprops(cell_image)

    # Prepare arrays for storing metrics
    cell_metrics = []

    # Process each cell
    for cell_prop in cell_props:
        cell_id = cell_prop.label

        # Skip very small objects (likely artifacts)
        if cell_prop.area < 100:
            continue

        # Create binary mask for current cell
        cell_mask = (cell_image == cell_id)

        # Count nuclei that overlap with this cell
        nuclei_in_cell = np.unique(nucleus_image[cell_mask])
        nuclei_in_cell = nuclei_in_cell[nuclei_in_cell > 0]  # Remove background (0)
        nuclei_count = len(nuclei_in_cell)

        # Calculate total nuclear area within this cell
        nuclear_area = np.sum(np.isin(nucleus_image, nuclei_in_cell) & cell_mask)
        nuclear_cytoplasmic_ratio = nuclear_area / cell_prop.area if cell_prop.area > 0 else 0

        # Calculate hole properties
        # Invert the cell mask to detect holes
        filled_mask = ndimage.binary_fill_holes(cell_mask)
        holes_mask = filled_mask & ~cell_mask

        # Label the holes
        labeled_holes, num_holes = ndimage.label(holes_mask)
        hole_sizes = [np.sum(labeled_holes == i) for i in range(1, num_holes + 1)]
        total_hole_area = np.sum(holes_mask)
        holes_to_cell_ratio = total_hole_area / cell_prop.area if cell_prop.area > 0 else 0

        # IMPROVED: Calculate more detailed hole metrics
        max_hole_size = max(hole_sizes) if hole_sizes else 0
        avg_hole_size = np.mean(hole_sizes) if hole_sizes else 0

        # Extract shape metrics
        perimeter = cell_prop.perimeter if cell_prop.perimeter else 0
        circularity = (4 * np.pi * cell_prop.area) / (perimeter * perimeter) if perimeter > 0 else 0
        solidity = cell_prop.solidity

        # Store all metrics
        metrics = {
            'cell_id': cell_id,
            'area': cell_prop.area,
            'perimeter': perimeter,
            'circularity': circularity,
            'solidity': solidity,
            'nuclei_count': nuclei_count,
            'nuclear_area': nuclear_area,
            'nuclear_cytoplasmic_ratio': nuclear_cytoplasmic_ratio,
            'num_holes': num_holes,
            'total_hole_area': total_hole_area,
            'max_hole_size': max_hole_size,
            'avg_hole_size': avg_hole_size,
            'holes_to_cell_ratio': holes_to_cell_ratio,
            'has_holes': num_holes > 0  # IMPROVED: Simple flag for cells with holes
        }

        cell_metrics.append(metrics)

    # Handle empty metrics case
    if not cell_metrics:
        print("No valid cells found for analysis")
        return {
            'cell_metrics': pd.DataFrame(),
            'original_cell_image': cell_image,
            'adjusted_cell_image': cell_image.copy(),
            'senescent_count': 0,
            'total_cells': 0,
            'senescent_fraction': 0
        }

    # Convert to DataFrame for easier analysis
    metrics_df = pd.DataFrame(cell_metrics)

    # IMPROVED: Modified senescence detection approach
    # First, mark all cells with holes as potentially senescent
    metrics_df['is_senescent'] = metrics_df['has_holes']

    # If we have global statistics, use them for thresholding
    if process_all_samples and global_hole_stats is not None:
        # Prioritize cells with holes but use global size thresholds to reach target percentage
        metrics_df['senescence_score'] = 0
        metrics_df.loc[metrics_df['has_holes'], 'senescence_score'] += 2  # Strong weight for holes
        metrics_df.loc[metrics_df['area'] > global_hole_stats['size_threshold'], 'senescence_score'] += 1
        metrics_df.loc[metrics_df['nuclei_count'] > 1, 'senescence_score'] += 1

        # Sort by score and select top X% as senescent
        metrics_df = metrics_df.sort_values('senescence_score', ascending=False)
        senescent_count = int(len(metrics_df) * expected_senescent_fraction)
        metrics_df['is_senescent'] = False
        metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True
    else:
        # For individual sample processing, prioritize cells with holes
        # but adjust to reach the expected senescent fraction
        has_holes_count = metrics_df['has_holes'].sum()
        total_cells = len(metrics_df)

        if has_holes_count / total_cells > expected_senescent_fraction:
            # If too many cells have holes, select the ones with the largest holes
            metrics_df = metrics_df.sort_values('total_hole_area', ascending=False)
            senescent_count = int(total_cells * expected_senescent_fraction)
            metrics_df['is_senescent'] = False
            metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True
        elif has_holes_count / total_cells < expected_senescent_fraction:
            # If too few cells have holes, add cells based on size and multinucleation
            metrics_df['senescence_score'] = 0
            metrics_df.loc[metrics_df['has_holes'], 'senescence_score'] += 3  # Strong weight for holes
            metrics_df.loc[metrics_df['area'] > metrics_df['area'].quantile(0.7), 'senescence_score'] += 1
            metrics_df.loc[metrics_df['nuclei_count'] > 1, 'senescence_score'] += 1

            # Sort by score and select top X% as senescent
            metrics_df = metrics_df.sort_values('senescence_score', ascending=False)
            senescent_count = int(total_cells * expected_senescent_fraction)
            metrics_df['is_senescent'] = False
            metrics_df.iloc[:senescent_count, metrics_df.columns.get_loc('is_senescent')] = True

    # Create adjusted cell masks by filling holes in senescent cells
    adjusted_cell_image = cell_image.copy()

    for _, row in metrics_df[metrics_df['is_senescent']].iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        # Fill holes in senescent cells
        filled_mask = ndimage.binary_fill_holes(cell_mask)

        # Update the adjusted image with the filled mask
        # We need to handle overlaps with other cells
        # First, remove the original cell
        adjusted_cell_image[cell_mask] = 0

        # Then add the filled version
        adjusted_cell_image[filled_mask] = cell_id

    # Calculate statistics
    senescent_count = metrics_df['is_senescent'].sum()
    total_cells = len(metrics_df)
    senescent_fraction = senescent_count / total_cells if total_cells > 0 else 0

    print(f"Detected {senescent_count} senescent cells out of {total_cells} total cells ({senescent_fraction:.2%})")

    # Return results
    return {
        'cell_metrics': metrics_df,
        'original_cell_image': cell_image,
        'adjusted_cell_image': adjusted_cell_image,
        'senescent_count': senescent_count,
        'total_cells': total_cells,
        'senescent_fraction': senescent_fraction
    }

# Function to calculate global statistics across all samples
def calculate_global_statistics(file_dict, complete_samples):
    """
    Calculate global statistics across all samples to use for consistent thresholding.

    Parameters:
    file_dict: Dictionary mapping sample IDs to component file paths
    complete_samples: List of sample IDs with complete data

    Returns:
    Dictionary with global statistics
    """
    print("Calculating global statistics across all samples...")

    all_cell_areas = []
    all_hole_ratios = []
    all_hole_counts = []
    cells_with_holes = 0
    total_cells = 0

    # Process each sample to gather statistics
    for sample_id in complete_samples:
        print(f"Processing sample {sample_id} for global statistics")

        # Load images
        images = load_sample_images(sample_id, file_dict)

        if 'cell' not in images or 'nuclei' not in images:
            print(f"Error: Required cell or nuclei image not found for sample {sample_id}")
            continue

        # Get cell properties
        cell_props = measure.regionprops(images['cell'].astype(np.int32))

        for cell_prop in cell_props:
            cell_id = cell_prop.label

            # Skip very small objects (likely artifacts)
            if cell_prop.area < 100:
                continue

            # Add cell area to global list
            all_cell_areas.append(cell_prop.area)

            # Check for holes
            cell_mask = (images['cell'] == cell_id)
            filled_mask = ndimage.binary_fill_holes(cell_mask)
            holes_mask = filled_mask & ~cell_mask

            # Label the holes
            labeled_holes, num_holes = ndimage.label(holes_mask)
            total_hole_area = np.sum(holes_mask)
            holes_to_cell_ratio = total_hole_area / cell_prop.area if cell_prop.area > 0 else 0

            # Track holes statistics
            all_hole_counts.append(num_holes)
            all_hole_ratios.append(holes_to_cell_ratio)

            if num_holes > 0:
                cells_with_holes += 1

            total_cells += 1

    # Calculate global statistics
    if total_cells == 0:
        print("No valid cells found across samples")
        return None

    # Calculate thresholds for metrics
    area_mean = np.mean(all_cell_areas)
    area_std = np.std(all_cell_areas)
    size_threshold = area_mean + 1.5 * area_std

    hole_ratio_mean = np.mean(all_hole_ratios)
    hole_ratio_std = np.std(all_hole_ratios)
    hole_ratio_threshold = hole_ratio_mean + 1.0 * hole_ratio_std

    global_hole_fraction = cells_with_holes / total_cells if total_cells > 0 else 0

    # Create statistics dictionary
    global_stats = {
        'total_cells': total_cells,
        'cells_with_holes': cells_with_holes,
        'hole_fraction': global_hole_fraction,
        'size_threshold': size_threshold,
        'hole_ratio_threshold': hole_ratio_threshold,
        'area_mean': area_mean,
        'area_std': area_std
    }

    print(f"Global statistics: {global_stats}")

    return global_stats

# Function to visualize senescence detection results
def visualize_senescence_detection(results, cell_image, nucleus_image, output_path=None):
    """
    Visualize the senescence detection results.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    # Check if we have valid results
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))

    # 1. Original cell image with nuclei overlay
    cell_boundary = segmentation.find_boundaries(cell_image)
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)

    overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Cell boundaries in green
    overlay[cell_boundary, 0] = 0
    overlay[cell_boundary, 1] = 255
    overlay[cell_boundary, 2] = 0
    # Nucleus boundaries in blue
    overlay[nucleus_boundary, 0] = 0
    overlay[nucleus_boundary, 1] = 0
    overlay[nucleus_boundary, 2] = 255

    axes[0, 0].imshow(overlay)
    axes[0, 0].set_title("Original Cells (green) and Nuclei (blue)")
    axes[0, 0].axis('off')

    # 2. Senescent vs normal cells
    senescent_mask = np.zeros_like(cell_image, dtype=bool)
    normal_mask = np.zeros_like(cell_image, dtype=bool)

    cell_metrics = results['cell_metrics']

    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        if row['is_senescent']:
            senescent_mask |= cell_mask
        else:
            normal_mask |= cell_mask

    classification = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Senescent cells in red
    classification[senescent_mask, 0] = 255
    classification[senescent_mask, 1] = 0
    classification[senescent_mask, 2] = 0
    # Normal cells in green
    classification[normal_mask, 0] = 0
    classification[normal_mask, 1] = 255
    classification[normal_mask, 2] = 0

    axes[0, 1].imshow(classification)
    axes[0, 1].set_title(f"Senescent Cells (red): {results['senescent_fraction']:.1%}")
    axes[0, 1].axis('off')

    # 3. Adjusted cell masks (after filling holes in senescent cells)
    adjusted_boundary = segmentation.find_boundaries(results['adjusted_cell_image'])

    adjusted_overlay = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Original cell boundaries in green
    adjusted_overlay[cell_boundary, 0] = 0
    adjusted_overlay[cell_boundary, 1] = 255
    adjusted_overlay[cell_boundary, 2] = 0
    # Adjusted cell boundaries in yellow
    adjusted_overlay[adjusted_boundary, 0] = 255
    adjusted_overlay[adjusted_boundary, 1] = 255
    adjusted_overlay[adjusted_boundary, 2] = 0

    axes[1, 0].imshow(adjusted_overlay)
    axes[1, 0].set_title("Original (green) vs Adjusted (yellow) Cell Boundaries")
    axes[1, 0].axis('off')

    # 4. Visualization of cells with holes
    has_holes_mask = np.zeros_like(cell_image, dtype=bool)
    no_holes_mask = np.zeros_like(cell_image, dtype=bool)

    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (cell_image == cell_id)

        if row['has_holes']:
            has_holes_mask |= cell_mask
        else:
            no_holes_mask |= cell_mask

    holes_vis = np.zeros((*cell_image.shape, 3), dtype=np.uint8)
    # Cells with holes in red
    holes_vis[has_holes_mask, 0] = 255
    holes_vis[has_holes_mask, 1] = 0
    holes_vis[has_holes_mask, 2] = 0
    # Cells without holes in blue
    holes_vis[no_holes_mask, 0] = 0
    holes_vis[no_holes_mask, 1] = 0
    holes_vis[no_holes_mask, 2] = 255

    # Add cell boundaries for clarity
    holes_vis[cell_boundary, :] = [255, 255, 255]  # White boundaries

    axes[1, 1].imshow(holes_vis)
    axes[1, 1].set_title("Cells With Holes (red) vs Without Holes (blue)")
    axes[1, 1].axis('off')

    # Add metrics summary as text
    senescent_metrics = cell_metrics[cell_metrics['is_senescent']]
    normal_metrics = cell_metrics[~cell_metrics['is_senescent']]

    if len(senescent_metrics) > 0:
        sen_text = (f"Senescent cells (n={len(senescent_metrics)}):\n" +
                   f"Mean area: {senescent_metrics['area'].mean():.1f}\n" +
                   f"Mean holes: {senescent_metrics['num_holes'].mean():.1f}\n" +
                   f"Mean nuclei: {senescent_metrics['nuclei_count'].mean():.1f}")
    else:
        sen_text = "No senescent cells detected"

    if len(normal_metrics) > 0:
        norm_text = (f"Normal cells (n={len(normal_metrics)}):\n" +
                    f"Mean area: {normal_metrics['area'].mean():.1f}\n" +
                    f"Mean holes: {normal_metrics['num_holes'].mean():.1f}\n" +
                    f"Mean nuclei: {normal_metrics['nuclei_count'].mean():.1f}")
    else:
        norm_text = "No normal cells detected"

    fig.text(0.02, 0.02, sen_text, fontsize=10)
    fig.text(0.52, 0.02, norm_text, fontsize=10)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")

    plt.show()

# Function to visualize multinucleated cells
def visualize_multinucleated_cells(results, cell_image, nucleus_image, output_path=None):
    """
    Create visualization specifically highlighting multinucleated cells.

    Parameters:
    results: Results from detect_senescent_cells
    cell_image: Original cell label image
    nucleus_image: Original nucleus label image
    output_path: Path to save the visualization
    """
    if results['total_cells'] == 0:
        print("No valid cells to visualize")
        return

    # Get metrics
    cell_metrics = results['cell_metrics']

    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    # Prepare colored image based on nuclei count
    multi_colored = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Color scheme:
    # 0 nuclei: gray
    # 1 nucleus: blue
    # 2 nuclei: green
    # 3+ nuclei: red
    color_map = {
        0: [100, 100, 100],  # Gray
        1: [0, 0, 255],      # Blue
        2: [0, 255, 0],      # Green
        3: [255, 0, 0]       # Red (3 or more)
    }

    # Apply colors
    for _, row in cell_metrics.iterrows():
        cell_id = int(row['cell_id'])
        nuclei_count = min(3, row['nuclei_count'])  # Cap at 3+ for coloring

        cell_mask = (cell_image == cell_id)
        color = color_map[nuclei_count]

        multi_colored[cell_mask, 0] = color[0]
        multi_colored[cell_mask, 1] = color[1]
        multi_colored[cell_mask, 2] = color[2]

    # Draw nuclei boundaries
    nucleus_boundary = segmentation.find_boundaries(nucleus_image)
    multi_colored[nucleus_boundary] = [255, 255, 255]  # White nucleus boundaries

    # Display colored image
    axes[0].imshow(multi_colored)
    axes[0].set_title("Cell Nuclei Count\nGray: 0, Blue: 1, Green: 2, Red: 3+")
    axes[0].axis('off')

    # Create distribution bar chart
    nuclei_counts = cell_metrics['nuclei_count'].value_counts().sort_index()

    # Ensure we have entries for 0, 1, 2, 3+ nuclei
    counts = [0, 0, 0, 0]
    for count, freq in nuclei_counts.items():
        if count >= 3:
            counts[3] += freq
        else:
            counts[count] = freq

    # Plot bar chart
    bars = axes[1].bar(['0', '1', '2', '3+'], counts, color=['gray', 'blue', 'green', 'red'])

    # Add count labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{int(height)}', ha='center', va='bottom')

    axes[1].set_xlabel('Number of Nuclei')
    axes[1].set_ylabel('Number of Cells')
    axes[1].set_title('Distribution of Nuclei Count')

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Multinucleated visualization saved to {output_path}")

    plt.show()

# Improved function to create a visualization showing holes in cells
def visualize_cell_holes_improved(cell_image, output_path=None):
    """
    Create an improved visualization showing cells with holes.

    Parameters:
    cell_image: Label image of cells
    output_path: Path to save the visualization
    """
    # Create figure with multiple subplots for better analysis
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))

    # 1. Basic cell and hole visualization
    cell_mask = cell_image > 0
    filled_cell_mask = ndimage.binary_fill_holes(cell_mask)
    holes_mask = filled_cell_mask & ~cell_mask

    # Create RGB image for visualization
    vis_image = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Cells in green
    vis_image[cell_mask, 1] = 180

    # Holes in red
    vis_image[holes_mask, 0] = 255

    # Cell boundaries in white
    cell_boundary = segmentation.find_boundaries(cell_image)
    vis_image[cell_boundary] = [255, 255, 255]

    axes[0, 0].imshow(vis_image)
    axes[0, 0].set_title('Cells (green) with Holes (red)')
    axes[0, 0].axis('off')

    # 2. Filtered holes visualization (removing small artifacts)
    # Label holes
    labeled_holes, num_holes = ndimage.label(holes_mask)

    # Filter small holes
    min_significant_hole_size = 10  # Minimum size for a significant hole
    significant_holes_mask = np.zeros_like(holes_mask, dtype=bool)

    hole_sizes = []
    for i in range(1, num_holes + 1):
        hole_mask = labeled_holes == i
        hole_size = np.sum(hole_mask)
        hole_sizes.append(hole_size)

        if hole_size >= min_significant_hole_size:
            significant_holes_mask |= hole_mask

    # Create RGB image for visualization with significant holes
    filtered_vis_image = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Cells in green
    filtered_vis_image[cell_mask, 1] = 180

    # Significant holes in red
    filtered_vis_image[significant_holes_mask, 0] = 255

    # Cell boundaries in white
    filtered_vis_image[cell_boundary] = [255, 255, 255]

    axes[0, 1].imshow(filtered_vis_image)
    axes[0, 1].set_title(f'Cells with Significant Holes (size ≥ {min_significant_hole_size} pixels)')
    axes[0, 1].axis('off')

    # 3. Color cells by number of holes
    # Count holes per cell
    cells_with_holes = {}

    for cell_id in np.unique(cell_image):
        if cell_id == 0:  # Skip background
            continue

        cell_mask = cell_image == cell_id
        cell_holes = significant_holes_mask & cell_mask

        # Count separate holes in this cell
        if np.any(cell_holes):
            labeled_cell_holes, num_cell_holes = ndimage.label(cell_holes)
            cells_with_holes[cell_id] = num_cell_holes

    # Create colored image based on hole count
    hole_count_vis = np.zeros((*cell_image.shape, 3), dtype=np.uint8)

    # Color cells based on hole count (blue: 0, green: 1, yellow: 2, red: 3+)
    color_map = {
        0: [0, 0, 180],     # Blue: no holes
        1: [0, 180, 0],     # Green: 1 hole
        2: [180, 180, 0],   # Yellow: 2 holes
        3: [180, 0, 0]      # Red: 3+ holes
    }

    for cell_id in np.unique(cell_image):
        if cell_id == 0:  # Skip background
            continue

        cell_mask = cell_image == cell_id
        hole_count = cells_with_holes.get(cell_id, 0)
        hole_count = min(3, hole_count)  # Cap at 3+ for coloring

        color = color_map[hole_count]
        hole_count_vis[cell_mask, 0] = color[0]
        hole_count_vis[cell_mask, 1] = color[1]
        hole_count_vis[cell_mask, 2] = color[2]

    # Add cell boundaries
    hole_count_vis[cell_boundary] = [255, 255, 255]

    axes[1, 0].imshow(hole_count_vis)
    axes[1, 0].set_title('Cells by Hole Count\nBlue: 0, Green: 1, Yellow: 2, Red: 3+')
    axes[1, 0].axis('off')

    # 4. Histogram of hole sizes
    if hole_sizes:
        axes[1, 1].hist(hole_sizes, bins=20, color='coral')
        axes[1, 1].set_xlabel('Hole Size (pixels)')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].set_title('Distribution of Hole Sizes')
        axes[1, 1].grid(True, alpha=0.3)

        # Add vertical line for the significance threshold
        axes[1, 1].axvline(x=min_significant_hole_size, color='red', linestyle='--',
                          label=f'Threshold ({min_significant_hole_size} px)')
        axes[1, 1].legend()
    else:
        axes[1, 1].text(0.5, 0.5, "No holes detected",
                       ha='center', va='center', transform=axes[1, 1].transAxes)

    # Add summary statistics as text
    cell_count = len(np.unique(cell_image)) - 1  # Subtract 1 for background
    cells_with_any_holes = len(cells_with_holes)
    significant_hole_count = np.sum([count > 0 for count in cells_with_holes.values()])
    multi_hole_count = np.sum([count > 1 for count in cells_with_holes.values()])

    stats_text = (f"Total cells: {cell_count}\n" +
                 f"Cells with holes: {cells_with_any_holes} ({cells_with_any_holes/cell_count:.1%})\n" +
                 f"Cells with multiple holes: {multi_hole_count} ({multi_hole_count/cell_count:.1%})")

    fig.text(0.02, 0.02, stats_text, fontsize=10)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Enhanced cell holes visualization saved to {output_path}")

    plt.close()  # Close the figure to avoid displaying in notebooks

# Function to process a sample for senescence analysis
def process_senescence_analysis(sample_id, file_dict, output_dir, global_stats=None, process_all_samples=False):
    """
    Process a sample for senescence analysis.

    Parameters:
    sample_id: ID of the sample to process
    file_dict: Dictionary mapping sample IDs to component file paths
    output_dir: Directory to save outputs
    global_stats: Global statistics for consistent thresholding
    process_all_samples: Flag to indicate if we're processing all samples together

    Returns:
    Results from senescence detection
    """
    print(f"Processing senescence analysis for sample {sample_id}")

    # Create sample-specific output directory
    sample_output_dir = os.path.join(output_dir, sample_id)
    os.makedirs(sample_output_dir, exist_ok=True)

    # Load images
    images = load_sample_images(sample_id, file_dict)

    if 'cell' not in images or 'nuclei' not in images:
        print(f"Error: Required cell or nuclei image not found for sample {sample_id}")
        return None

    # Display raw image details
    print(f"Cell image shape: {images['cell'].shape}, range: [{np.min(images['cell'])}, {np.max(images['cell'])}]")
    print(f"Nuclei image shape: {images['nuclei'].shape}, range: [{np.min(images['nuclei'])}, {np.max(images['nuclei'])}]")

    # Create visualization showing raw cells and nuclei
    plt.figure(figsize=(10, 10))
    cell_boundaries = segmentation.find_boundaries(images['cell'])
    nuclei_boundaries = segmentation.find_boundaries(images['nuclei'])

    # Create RGB image for raw visualization
    raw_vis = np.zeros((*images['cell'].shape, 3), dtype=np.uint8)
    raw_vis[cell_boundaries, 1] = 255  # Green cell boundaries
    raw_vis[nuclei_boundaries, 2] = 255  # Blue nuclei boundaries

    plt.imshow(raw_vis)
    plt.title('Raw Cell (green) and Nuclei (blue) Boundaries')
    plt.savefig(os.path.join(sample_output_dir, f"{sample_id}_raw_boundaries.png"), dpi=300)
    plt.close()

    # IMPROVED: Apply preprocessing to enhance cell and nuclei detection
    # Preprocess cell mask to ensure proper hole detection
    cell_mask = images['cell'] > 0

    # Perform morphological closing to fill very small gaps
    struct_elem = np.ones((3, 3), bool)
    closed_cell_mask = ndimage.binary_closing(cell_mask, structure=struct_elem)

    # Label connected components
    labeled_cells, cell_count = ndimage.label(closed_cell_mask)
    print(f"Detected {cell_count} cells after preprocessing")

    # Use enhanced senescent cell detection focusing on holes and size
    results = detect_senescent_cells_enhanced(
        images['cell'],
        images['nuclei'],
        expected_senescent_fraction=0.3
    )

    # Skip saving if no valid cells found
    if results['total_cells'] == 0:
        print(f"No valid cells found in sample {sample_id}, skipping output generation")
        return results

    # Create special visualization for hole detection
    plt.figure(figsize=(12, 10))

    # Get holes information
    cell_holes = detect_holes_in_cell_mask(images['cell'])

    # Create a visualization showing cells and detected holes
    hole_vis = np.zeros((*images['cell'].shape, 3), dtype=np.uint8)

    # Draw cells in green
    cell_mask = images['cell'] > 0
    hole_vis[cell_mask, 1] = 180  # Green for cells

    # Draw holes in red
    for cell_id, holes in cell_holes.items():
        for hole in holes:
            hole_vis[hole['mask'], 0] = 255  # Red for holes

    # Draw cell boundaries in white
    cell_boundary = segmentation.find_boundaries(images['cell'])
    hole_vis[cell_boundary] = [255, 255, 255]

    plt.imshow(hole_vis)
    plt.title(f'Enhanced Hole Detection - Found {len(cell_holes)} cells with holes')
    plt.axis('off')
    plt.savefig(os.path.join(sample_output_dir, f"{sample_id}_enhanced_holes.png"), dpi=300)
    plt.close()

    # Save results
    # 1. Save metrics as CSV
    metrics_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_metrics.csv")
    results['cell_metrics'].to_csv(metrics_file, index=False)
    print(f"Saved cell metrics to {metrics_file}")

    # 2. Save adjusted cell mask
    adjusted_mask_file = os.path.join(sample_output_dir, f"{sample_id}_adjusted_cell_mask.tif")
    io.imsave(adjusted_mask_file, results['adjusted_cell_image'].astype(np.uint16))
    print(f"Saved adjusted cell mask to {adjusted_mask_file}")

    # 3. Create and save senescent vs normal cell visualization
    plt.figure(figsize=(12, 10))

    # Create a colored image for senescent vs normal cells
    sen_vis = np.zeros((*images['cell'].shape, 3), dtype=np.uint8)

    # Color cells based on senescence classification
    for _, row in results['cell_metrics'].iterrows():
        cell_id = int(row['cell_id'])
        cell_mask = (images['cell'] == cell_id)

        if row['is_senescent']:
            sen_vis[cell_mask, 0] = 255  # Red for senescent
        else:
            sen_vis[cell_mask, 1] = 180  # Green for normal

    # Add cell boundaries
    sen_vis[cell_boundary] = [255, 255, 255]

    plt.imshow(sen_vis)
    plt.title(f'Senescent Cells (red): {results["senescent_fraction"]:.1%}')
    plt.axis('off')
    plt.savefig(os.path.join(sample_output_dir, f"{sample_id}_senescent_classification.png"), dpi=300)
    plt.close()

    # 4. Additional visualization showing senescent vs normal size distribution
    senescent = results['cell_metrics'][results['cell_metrics']['is_senescent']]
    normal = results['cell_metrics'][~results['cell_metrics']['is_senescent']]

    if len(senescent) > 0 and len(normal) > 0:
        plt.figure(figsize=(12, 6))

        # Plot size distribution
        plt.subplot(1, 2, 1)
        plt.hist([normal['area'], senescent['area']], bins=20,
                alpha=0.7, label=['Normal', 'Senescent'])
        plt.xlabel('Cell Area (pixels)')
        plt.ylabel('Frequency')
        plt.title('Size Distribution')
        plt.legend()

        # Plot hole distribution if any holes
        plt.subplot(1, 2, 2)
        with_holes = len(results['cell_metrics'][results['cell_metrics']['has_holes']])
        without_holes = results['total_cells'] - with_holes
        plt.bar(['Without Holes', 'With Holes'], [without_holes, with_holes],
               color=['blue', 'red'])
        plt.ylabel('Number of Cells')
        plt.title(f'Cells With Holes: {with_holes}/{results["total_cells"]} ({with_holes/results["total_cells"]:.1%})')

        plt.tight_layout()
        plt.savefig(os.path.join(sample_output_dir, f"{sample_id}_size_hole_distribution.png"), dpi=300)
        plt.close()

    # 5. Summary statistics
    summary = {
        'sample_id': sample_id,
        'total_cells': results['total_cells'],
        'senescent_cells': results['senescent_count'],
        'senescent_fraction': results['senescent_fraction'],
        'cells_with_holes': results['cells_with_holes'],
        'hole_fraction': results['cells_with_holes'] / results['total_cells'] if results['total_cells'] > 0 else 0
    }

    # Add size metrics
    if len(senescent) > 0 and len(normal) > 0:
        summary['senescent_mean_area'] = senescent['area'].mean()
        summary['normal_mean_area'] = normal['area'].mean()
        summary['size_ratio'] = summary['senescent_mean_area'] / summary['normal_mean_area']

    # Save summary as CSV
    summary_df = pd.DataFrame([summary])
    summary_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_summary.csv")
    summary_df.to_csv(summary_file, index=False)
    print(f"Saved summary to {summary_file}")

    return results

    # Skip saving if no valid cells found
    if results['total_cells'] == 0:
        print(f"No valid cells found in sample {sample_id}, skipping output generation")
        return results

    # Save results
    # 1. Save metrics as CSV
    metrics_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_metrics.csv")
    results['cell_metrics'].to_csv(metrics_file, index=False)
    print(f"Saved cell metrics to {metrics_file}")

    # 2. Save adjusted cell mask
    adjusted_mask_file = os.path.join(sample_output_dir, f"{sample_id}_adjusted_cell_mask.tif")
    io.imsave(adjusted_mask_file, results['adjusted_cell_image'].astype(np.uint16))
    print(f"Saved adjusted cell mask to {adjusted_mask_file}")

    # 3. Create and save visualization
    vis_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_visualization.png")
    visualize_senescence_detection(results, images['cell'], merged_nucleus_image, vis_file)

    # 4. Create and save multinucleated visualization
    multi_vis_file = os.path.join(sample_output_dir, f"{sample_id}_multinucleated_visualization.png")
    visualize_multinucleated_cells(results, images['cell'], merged_nucleus_image, multi_vis_file)

    # 5. Summary statistics
    summary = {
        'sample_id': sample_id,
        'total_cells': results['total_cells'],
        'senescent_cells': results['senescent_count'],
        'senescent_fraction': results['senescent_fraction'],
        'normal_cells': results['total_cells'] - results['senescent_count'],
        'normal_fraction': 1 - results['senescent_fraction'],
        'cells_with_holes': results['cell_metrics']['has_holes'].sum(),
        'hole_fraction': results['cell_metrics']['has_holes'].sum() / results['total_cells'] if results['total_cells'] > 0 else 0
    }

    # Add more detailed metrics if available
    if not results['cell_metrics'].empty:
        # Compare senescent vs normal cells
        senescent = results['cell_metrics'][results['cell_metrics']['is_senescent']]
        normal = results['cell_metrics'][~results['cell_metrics']['is_senescent']]

        # Add size metrics
        if not senescent.empty and not normal.empty:
            summary['senescent_mean_area'] = senescent['area'].mean()
            summary['normal_mean_area'] = normal['area'].mean()
            summary['size_ratio'] = summary['senescent_mean_area'] / summary['normal_mean_area']

            summary['senescent_mean_nuclei'] = senescent['nuclei_count'].mean()
            summary['normal_mean_nuclei'] = normal['nuclei_count'].mean()

            summary['senescent_multi_nuclei_pct'] = (senescent['nuclei_count'] > 1).mean()
            summary['normal_multi_nuclei_pct'] = (normal['nuclei_count'] > 1).mean()

            # Add hole statistics
            summary['senescent_mean_holes'] = senescent['num_holes'].mean()
            summary['normal_mean_holes'] = normal['num_holes'].mean()

            summary['senescent_with_holes_pct'] = senescent['has_holes'].mean()
            summary['normal_with_holes_pct'] = normal['has_holes'].mean()

    # Save summary as CSV
    summary_df = pd.DataFrame([summary])
    summary_file = os.path.join(sample_output_dir, f"{sample_id}_senescence_summary.csv")
    summary_df.to_csv(summary_file, index=False)
    print(f"Saved summary to {summary_file}")

    return results

# Function to compile results across all samples
def compile_cross_sample_results(output_dir, sample_ids):
    """Compile results across all analyzed samples"""

    all_summaries = []
    all_metrics = []

    for sample_id in sample_ids:
        sample_dir = os.path.join(output_dir, sample_id)

        # Read summary file
        summary_file = os.path.join(sample_dir, f"{sample_id}_senescence_summary.csv")
        if os.path.exists(summary_file):
            summary = pd.read_csv(summary_file)
            all_summaries.append(summary)

        # Read metrics file
        metrics_file = os.path.join(sample_dir, f"{sample_id}_senescence_metrics.csv")
        if os.path.exists(metrics_file):
            metrics = pd.read_csv(metrics_file)
            metrics['sample_id'] = sample_id  # Add sample ID
            all_metrics.append(metrics)

    # Combine all summaries
    if all_summaries:
        combined_summary = pd.concat(all_summaries, ignore_index=True)
        summary_output = os.path.join(output_dir, "all_samples_senescence_summary.csv")
        combined_summary.to_csv(summary_output, index=False)
        print(f"Saved combined summary to {summary_output}")

        # Create a visualization of senescence percentages across samples
        plt.figure(figsize=(12, 6))
        plt.bar(combined_summary['sample_id'],
                combined_summary['senescent_fraction'] * 100)
        plt.xlabel('Sample ID')
        plt.ylabel('Senescent Cells (%)')
        plt.title('Percentage of Senescent Cells Across Samples')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "senescence_percentages.png"), dpi=300)
        plt.close()

        # Create visualization of hole fractions across samples
        if 'hole_fraction' in combined_summary.columns:
            plt.figure(figsize=(12, 6))
            plt.bar(combined_summary['sample_id'],
                    combined_summary['hole_fraction'] * 100)
            plt.xlabel('Sample ID')
            plt.ylabel('Cells With Holes (%)')
            plt.title('Percentage of Cells With Holes Across Samples')
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "hole_fractions.png"), dpi=300)
            plt.close()

    # Combine all metrics
    if all_metrics:
        combined_metrics = pd.concat(all_metrics, ignore_index=True)
        metrics_output = os.path.join(output_dir, "all_samples_cell_metrics.csv")
        combined_metrics.to_csv(metrics_output, index=False)
        print(f"Saved combined metrics to {metrics_output}")

        # Create boxplots comparing senescent vs normal cells
        if 'is_senescent' in combined_metrics.columns:
            senescent = combined_metrics[combined_metrics['is_senescent']]
            normal = combined_metrics[~combined_metrics['is_senescent']]

            metrics_to_plot = ['area', 'nuclei_count', 'solidity', 'num_holes', 'circularity']
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            axes = axes.flatten()

            for i, metric in enumerate(metrics_to_plot):
                if i < len(axes) and metric in combined_metrics.columns:
                    if len(senescent) > 0 and len(normal) > 0:
                        axes[i].boxplot([normal[metric], senescent[metric]],
                                      labels=['Normal', 'Senescent'])
                        axes[i].set_ylabel(metric)
                        axes[i].set_title(f'{metric} Distribution')

            # Add a plot for holes to cell ratio
            if len(axes) > len(metrics_to_plot) and 'holes_to_cell_ratio' in combined_metrics.columns:
                i = len(metrics_to_plot)
                if len(senescent) > 0 and len(normal) > 0:
                    axes[i].boxplot([normal['holes_to_cell_ratio'], senescent['holes_to_cell_ratio']],
                                  labels=['Normal', 'Senescent'])
                    axes[i].set_ylabel('Holes to Cell Ratio')
                    axes[i].set_title('Holes to Cell Ratio Distribution')

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "senescence_metrics_comparison.png"), dpi=300)
            plt.close()

            # Create scatter plot of area vs nuclei count
            plt.figure(figsize=(10, 8))
            plt.scatter(normal['area'], normal['nuclei_count'],
                      alpha=0.5, label='Normal', color='blue')
            plt.scatter(senescent['area'], senescent['nuclei_count'],
                      alpha=0.5, label='Senescent', color='red')
            plt.xlabel('Cell Area')
            plt.ylabel('Nuclei Count')
            plt.title('Cell Area vs Nuclei Count')
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "area_vs_nuclei_scatter.png"), dpi=300)
            plt.close()

            # Create scatter plot of area vs holes
            if 'num_holes' in combined_metrics.columns:
                plt.figure(figsize=(10, 8))
                plt.scatter(normal['area'], normal['num_holes'],
                          alpha=0.5, label='Normal', color='blue')
                plt.scatter(senescent['area'], senescent['num_holes'],
                          alpha=0.5, label='Senescent', color='red')
                plt.xlabel('Cell Area')
                plt.ylabel('Number of Holes')
                plt.title('Cell Area vs Number of Holes')
                plt.legend()
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, "area_vs_holes_scatter.png"), dpi=300)
                plt.close()

    print("Compiled results across all samples")

# IMPROVED: Two-phase processing approach
def process_with_global_statistics(file_dict, complete_samples, output_dir):
    """
    Process all samples with consistent thresholding based on global statistics.

    Parameters:
    file_dict: Dictionary mapping sample IDs to component file paths
    complete_samples: List of sample IDs with complete data
    output_dir: Directory to save outputs
    """
    print("Starting two-phase senescence analysis...")

    # Phase 1: Calculate global statistics
    global_stats = calculate_global_statistics(file_dict, complete_samples)

    if not global_stats or global_stats['total_cells'] == 0:
        print("Could not calculate global statistics. Falling back to per-sample processing.")
        # Fall back to individual sample processing
        for sample_id in complete_samples:
            process_senescence_analysis(sample_id, file_dict, output_dir)
        return

    # Save global statistics
    global_stats_df = pd.DataFrame([global_stats])
    global_stats_output = os.path.join(output_dir, "global_statistics.csv")
    global_stats_df.to_csv(global_stats_output, index=False)
    print(f"Saved global statistics to {global_stats_output}")

    # Phase 2: Process each sample with global statistics
    for sample_id in complete_samples:
        print(f"\nProcessing sample {sample_id} with global statistics")
        senescence_results = process_senescence_analysis(
            sample_id, file_dict, output_dir,
            global_stats=global_stats,
            process_all_samples=True
        )

        if senescence_results:
            print(f"Completed senescence analysis for sample {sample_id}")
            print(f"Found {senescence_results['senescent_count']} senescent cells " +
                  f"({senescence_results['senescent_fraction']:.1%}) out of {senescence_results['total_cells']} total cells")

    # Compile results across all samples
    compile_cross_sample_results(output_dir, complete_samples)

    print("Two-phase senescence analysis completed!")

# Function to run all processing
def main():
    print("Starting improved senescent cell analysis pipeline...")

    # Find matching files
    file_dict, complete_samples = find_matching_files(base_dir)

    if not complete_samples:
        print("No complete samples found. Exiting.")
        return

    # IMPROVED: Use the two-phase processing approach for consistent results
    process_with_global_statistics(file_dict, complete_samples, output_dir)

    print("Improved senescence analysis pipeline completed!")

# Run the main function
if __name__ == "__main__":
    main()

Starting improved senescent cell analysis pipeline...
Found 18 TIFF files in Nuclei folder.
Found 18 TIFF files in Membrane_Adjusted folder.
Found 18 TIFF files in Golgi folder.
Found 18 TIFF files in Cell folder.
Found 18 complete samples with at least nuclei and cell data.
Starting two-phase senescence analysis...
Calculating global statistics across all samples...
Processing sample denoised_0Pa_A1_20dec21_20xA_L2RA_FlatA_seq001_contrast for global statistics
Loaded nuclei image: shape (1024, 1024), dtype: uint32, value range [0, 241]
Loaded membrane image: shape (1024, 1024), dtype: uint8, value range [0, 1]
Loaded golgi image: shape (1024, 1024), dtype: uint32, value range [0, 194]
Loaded cell image: shape (1024, 1024), dtype: uint32, value range [0, 241]
Processing sample denoised_0Pa_A1_20dec21_20xA_L2RA_FlatA_seq002_contrast for global statistics
Loaded nuclei image: shape (1024, 1024), dtype: uint32, value range [0, 329]
Loaded membrane image: shape (1024, 1024), dtype: uint8, 

NameError: name 'detect_senescent_cells_enhanced' is not defined