In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation
from scipy import ndimage
from collections import Counter
import re
import scipy.stats as stats

# Set up matplotlib for better visualization
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')

# Define your input and output directories
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
membrane_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Membrane"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Relation"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Function to extract pressure value from filename
def extract_pressure(filename):
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    if match:
        return match.group(1)
    return None

# Function to extract sample ID from filename to match cell and membrane files
def extract_sample_id(filename):
    # Extract pattern like "1.4Pa_U_05mar19_20x_L2R_Flat_seq005" from filename
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    if match:
        return match.group(1)
    return None

# Function to find and organize mask files by pressure
def find_mask_files(cell_dir, membrane_dir):
    print("Organizing files by pressure and finding corresponding pairs...")

    # Create a dictionary to store file info by pressure
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    # Find all cell mask files
    cell_files = [f for f in os.listdir(cell_dir) if f.endswith('_cell_mask_merged_conservative.tif') or
                                                    f.endswith('_cell_mask.tif')]
    print(f"Found {len(cell_files)} cell mask files.")

    # Find all membrane files
    membrane_files = [f for f in os.listdir(membrane_dir) if f.endswith('.tif')]
    print(f"Found {len(membrane_files)} membrane mask files.")

    # Debug: Print some examples to check pattern matching
    if cell_files:
        sample_cell = cell_files[0]
        sample_id = extract_sample_id(sample_cell)
        print(f"Example cell file: {sample_cell}")
        print(f"Extracted sample ID: {sample_id}")

    if membrane_files:
        sample_membrane = membrane_files[0]
        sample_id = extract_sample_id(sample_membrane)
        print(f"Example membrane file: {sample_membrane}")
        print(f"Extracted sample ID: {sample_id}")

    # Create a lookup dictionary for membrane files
    membrane_lookup = {}
    for membrane_file in membrane_files:
        sample_id = extract_sample_id(membrane_file)
        if sample_id:
            membrane_lookup[sample_id] = membrane_file
            print(f"Added to lookup: {sample_id} -> {membrane_file}")
        else:
            print(f"Warning: Could not extract sample ID from {membrane_file}")

    # Organize cell files and find corresponding membrane files
    pairs_found = 0
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if pressure and pressure in pressure_dict and sample_id:
            # Find matching membrane file
            if sample_id in membrane_lookup:
                membrane_file = membrane_lookup[sample_id]

                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'membrane_file': os.path.join(membrane_dir, membrane_file),
                    'sample_id': sample_id
                }

                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                print(f"Matched pair: {cell_file} -> {membrane_file}")
            else:
                print(f"Warning: No matching membrane file found for {cell_file} with ID {sample_id}")
        else:
            print(f"Warning: Invalid pressure or sample ID for {cell_file}")

    print(f"Found {pairs_found} matching cell-membrane file pairs.")
    print(f"Organized by pressure:")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} file pairs")

    return pressure_dict

# Function to load a mask image
def load_mask_image(filepath):
    try:
        img = io.imread(filepath)
        if img.dtype == bool:
            img = img.astype(np.uint8)
        if np.max(img) <= 1:
            img = (img > 0).astype(np.uint8)
        return img
    except Exception as e:
        print(f"Error loading {filepath}: {str(e)}")
        return None

# Function to analyze cell and membrane relationship
def analyze_cell_membrane_relation(cell_mask, membrane_mask):
    # Ensure both masks are binary (0 and 1)
    if np.max(cell_mask) > 1:
        binary_cell_mask = (cell_mask > 0).astype(np.uint8)
    else:
        binary_cell_mask = cell_mask.astype(np.uint8)

    if np.max(membrane_mask) > 1:
        binary_membrane_mask = (membrane_mask > 0).astype(np.uint8)
    else:
        binary_membrane_mask = membrane_mask.astype(np.uint8)

    # Get labeled cell mask for individual cell analysis
    labeled_cells, num_cells = ndimage.label(binary_cell_mask)
    print(f"Found {num_cells} individual cells")

    # Calculate total area of cells and membrane
    total_cell_area = np.sum(binary_cell_mask)
    total_membrane_area = np.sum(binary_membrane_mask)

    # Calculate intersection area (membrane overlapping with cells)
    intersection = np.logical_and(binary_cell_mask, binary_membrane_mask).astype(np.uint8)
    intersection_area = np.sum(intersection)

    # Calculate membrane outside cells
    membrane_outside_cells = np.logical_and(binary_membrane_mask, np.logical_not(binary_cell_mask)).astype(np.uint8)
    membrane_outside_area = np.sum(membrane_outside_cells)

    # Calculate overlap ratio
    if total_membrane_area > 0:
        membrane_cell_overlap_ratio = intersection_area / total_membrane_area
    else:
        membrane_cell_overlap_ratio = 0

    # Analyze membrane coverage for each cell
    cell_metrics = []

    for cell_id in range(1, num_cells + 1):
        # Extract single cell mask
        single_cell_mask = (labeled_cells == cell_id).astype(np.uint8)

        # Calculate cell properties
        cell_props = measure.regionprops(single_cell_mask)[0]

        # Calculate cell area
        cell_area = np.sum(single_cell_mask)

        # Calculate cell perimeter
        contours, _ = cv2.findContours(single_cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        if contours:
            cell_perimeter = cv2.arcLength(contours[0], True)
        else:
            cell_perimeter = 0

        # Calculate membrane overlap with this cell
        cell_membrane_overlap = np.logical_and(single_cell_mask, binary_membrane_mask).astype(np.uint8)
        cell_membrane_overlap_area = np.sum(cell_membrane_overlap)

        # Calculate cell boundary
        cell_boundary = segmentation.find_boundaries(single_cell_mask, mode='outer').astype(np.uint8)
        cell_boundary_length = np.sum(cell_boundary)

        # Calculate membrane on cell boundary
        membrane_on_boundary = np.logical_and(cell_boundary, binary_membrane_mask).astype(np.uint8)
        membrane_boundary_length = np.sum(membrane_on_boundary)

        # Calculate membrane coverage ratio for this cell
        if cell_area > 0:
            membrane_coverage_ratio = cell_membrane_overlap_area / cell_area
        else:
            membrane_coverage_ratio = 0

        # Calculate boundary coverage ratio
        if cell_boundary_length > 0:
            boundary_coverage_ratio = membrane_boundary_length / cell_boundary_length
        else:
            boundary_coverage_ratio = 0

        # Calculate shape metrics
        circularity = (4 * np.pi * cell_area) / (cell_perimeter ** 2) if cell_perimeter > 0 else 0

        # Store cell metrics
        metrics = {
            'cell_id': cell_id,
            'cell_area': cell_area,
            'cell_perimeter': cell_perimeter,
            'membrane_overlap_area': cell_membrane_overlap_area,
            'membrane_coverage_ratio': membrane_coverage_ratio,
            'boundary_length': cell_boundary_length,
            'membrane_on_boundary_length': membrane_boundary_length,
            'boundary_coverage_ratio': boundary_coverage_ratio,
            'circularity': circularity,
            'centroid_y': cell_props.centroid[0],
            'centroid_x': cell_props.centroid[1]
        }

        cell_metrics.append(metrics)

    # Calculate average metrics across all cells
    if cell_metrics:
        avg_membrane_coverage = np.mean([m['membrane_coverage_ratio'] for m in cell_metrics])
        avg_boundary_coverage = np.mean([m['boundary_coverage_ratio'] for m in cell_metrics])
        median_membrane_coverage = np.median([m['membrane_coverage_ratio'] for m in cell_metrics])
        median_boundary_coverage = np.median([m['boundary_coverage_ratio'] for m in cell_metrics])
    else:
        avg_membrane_coverage = 0
        avg_boundary_coverage = 0
        median_membrane_coverage = 0
        median_boundary_coverage = 0

    # Create summary results
    results = {
        'total_cells': num_cells,
        'total_cell_area': total_cell_area,
        'total_membrane_area': total_membrane_area,
        'intersection_area': intersection_area,
        'membrane_outside_area': membrane_outside_area,
        'membrane_cell_overlap_ratio': membrane_cell_overlap_ratio,
        'avg_membrane_coverage': avg_membrane_coverage,
        'avg_boundary_coverage': avg_boundary_coverage,
        'median_membrane_coverage': median_membrane_coverage,
        'median_boundary_coverage': median_boundary_coverage,
        'cell_metrics': cell_metrics
    }

    return results

# Function to create visualization of cell-membrane relationship
def visualize_cell_membrane_relation(cell_mask, membrane_mask, cell_metrics, output_path=None, title=None):
    # Create RGB image for visualization
    vis_img = np.zeros((cell_mask.shape[0], cell_mask.shape[1], 3), dtype=np.uint8)

    # Color cells based on membrane coverage
    labeled_cells, _ = ndimage.label((cell_mask > 0).astype(np.uint8))

    # Create a mapping from cell_id to membrane coverage
    coverage_map = {m['cell_id']: m['membrane_coverage_ratio'] for m in cell_metrics}
    boundary_map = {m['cell_id']: m['boundary_coverage_ratio'] for m in cell_metrics}

    # Maximum coverage for normalization
    max_coverage = max(max(coverage_map.values()) if coverage_map else 0, 0.001)

    # Color each cell based on membrane coverage
    for cell_id in range(1, np.max(labeled_cells) + 1):
        if cell_id not in coverage_map:
            continue

        # Get cell mask
        cell = (labeled_cells == cell_id)

        # Normalize coverage to 0-1 range
        coverage = coverage_map[cell_id] / max_coverage
        boundary_coverage = boundary_map[cell_id]

        # Create a color based on coverage (red channel)
        # Higher coverage = more red
        color = [int(255 * coverage), 0, 0]

        # Apply color to cell
        for i in range(3):
            vis_img[cell, i] = color[i]

    # Add membrane overlay in cyan
    membrane = (membrane_mask > 0)
    vis_img[membrane, 1] = 255  # Add green channel
    vis_img[membrane, 2] = 255  # Add blue channel

    # Display the image
    plt.figure(figsize=(12, 10))
    plt.imshow(vis_img)

    if title:
        plt.title(title)
    else:
        plt.title("Cell-Membrane Relationship Visualization")

    plt.axis('off')

    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='cyan', edgecolor='black', label='Membrane'),
        Patch(facecolor='red', edgecolor='black', label='High membrane coverage'),
        Patch(facecolor='darkred', edgecolor='black', label='Low membrane coverage')
    ]
    plt.legend(handles=legend_elements, loc='upper right')

    # Save the figure if output path is provided
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved visualization to {output_path}")

    plt.close()

# Function to process a single cell-membrane pair
def process_file_pair(file_pair, pressure):
    sample_id = file_pair['sample_id']
    print(f"\nProcessing {pressure} sample: {sample_id}")

    # Load cell mask
    cell_mask = load_mask_image(file_pair['cell_file'])
    if cell_mask is None:
        print(f"Failed to load cell mask: {file_pair['cell_file']}")
        return None

    # Load membrane mask
    membrane_mask = load_mask_image(file_pair['membrane_file'])
    if membrane_mask is None:
        print(f"Failed to load membrane mask: {file_pair['membrane_file']}")
        return None

    # Ensure masks have same dimensions
    if cell_mask.shape != membrane_mask.shape:
        print(f"Mask dimension mismatch: Cell {cell_mask.shape} vs Membrane {membrane_mask.shape}")
        # Resize membrane mask to match cell mask if needed
        membrane_mask = cv2.resize(membrane_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                  interpolation=cv2.INTER_NEAREST)

    # Analyze relationship
    try:
        results = analyze_cell_membrane_relation(cell_mask, membrane_mask)

        # Add sample information
        results['sample_id'] = sample_id
        results['pressure'] = pressure
        results['cell_file'] = os.path.basename(file_pair['cell_file'])
        results['membrane_file'] = os.path.basename(file_pair['membrane_file'])

        # Convert cell metrics to dataframe
        if results['cell_metrics']:
            # Add sample info to each cell metric
            for metrics in results['cell_metrics']:
                metrics['sample_id'] = sample_id
                metrics['pressure'] = pressure

            cell_df = pd.DataFrame(results['cell_metrics'])
            results['cell_metrics_df'] = cell_df

        print(f"Analysis complete: {results['total_cells']} cells, "
              f"Avg membrane coverage: {results['avg_membrane_coverage']:.2f}, "
              f"Avg boundary coverage: {results['avg_boundary_coverage']:.2f}")

        return results

    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Function to process all files for a specific pressure
def process_pressure_group(file_pairs, pressure):
    print(f"\n=== Processing {pressure} group with {len(file_pairs)} file pairs ===")

    all_results = []
    all_cell_metrics = []

    for file_pair in file_pairs:
        results = process_file_pair(file_pair, pressure)

        if results:
            all_results.append(results)
            all_cell_metrics.extend(results['cell_metrics'])

    # Create combined results
    combined_results = {
        'pressure': pressure,
        'total_samples': len(all_results),
        'total_cells': sum(r['total_cells'] for r in all_results),
        'avg_membrane_cell_overlap': np.mean([r['membrane_cell_overlap_ratio'] for r in all_results]),
        'avg_membrane_coverage': np.mean([r['avg_membrane_coverage'] for r in all_results]),
        'avg_boundary_coverage': np.mean([r['avg_boundary_coverage'] for r in all_results]),
        'sample_results': all_results,
        'all_cell_metrics': all_cell_metrics
    }

    # Create dataframe of all cell metrics
    if all_cell_metrics:
        combined_results['cell_metrics_df'] = pd.DataFrame(all_cell_metrics)

    return combined_results

# Function to compare pressure groups
def compare_pressure_groups(pressure_results, output_dir):
    # Create output directory for comparison
    comparison_dir = os.path.join(output_dir, "pressure_comparison")
    os.makedirs(comparison_dir, exist_ok=True)

    # Extract pressure names
    pressure_names = list(pressure_results.keys())

    # Create a combined dataframe of all cell metrics
    all_metrics_df = pd.concat([pressure_results[p]['cell_metrics_df'] for p in pressure_names
                               if 'cell_metrics_df' in pressure_results[p]])

    # Save combined metrics
    all_metrics_df.to_csv(os.path.join(comparison_dir, "all_cell_membrane_metrics.csv"), index=False)

    # 1. Compare membrane coverage
    plt.figure(figsize=(10, 6))

    sns.boxplot(x='pressure', y='membrane_coverage_ratio', data=all_metrics_df)
    plt.title('Membrane Coverage by Pressure')
    plt.xlabel('Pressure')
    plt.ylabel('Membrane Coverage Ratio')

    # Perform statistical test
    group1 = all_metrics_df[all_metrics_df['pressure'] == pressure_names[0]]['membrane_coverage_ratio']
    group2 = all_metrics_df[all_metrics_df['pressure'] == pressure_names[1]]['membrane_coverage_ratio']

    if len(group1) > 0 and len(group2) > 0:
        # Mann-Whitney U test
        stat, p_value = stats.mannwhitneyu(group1, group2)
        plt.annotate(f'p-value: {p_value:.4f}', xy=(0.5, 0.95), xycoords='axes fraction',
                     ha='center', bbox=dict(boxstyle='round', fc='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "membrane_coverage_comparison.png"), dpi=300)
    plt.close()

    # 2. Compare boundary coverage
    plt.figure(figsize=(10, 6))

    sns.boxplot(x='pressure', y='boundary_coverage_ratio', data=all_metrics_df)
    plt.title('Cell Boundary Membrane Coverage by Pressure')
    plt.xlabel('Pressure')
    plt.ylabel('Boundary Coverage Ratio')

    # Perform statistical test
    group1 = all_metrics_df[all_metrics_df['pressure'] == pressure_names[0]]['boundary_coverage_ratio']
    group2 = all_metrics_df[all_metrics_df['pressure'] == pressure_names[1]]['boundary_coverage_ratio']

    if len(group1) > 0 and len(group2) > 0:
        # Mann-Whitney U test
        stat, p_value = stats.mannwhitneyu(group1, group2)
        plt.annotate(f'p-value: {p_value:.4f}', xy=(0.5, 0.95), xycoords='axes fraction',
                     ha='center', bbox=dict(boxstyle='round', fc='white', alpha=0.8))

    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "boundary_coverage_comparison.png"), dpi=300)
    plt.close()

    # 3. Compare circularity vs membrane coverage scatter plot
    plt.figure(figsize=(10, 8))

    sns.scatterplot(x='circularity', y='membrane_coverage_ratio', hue='pressure', data=all_metrics_df, alpha=0.7)
    plt.title('Cell Circularity vs Membrane Coverage')
    plt.xlabel('Cell Circularity')
    plt.ylabel('Membrane Coverage Ratio')
    plt.grid(True, alpha=0.3)

    # Add regression lines
    for pressure in pressure_names:
        pressure_data = all_metrics_df[all_metrics_df['pressure'] == pressure]
        if len(pressure_data) > 1:
            sns.regplot(x='circularity', y='membrane_coverage_ratio', data=pressure_data,
                       scatter=False, label=f"Trend for {pressure}")

    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "circularity_vs_coverage.png"), dpi=300)
    plt.close()

    # 4. Compare cell area vs membrane coverage scatter plot
    plt.figure(figsize=(10, 8))

    sns.scatterplot(x='cell_area', y='membrane_coverage_ratio', hue='pressure', data=all_metrics_df, alpha=0.7)
    plt.title('Cell Area vs Membrane Coverage')
    plt.xlabel('Cell Area (pixels)')
    plt.ylabel('Membrane Coverage Ratio')
    plt.grid(True, alpha=0.3)

    # Add regression lines
    for pressure in pressure_names:
        pressure_data = all_metrics_df[all_metrics_df['pressure'] == pressure]
        if len(pressure_data) > 1:
            sns.regplot(x='cell_area', y='membrane_coverage_ratio', data=pressure_data,
                       scatter=False, label=f"Trend for {pressure}")

    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir, "area_vs_coverage.png"), dpi=300)
    plt.close()

    # 5. Create summary statistics table
    summary_data = []

    for pressure in pressure_names:
        if 'cell_metrics_df' in pressure_results[pressure]:
            df = pressure_results[pressure]['cell_metrics_df']

            row = {
                'Pressure': pressure,
                'Total Samples': pressure_results[pressure]['total_samples'],
                'Total Cells': pressure_results[pressure]['total_cells'],
                'Avg Membrane-Cell Overlap': pressure_results[pressure]['avg_membrane_cell_overlap'],
                'Avg Membrane Coverage': pressure_results[pressure]['avg_membrane_coverage'],
                'Avg Boundary Coverage': pressure_results[pressure]['avg_boundary_coverage'],
                'Median Membrane Coverage': df['membrane_coverage_ratio'].median(),
                'Median Boundary Coverage': df['boundary_coverage_ratio'].median(),
                'Avg Cell Area': df['cell_area'].mean(),
                'Avg Cell Circularity': df['circularity'].mean()
            }

            summary_data.append(row)

    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(os.path.join(comparison_dir, "pressure_comparison_summary.csv"), index=False)

    print(f"\nComparison results saved to {comparison_dir}")
    print("\nSummary Statistics:")
    print(summary_df)

    return summary_df

# Function to create histogram distribution plots
def create_distribution_plots(pressure_results, output_dir):
    dist_dir = os.path.join(output_dir, "distributions")
    os.makedirs(dist_dir, exist_ok=True)

    pressure_names = list(pressure_results.keys())

    # 1. Membrane coverage distribution
    plt.figure(figsize=(12, 8))

    for pressure in pressure_names:
        if 'cell_metrics_df' in pressure_results[pressure]:
            df = pressure_results[pressure]['cell_metrics_df']
            sns.histplot(df['membrane_coverage_ratio'], kde=True, label=pressure, alpha=0.6)

    plt.title('Distribution of Membrane Coverage Ratio')
    plt.xlabel('Membrane Coverage Ratio')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(dist_dir, "membrane_coverage_distribution.png"), dpi=300)
    plt.close()

    # 2. Boundary coverage distribution
    plt.figure(figsize=(12, 8))

    for pressure in pressure_names:
        if 'cell_metrics_df' in pressure_results[pressure]:
            df = pressure_results[pressure]['cell_metrics_df']
            sns.histplot(df['boundary_coverage_ratio'], kde=True, label=pressure, alpha=0.6)

    plt.title('Distribution of Boundary Coverage Ratio')
    plt.xlabel('Boundary Coverage Ratio')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(dist_dir, "boundary_coverage_distribution.png"), dpi=300)
    plt.close()

    # 3. Cell area distribution
    plt.figure(figsize=(12, 8))

    for pressure in pressure_names:
        if 'cell_metrics_df' in pressure_results[pressure]:
            df = pressure_results[pressure]['cell_metrics_df']
            sns.histplot(df['cell_area'], kde=True, label=pressure, alpha=0.6)

    plt.title('Distribution of Cell Area')
    plt.xlabel('Cell Area (pixels)')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(dist_dir, "cell_area_distribution.png"), dpi=300)
    plt.close()

    # 4. Cell circularity distribution
    plt.figure(figsize=(12, 8))

    for pressure in pressure_names:
        if 'cell_metrics_df' in pressure_results[pressure]:
            df = pressure_results[pressure]['cell_metrics_df']
            sns.histplot(df['circularity'], kde=True, label=pressure, alpha=0.6)

    plt.title('Distribution of Cell Circularity')
    plt.xlabel('Circularity')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(dist_dir, "cell_circularity_distribution.png"), dpi=300)
    plt.close()

    print(f"Distribution plots saved to {dist_dir}")

# Main function
def main():
    print("Starting cell-membrane relationship analysis...")

    # Find and organize files by pressure
    pressure_dict = find_mask_files(cell_mask_dir, membrane_dir)

    # Process each pressure group
    pressure_results = {}

    for pressure, file_pairs in pressure_dict.items():
        if file_pairs:
            # Process all file pairs for this pressure
            results = process_pressure_group(file_pairs, pressure)
            pressure_results[pressure] = results

            # Create pressure-specific output directory
            pressure_dir = os.path.join(output_dir, f"pressure_{pressure}")
            os.makedirs(pressure_dir, exist_ok=True)

            # Save cell metrics to CSV
            if 'cell_metrics_df' in results:
                results['cell_metrics_df'].to_csv(
                    os.path.join(pressure_dir, f"{pressure}_cell_membrane_metrics.csv"),
                    index=False
                )

            # Create visualizations for one sample
            if file_pairs and 'cell_metrics_df' in results:
                sample_pair = file_pairs[0]
                sample_cell_mask = load_mask_image(sample_pair['cell_file'])
                sample_membrane_mask = load_mask_image(sample_pair['membrane_file'])

                if sample_cell_mask is not None and sample_membrane_mask is not None:
                    # Ensure masks have same dimensions
                    if sample_cell_mask.shape != sample_membrane_mask.shape:
                        sample_membrane_mask = cv2.resize(
                            sample_membrane_mask,
                            (sample_cell_mask.shape[1], sample_cell_mask.shape[0]),
                            interpolation=cv2.INTER_NEAREST
                        )

                    # Get sample ID
                    sample_id = sample_pair['sample_id']

                    # Get cell metrics for this sample
                    sample_metrics = [
                        m for m in results['all_cell_metrics']
                        if m.get('sample_id') == sample_id
                    ]

                    if sample_metrics:
                        vis_output = os.path.join(
                            pressure_dir,
                            f"{pressure}_{sample_id}_visualization.png"
                        )

                        visualize_cell_membrane_relation(
                            sample_cell_mask,
                            sample_membrane_mask,
                            sample_metrics,
                            output_path=vis_output,
                            title=f"{pressure}: Cell-Membrane Relationship"
                        )

    # Compare pressure groups if we have at least two
    if len(pressure_results) >= 2:
        summary_df = compare_pressure_groups(pressure_results, output_dir)
        create_distribution_plots(pressure_results, output_dir)
    else:
        print("\nNot enough pressure groups for comparison")

    print("\nCell-membrane relationship analysis completed!")

# Run the main function
if __name__ == "__main__":
    main()

Starting cell-membrane relationship analysis...
Organizing files by pressure and finding corresponding pairs...
Found 8 cell mask files.
Found 8 membrane mask files.
Example cell file: 0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_mask_merged_conservative.tif
Extracted sample ID: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
Example membrane file: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Nuclei_regional_tophat_cadherins_mask_cleaned.tif
Extracted sample ID: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
Added to lookup: 0Pa_U_05mar19_20x_L2RA_Flat_seq001 -> denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Nuclei_regional_tophat_cadherins_mask_cleaned.tif
Added to lookup: 0Pa_U_05mar19_20x_L2RA_Flat_seq002 -> denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Nuclei_regional_tophat_cadherins_mask_cleaned.tif
Added to lookup: 0Pa_U_05mar19_20x_L2RA_Flat_seq003 -> denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Nuclei_regional_tophat_cadherins_mask_cleaned.tif
Added to lookup: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001 -> denoised_1.4Pa

In [21]:
# %% Setup Cell 1: Mount Google Drive
# Run this cell to mount your Google Drive where the data is stored.
# You will be prompted to authorize access.
from google.colab import drive
drive.mount('/content/drive')

# %% Setup Cell 2: Install Dependencies
# Run this cell to ensure necessary libraries are installed in the Colab environment.
!pip install opencv-python-headless scikit-image seaborn >> /dev/null
print("Dependencies checked/installed.")

# %% Main Code Cell: Analysis Script
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2 # OpenCV for image processing like contours
from skimage import io, measure, segmentation # Scikit-image for image analysis tasks
from scipy import ndimage, stats # Scipy for numerical operations and stats
from collections import Counter
import re
import math
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression
import traceback # For detailed error messages

print("Libraries imported successfully.")

# Set up matplotlib for better visualization in notebooks
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100 # Adjust DPI for plot resolution
plt.style.use('ggplot') # Use a visually appealing style

# --- Configuration: Define your input and output directories ---
# !!! IMPORTANT: Make sure these paths correctly point to your folders in Google Drive !!!
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
membrane_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Membrane"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation"

# Create output directory if it doesn't exist
try:
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory created/verified: {output_dir}")
except OSError as e:
    print(f"Error creating output directory {output_dir}: {e}")
    # Handle error appropriately, maybe exit or use a default path
    # For now, we'll proceed, but saving might fail.

# --- Helper Functions ---

def extract_pressure(filename):
    """Extracts '0Pa' or '1.4Pa' from a filename."""
    # Regex looks for 0Pa or 1.4Pa explicitly
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    if match:
        return match.group(1)
    # print(f"Warning: Could not extract pressure from filename: {filename}") # Optional warning
    return None

def extract_sample_id(filename):
    """Extracts a unique sample identifier from a filename to match pairs."""
    # Regex designed to capture the core part of the filename excluding mask type suffixes
    # Example: "1.4Pa_U_05mar19_20x_L2R_Flat_seq005"
    # It assumes the pressure value is at the start, followed by details, and ends before suffixes like '_cell_mask...'
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    if match:
        return match.group(1)
    # print(f"Warning: Could not extract sample ID from filename: {filename}") # Optional warning
    return None

def find_mask_files(cell_dir, membrane_dir):
    """Finds and pairs cell and membrane mask files based on sample ID and pressure."""
    print("\n--- Finding and Pairing Mask Files ---")
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    try:
        cell_files_raw = os.listdir(cell_dir)
        membrane_files_raw = os.listdir(membrane_dir)
    except FileNotFoundError as e:
        print(f"Error accessing directory: {e}")
        print("Please ensure the `cell_mask_dir` and `membrane_dir` paths are correct and Google Drive is mounted.")
        return None # Return None to indicate failure

    # Filter for relevant cell mask files (adjust suffixes if needed)
    cell_files = [f for f in cell_files_raw if f.endswith(('_cell_mask_merged_conservative.tif', '_cell_mask.tif')) and not f.startswith('.')]
    print(f"Found {len(cell_files)} potential cell mask files in {cell_dir}")

    # Filter for relevant membrane mask files (assuming all .tif are relevant)
    membrane_files = [f for f in membrane_files_raw if f.endswith('.tif') and not f.startswith('.')]
    print(f"Found {len(membrane_files)} potential membrane mask files in {membrane_dir}")

    if not cell_files or not membrane_files:
        print("Warning: No cell or membrane files found. Check directories and file naming.")
        return pressure_dict # Return empty dict

    # Create a lookup dictionary for membrane files based on their sample ID
    membrane_lookup = {}
    for membrane_file in membrane_files:
        sample_id = extract_sample_id(membrane_file)
        if sample_id:
            if sample_id in membrane_lookup:
                 print(f"Warning: Duplicate sample ID found for membrane file: {sample_id}. Using the last one found: {membrane_file}")
            membrane_lookup[sample_id] = membrane_file
        # else: # Optional: Be more verbose about files that don't match the pattern
            # print(f"Debug: Could not extract sample ID from membrane file: {membrane_file}")

    # Match cell files to membrane files
    pairs_found = 0
    processed_cell_ids = set()
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if sample_id in processed_cell_ids:
            # print(f"Debug: Skipping already processed sample ID: {sample_id} from file {cell_file}")
            continue

        if pressure and pressure in pressure_dict and sample_id:
            if sample_id in membrane_lookup:
                membrane_file = membrane_lookup[sample_id]
                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'membrane_file': os.path.join(membrane_dir, membrane_file),
                    'sample_id': sample_id
                }
                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                processed_cell_ids.add(sample_id)
                # print(f"Matched pair for ID {sample_id}: {cell_file} <-> {membrane_file}") # Optional verbose matching log
            # else: # Optional: Be more verbose about missing matches
                # print(f"Warning: No matching membrane file found for cell file {cell_file} with ID {sample_id}")
        # else: # Optional: Be more verbose about files that don't fit criteria
            # if not pressure: print(f"Debug: Invalid pressure for {cell_file}")
            # if not sample_id: print(f"Debug: Invalid sample ID for {cell_file}")

    print(f"Total matching cell-membrane file pairs found: {pairs_found}")
    print("Pairs per pressure group:")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} pairs")
    print("-" * 35)

    return pressure_dict

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert boolean masks or masks with values > 1 to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8) # Convert boolean to 0/1
        elif np.max(img) > 1:
             # Assuming values > 0 are foreground
             img = (img > 0).astype(np.uint8)
        elif np.max(img) == 1:
             # Already binary 0/1, ensure type is uint8
             img = img.astype(np.uint8)
        else:
             # All zeros image
             img = img.astype(np.uint8)

        if img.ndim > 2:
            print(f"Warning: Image {os.path.basename(filepath)} has multiple channels ({img.shape}). Converting to grayscale.")
            # Attempt to convert to grayscale, assuming standard channel order
            if img.shape[2] == 3: # RGB
                 img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4: # RGBA
                 img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else: # Other multi-channel formats - take the first channel? Or fail?
                 print(f"Warning: Cannot automatically convert image with shape {img.shape}. Taking first channel.")
                 img = img[:,:,0]
            # Re-binarize after conversion
            img = (img > 0).astype(np.uint8)

        return img
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def analyze_cell_membrane_size_relation(cell_mask, membrane_mask):
    """Analyzes individual cells from masks to extract size and membrane metrics."""
    # Ensure masks are binary (redundant check if load_mask_image works, but safe)
    binary_cell_mask = (cell_mask > 0).astype(np.uint8)
    binary_membrane_mask = (membrane_mask > 0).astype(np.uint8)

    # Label individual cells (connected components) in the cell mask
    labeled_cells, num_cells = ndimage.label(binary_cell_mask)
    if num_cells == 0:
        # print("No cells found in this mask.")
        return [] # Return empty list if no cells

    # print(f"Analyzing {num_cells} detected cells...")
    cell_metrics = []

    # Use ndimage.find_objects to get bounding boxes and slices for efficiency
    cell_slices = ndimage.find_objects(labeled_cells)

    for cell_id in range(1, num_cells + 1):
        # Get the slice corresponding to the current cell
        current_slice = cell_slices[cell_id - 1]
        # Extract the relevant region from the labeled mask and binary masks
        sub_labeled = labeled_cells[current_slice]
        sub_cell_mask = (sub_labeled == cell_id).astype(np.uint8)
        sub_membrane_mask = binary_membrane_mask[current_slice]

        # --- Calculate Cell Metrics ---
        cell_area = np.sum(sub_cell_mask)
        if cell_area < 5: # Skip tiny artifacts (adjust threshold if needed)
            continue

        # Cell Perimeter using OpenCV contours
        contours, _ = cv2.findContours(sub_cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        cell_perimeter = cv2.arcLength(contours[0], True) if contours else 0

        # --- Calculate Membrane Metrics Relative to Cell ---
        # Total membrane overlapping with the cell area
        cell_membrane_overlap = np.logical_and(sub_cell_mask, sub_membrane_mask)
        total_membrane_area = np.sum(cell_membrane_overlap)

        # Cell Boundary (pixels just outside the cell)
        # 'outer' mode gives pixels adjacent to the object, not part of it. Use 'thick' or dilate/subtract for boundary pixels *within* the cell.
        # Let's redefine boundary as the pixels *belonging* to the cell at its edge.
        # Method 1: Using find_boundaries on the cell mask itself
        # cell_boundary = segmentation.find_boundaries(sub_cell_mask, mode='inner').astype(np.uint8)
        # Method 2: Using morphological erosion
        eroded_cell = ndimage.binary_erosion(sub_cell_mask, structure=np.ones((3,3))).astype(sub_cell_mask.dtype)
        cell_boundary = sub_cell_mask - eroded_cell
        cell_boundary_length = np.sum(cell_boundary) # Number of boundary pixels

        # Membrane located on this boundary
        membrane_on_boundary = np.logical_and(cell_boundary, sub_membrane_mask)
        boundary_membrane_area = np.sum(membrane_on_boundary)

        # Cell Interior (cell area minus the boundary pixels)
        cell_interior = eroded_cell # Use the eroded cell as the interior
        interior_area = np.sum(cell_interior)

        # Membrane located in the cell interior
        interior_membrane = np.logical_and(cell_interior, sub_membrane_mask)
        interior_membrane_area = np.sum(interior_membrane)

        # --- Calculate Ratios and Derived Metrics ---
        membrane_coverage_ratio = total_membrane_area / cell_area if cell_area > 0 else 0
        # Ratio of boundary pixels covered by membrane
        boundary_coverage_ratio = boundary_membrane_area / cell_boundary_length if cell_boundary_length > 0 else 0
        # Ratio of interior pixels covered by membrane
        interior_coverage_ratio = interior_membrane_area / interior_area if interior_area > 0 else 0

        # Shape Metrics
        # Circularity = 4 * pi * Area / Perimeter^2 (1 for perfect circle)
        circularity = (4 * np.pi * cell_area) / (cell_perimeter ** 2) if cell_perimeter > 0 else 0
        # Equivalent Diameter (diameter of circle with same area)
        equivalent_diameter = 2 * np.sqrt(cell_area / np.pi)

        # Ratio of membrane on boundary vs interior
        membrane_boundary_to_interior_ratio = boundary_membrane_area / interior_membrane_area if interior_membrane_area > 0 else (boundary_membrane_area if boundary_membrane_area > 0 else 0) # Avoid division by zero, handle cases

        metrics = {
            'cell_id': cell_id,
            'cell_area': cell_area,
            'cell_perimeter': cell_perimeter,
            'equivalent_diameter': equivalent_diameter,
            'circularity': circularity,
            'total_membrane_area': total_membrane_area, # Membrane within cell bounds
            'membrane_coverage_ratio': membrane_coverage_ratio, # total_membrane / cell_area
            'cell_boundary_length': cell_boundary_length, # Number of pixels in boundary band
            'boundary_membrane_area': boundary_membrane_area, # Membrane on the boundary band
            'boundary_coverage_ratio': boundary_coverage_ratio, # boundary_membrane / boundary_length
            'interior_area': interior_area, # Cell area excluding boundary band
            'interior_membrane_area': interior_membrane_area, # Membrane in the interior
            'interior_coverage_ratio': interior_coverage_ratio, # interior_membrane / interior_area
            'membrane_boundary_to_interior_ratio': membrane_boundary_to_interior_ratio
        }
        cell_metrics.append(metrics)

    return cell_metrics

def process_file_pair(file_pair, pressure):
    """Loads masks for a pair, handles resizing, runs analysis, returns metrics."""
    sample_id = file_pair['sample_id']
    # print(f"\nProcessing {pressure} sample: {sample_id}") # Reduce verbosity

    # Load cell mask
    cell_mask = load_mask_image(file_pair['cell_file'])
    if cell_mask is None:
        print(f"--> Failed to load cell mask: {os.path.basename(file_pair['cell_file'])}")
        return None

    # Load membrane mask
    membrane_mask = load_mask_image(file_pair['membrane_file'])
    if membrane_mask is None:
        print(f"--> Failed to load membrane mask: {os.path.basename(file_pair['membrane_file'])}")
        return None

    # Ensure masks have same dimensions - Resize membrane mask if necessary
    if cell_mask.shape != membrane_mask.shape:
        print(f"Warning: Mask dimension mismatch for {sample_id}. Cell: {cell_mask.shape}, Membrane: {membrane_mask.shape}. Resizing membrane mask...")
        try:
            # Use INTER_NEAREST for masks to avoid introducing intermediate values
            membrane_mask = cv2.resize(membrane_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                       interpolation=cv2.INTER_NEAREST)
        except Exception as e:
             print(f"--> Error resizing membrane mask for {sample_id}: {e}")
             return None # Skip this pair if resizing fails

    # Analyze relationship
    try:
        cell_metrics = analyze_cell_membrane_size_relation(cell_mask, membrane_mask)

        # Add sample information and pressure to each cell's metrics
        for metrics in cell_metrics:
            metrics['sample_id'] = sample_id
            metrics['pressure'] = pressure

        # print(f"-> Analysis complete for {sample_id}: {len(cell_metrics)} cells analyzed.") # Reduce verbosity
        return cell_metrics

    except Exception as e:
        print(f"--> Error during analysis of {sample_id}: {str(e)}")
        print(traceback.format_exc()) # Print full traceback for debugging
        return None

def process_pressure_group(file_pairs, pressure, output_dir):
    """Processes all file pairs for a specific pressure group."""
    print(f"\n=== Processing {pressure} Group ({len(file_pairs)} file pairs) ===")
    if not file_pairs:
        print("No files to process for this group.")
        return {'pressure': pressure, 'total_samples': 0, 'total_cells': 0, 'all_cell_metrics': [], 'cell_metrics_df': pd.DataFrame()}

    all_cell_metrics = []
    processed_count = 0
    failed_count = 0

    for i, file_pair in enumerate(file_pairs):
        # print(f"Processing pair {i+1}/{len(file_pairs)}...") # Simple progress indicator
        cell_metrics = process_file_pair(file_pair, pressure)
        if cell_metrics is not None:
            all_cell_metrics.extend(cell_metrics)
            processed_count += 1
        else:
            failed_count += 1

    print(f"--- {pressure} Group Summary ---")
    print(f"Successfully processed: {processed_count} pairs")
    print(f"Failed/Skipped: {failed_count} pairs")
    print(f"Total cells analyzed: {len(all_cell_metrics)}")

    results = {
        'pressure': pressure,
        'total_samples_processed': processed_count,
        'total_cells': len(all_cell_metrics),
        'all_cell_metrics': all_cell_metrics
    }

    # Create and save dataframe for this pressure group
    if all_cell_metrics:
        results_df = pd.DataFrame(all_cell_metrics)
        results['cell_metrics_df'] = results_df
        try:
            group_csv_path = os.path.join(output_dir, f"{pressure}_cell_metrics.csv")
            results_df.to_csv(group_csv_path, index=False)
            print(f"Saved {pressure} metrics to {group_csv_path}")
        except Exception as e:
            print(f"Error saving {pressure} metrics CSV: {e}")
    else:
        results['cell_metrics_df'] = pd.DataFrame() # Ensure DataFrame exists even if empty

    print(f"=== Finished processing {pressure} group ===")
    return results

# --- Analysis and Plotting Functions ---

def create_correlation_analysis(pressure_results, output_dir):
    """Generates scatter plots with regression lines for key metric correlations."""
    print("\n--- Generating Correlation Analysis Plots ---")
    corr_dir = os.path.join(output_dir, "correlation_analysis")
    os.makedirs(corr_dir, exist_ok=True)

    pressure_names = list(pressure_results.keys())
    valid_results = {p: r for p, r in pressure_results.items() if 'cell_metrics_df' in r and not r['cell_metrics_df'].empty}

    if not valid_results:
        print("No valid data found for correlation analysis.")
        return None

    # Combine dataframes from all valid pressure groups
    all_metrics_df = pd.concat([valid_results[p]['cell_metrics_df'] for p in valid_results])
    if all_metrics_df.empty:
        print("Combined dataframe is empty. Skipping correlation analysis.")
        return None

    # Save the combined dataframe
    try:
        all_metrics_csv_path = os.path.join(corr_dir, "all_combined_cell_metrics.csv")
        all_metrics_df.to_csv(all_metrics_csv_path, index=False)
        print(f"Saved combined metrics to {all_metrics_csv_path}")
    except Exception as e:
        print(f"Error saving combined metrics CSV: {e}")


    # --- Plot 1: Cell Area vs Membrane Coverage ---
    plt.figure(figsize=(12, 10))
    plot_success = False
    for pressure in valid_results.keys():
        df = valid_results[pressure]['cell_metrics_df']
        if df.empty or 'cell_area' not in df or 'membrane_coverage_ratio' not in df:
            continue

        sns.scatterplot(x='cell_area', y='membrane_coverage_ratio', data=df, alpha=0.5, label=f"{pressure} cells")
        # Add regression line only if there are enough points
        if len(df) > 1:
             sns.regplot(x='cell_area', y='membrane_coverage_ratio', data=df, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")

             # Calculate correlation and R-squared
             try:
                 correlation = df['cell_area'].corr(df['membrane_coverage_ratio'])
                 x = df['cell_area'].values.reshape(-1, 1)
                 y = df['membrane_coverage_ratio'].values
                 model = LinearRegression().fit(x, y)
                 r_squared = r2_score(y, model.predict(x))
                 annotation = f"{pressure}: r={correlation:.2f}, R²={r_squared:.2f}"
             except Exception as e:
                 print(f"Could not calculate stats for {pressure} Area vs Coverage: {e}")
                 annotation = f"{pressure}: Stats Error"

             plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(valid_results.keys()).index(pressure)), xycoords='axes fraction',
                          ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
             plot_success = True

    if plot_success:
        plt.title('Cell Area vs Total Membrane Coverage Ratio', fontsize=16)
        plt.xlabel('Cell Area (pixels²)', fontsize=14)
        plt.ylabel('Membrane Coverage Ratio (Membrane Area / Cell Area)', fontsize=14)
        plt.grid(True, alpha=0.4)
        plt.legend(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(corr_dir, "plot_corr_area_vs_membrane_coverage.png"), dpi=300)
        print("Saved: plot_corr_area_vs_membrane_coverage.png")
    else:
        print("Skipping plot: Cell Area vs Membrane Coverage (no data)")
    plt.close()


    # --- Plot 2: Cell Perimeter vs Boundary Coverage ---
    plt.figure(figsize=(12, 10))
    plot_success = False
    for pressure in valid_results.keys():
        df = valid_results[pressure]['cell_metrics_df']
        if df.empty or 'cell_perimeter' not in df or 'boundary_coverage_ratio' not in df:
            continue

        # Filter out potential NaNs or Infs if they occur
        df_filt = df[['cell_perimeter', 'boundary_coverage_ratio']].dropna()
        if df_filt.empty: continue

        sns.scatterplot(x='cell_perimeter', y='boundary_coverage_ratio', data=df_filt, alpha=0.5, label=f"{pressure} cells")
        if len(df_filt) > 1:
            sns.regplot(x='cell_perimeter', y='boundary_coverage_ratio', data=df_filt, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
            try:
                correlation = df_filt['cell_perimeter'].corr(df_filt['boundary_coverage_ratio'])
                x = df_filt['cell_perimeter'].values.reshape(-1, 1)
                y = df_filt['boundary_coverage_ratio'].values
                model = LinearRegression().fit(x, y)
                r_squared = r2_score(y, model.predict(x))
                annotation = f"{pressure}: r={correlation:.2f}, R²={r_squared:.2f}"
            except Exception as e:
                print(f"Could not calculate stats for {pressure} Perimeter vs Boundary Coverage: {e}")
                annotation = f"{pressure}: Stats Error"
            plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(valid_results.keys()).index(pressure)), xycoords='axes fraction',
                         ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
            plot_success = True

    if plot_success:
        plt.title('Cell Perimeter vs Boundary Membrane Coverage Ratio', fontsize=16)
        plt.xlabel('Cell Perimeter (pixels)', fontsize=14)
        plt.ylabel('Boundary Coverage Ratio (Boundary Membrane / Boundary Length)', fontsize=14)
        plt.grid(True, alpha=0.4)
        plt.legend(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(corr_dir, "plot_corr_perimeter_vs_boundary_coverage.png"), dpi=300)
        print("Saved: plot_corr_perimeter_vs_boundary_coverage.png")
    else:
        print("Skipping plot: Cell Perimeter vs Boundary Coverage (no data)")
    plt.close()


    # --- Plot 3: Cell Circularity vs Membrane Coverage ---
    plt.figure(figsize=(12, 10))
    plot_success = False
    for pressure in valid_results.keys():
        df = valid_results[pressure]['cell_metrics_df']
        if df.empty or 'circularity' not in df or 'membrane_coverage_ratio' not in df:
            continue

        df_filt = df[['circularity', 'membrane_coverage_ratio']].dropna()
         # Filter out non-finite values for circularity if needed (e.g., if perimeter was 0)
        df_filt = df_filt[np.isfinite(df_filt['circularity'])]
        if df_filt.empty: continue

        sns.scatterplot(x='circularity', y='membrane_coverage_ratio', data=df_filt, alpha=0.5, label=f"{pressure} cells")
        if len(df_filt) > 1:
            sns.regplot(x='circularity', y='membrane_coverage_ratio', data=df_filt, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
            try:
                correlation = df_filt['circularity'].corr(df_filt['membrane_coverage_ratio'])
                x = df_filt['circularity'].values.reshape(-1, 1)
                y = df_filt['membrane_coverage_ratio'].values
                model = LinearRegression().fit(x, y)
                r_squared = r2_score(y, model.predict(x))
                annotation = f"{pressure}: r={correlation:.2f}, R²={r_squared:.2f}"
            except Exception as e:
                print(f"Could not calculate stats for {pressure} Circularity vs Coverage: {e}")
                annotation = f"{pressure}: Stats Error"
            plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(valid_results.keys()).index(pressure)), xycoords='axes fraction',
                         ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
            plot_success = True

    if plot_success:
        plt.title('Cell Circularity vs Total Membrane Coverage', fontsize=16)
        plt.xlabel('Circularity (4π Area / Perimeter²)', fontsize=14)
        plt.ylabel('Membrane Coverage Ratio', fontsize=14)
        plt.xlim(0, 1.1) # Circularity is theoretically <= 1
        plt.grid(True, alpha=0.4)
        plt.legend(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(corr_dir, "plot_corr_circularity_vs_membrane_coverage.png"), dpi=300)
        print("Saved: plot_corr_circularity_vs_membrane_coverage.png")
    else:
        print("Skipping plot: Circularity vs Membrane Coverage (no data)")
    plt.close()


    # --- Plot 4: Membrane Distribution: Interior vs Boundary Coverage ---
    plt.figure(figsize=(12, 10))
    plot_success = False
    # Calculate reasonable point sizes based on area quantiles to avoid extreme variations
    min_size, max_size = 10, 200 # Min/max point sizes for plot
    try:
        area_q1 = all_metrics_df['cell_area'].quantile(0.1)
        area_q9 = all_metrics_df['cell_area'].quantile(0.9)
        # Normalize sizes between min_size and max_size based on area quantiles
        # Clip areas outside the 10-90 percentile range for sizing
        sizes_norm = np.clip(all_metrics_df['cell_area'], area_q1, area_q9)
        sizes_norm = min_size + (max_size - min_size) * (sizes_norm - area_q1) / (area_q9 - area_q1 + 1e-6) # Normalize
        all_metrics_df['plot_point_size'] = sizes_norm.fillna(min_size) # Handle potential NaNs
    except Exception:
         print("Warning: Could not calculate dynamic point sizes. Using fixed size.")
         all_metrics_df['plot_point_size'] = 50 # Fixed size fallback

    for pressure in valid_results.keys():
        df = all_metrics_df[all_metrics_df['pressure'] == pressure]
        if df.empty or 'interior_coverage_ratio' not in df or 'boundary_coverage_ratio' not in df:
            continue

        df_filt = df[['interior_coverage_ratio', 'boundary_coverage_ratio', 'plot_point_size']].dropna()
        if df_filt.empty: continue

        sns.scatterplot(
            x='interior_coverage_ratio',
            y='boundary_coverage_ratio',
            data=df_filt,
            alpha=0.6,
            s=df_filt['plot_point_size'], # Use pre-calculated sizes
            label=f"{pressure} cells"
        )
        plot_success = True

    if plot_success:
        plt.title('Membrane Distribution: Interior vs Boundary Coverage', fontsize=16)
        plt.xlabel('Interior Coverage Ratio (Interior Membrane / Interior Area)', fontsize=14)
        plt.ylabel('Boundary Coverage Ratio (Boundary Membrane / Boundary Length)', fontsize=14)
        plt.grid(True, alpha=0.4)
        # Create a proxy legend for point sizes if desired (more complex)
        plt.legend(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(corr_dir, "plot_scatter_interior_vs_boundary_coverage.png"), dpi=300)
        print("Saved: plot_scatter_interior_vs_boundary_coverage.png")
    else:
         print("Skipping plot: Interior vs Boundary Coverage (no data)")
    plt.close()

    print("--- Correlation analysis complete ---")
    return all_metrics_df # Return the combined dataframe

def analyze_size_categories(all_metrics_df, output_dir):
    """Analyzes metrics grouped by cell size categories and pressure."""
    if all_metrics_df is None or all_metrics_df.empty:
        print("\n--- Skipping Size Category Analysis (No data) ---")
        return None

    print("\n--- Analyzing Metrics by Cell Size Categories ---")
    size_dir = os.path.join(output_dir, "size_categories")
    os.makedirs(size_dir, exist_ok=True)

    if 'cell_area' not in all_metrics_df.columns or all_metrics_df['cell_area'].isnull().all():
         print("Error: 'cell_area' column missing or empty. Cannot perform size categorization.")
         return None

    # Define size categories based on quartiles of cell area
    try:
        size_q1 = all_metrics_df['cell_area'].quantile(0.25)
        size_q2 = all_metrics_df['cell_area'].quantile(0.50)
        size_q3 = all_metrics_df['cell_area'].quantile(0.75)

        # Handle cases where quartiles might be equal (e.g., low variance data)
        bins = sorted(list(set([0, size_q1, size_q2, size_q3, float('inf')])))
        labels = ['Smallest', 'Small-Med', 'Med-Large', 'Largest']
        # Adjust labels if bins collapsed
        if len(bins) -1 < len(labels):
             labels = labels[:len(bins)-1]

        if len(bins) <= 2: # Not enough variation to categorize
             print("Warning: Not enough variation in cell area to create meaningful size categories.")
             all_metrics_df['size_category'] = 'All Sizes'
             print("Size categories defined:")
             print(f"  All Sizes: > 0 pixels")
        else:
             all_metrics_df['size_category'] = pd.cut(
                 all_metrics_df['cell_area'],
                 bins=bins,
                 labels=labels,
                 right=False, # [ S Q1 ) [ Q1 Q2 ) etc.
                 include_lowest=True # include 0
             )
             print("Size categories defined (based on area quartiles):")
             for i in range(len(labels)):
                 print(f"  {labels[i]}: {bins[i]:.1f} - {bins[i+1]:.1f} pixels²")

    except Exception as e:
        print(f"Error defining size categories: {e}. Assigning all to one category.")
        all_metrics_df['size_category'] = 'All Sizes'


    # Save categorized dataframe
    try:
        categorized_csv_path = os.path.join(size_dir, "cells_with_size_categories.csv")
        all_metrics_df.to_csv(categorized_csv_path, index=False)
        print(f"Saved categorized data to {categorized_csv_path}")
    except Exception as e:
        print(f"Error saving categorized data CSV: {e}")


    # --- Box Plots ---
    metrics_to_plot = {
        'membrane_coverage_ratio': 'Total Membrane Coverage by Size Category',
        'boundary_coverage_ratio': 'Boundary Membrane Coverage by Size Category',
        'interior_coverage_ratio': 'Interior Membrane Coverage by Size Category'
    }

    category_order = labels if 'labels' in locals() and len(bins)>2 else ['All Sizes'] # Ensure correct order

    for metric, title in metrics_to_plot.items():
        if metric not in all_metrics_df.columns or all_metrics_df[metric].isnull().all():
            print(f"Skipping plot for {metric}: Column missing or empty.")
            continue

        plt.figure(figsize=(14, 8))
        try:
             sns.boxplot(
                 x='size_category',
                 y=metric,
                 hue='pressure',
                 data=all_metrics_df,
                 order=category_order,
                 palette='viridis' # Use a different palette
             )
             plt.title(title, fontsize=16)
             plt.xlabel('Cell Size Category', fontsize=14)
             plt.ylabel(metric.replace('_', ' ').title(), fontsize=14)
             plt.xticks(rotation=15, ha='right')
             plt.grid(True, axis='y', alpha=0.4)
             plt.legend(title='Pressure')
             plt.tight_layout()
             plot_filename = f"plot_boxplot_{metric}_by_size_category.png"
             plt.savefig(os.path.join(size_dir, plot_filename), dpi=300)
             print(f"Saved: {plot_filename}")
        except Exception as e:
             print(f"Error generating boxplot for {metric}: {e}")
        finally:
             plt.close()


    # --- Stacked Bar Plot: Membrane Distribution (Boundary vs Interior Contribution) ---
    try:
        # Calculate average coverage ratios per category and pressure
        grouped_stats = all_metrics_df.groupby(['pressure', 'size_category'], observed=True).agg(
            mean_boundary_coverage=('boundary_coverage_ratio', 'mean'),
            mean_interior_coverage=('interior_coverage_ratio', 'mean'),
            mean_total_coverage=('membrane_coverage_ratio', 'mean')
        ).reset_index()

        # Calculate relative contribution (handle division by zero if total coverage is zero)
        grouped_stats['boundary_contribution'] = grouped_stats.apply(
            lambda row: row['mean_boundary_coverage'] / row['mean_total_coverage'] if row['mean_total_coverage'] > 0 else 0, axis=1)
        grouped_stats['interior_contribution'] = grouped_stats.apply(
            lambda row: row['mean_interior_coverage'] / row['mean_total_coverage'] if row['mean_total_coverage'] > 0 else 0, axis=1)

        # Ensure contributions roughly sum to 1 (or handle cases where they don't)
        # Normalize if needed, though ideally they should represent parts of the whole
        # grouped_stats['total_contribution'] = grouped_stats['boundary_contribution'] + grouped_stats['interior_contribution']
        # print(grouped_stats[['pressure', 'size_category', 'boundary_contribution', 'interior_contribution', 'total_contribution']]) # Debugging

        plt.figure(figsize=(14, 9))
        bar_width = 0.4 # Width for each pressure group within a size category
        size_categories = grouped_stats['size_category'].unique()
        pressures = grouped_stats['pressure'].unique()
        x_indices = np.arange(len(size_categories)) # Base indices for size categories

        for i, pressure in enumerate(pressures):
            pressure_data = grouped_stats[grouped_stats['pressure'] == pressure]
            # Calculate offset for bars of this pressure group
            offset = bar_width * (i - (len(pressures) - 1) / 2)

            plt.bar(x_indices + offset, pressure_data['boundary_contribution'], bar_width,
                    alpha=0.8, label=f'{pressure} - Boundary' if i == 0 else None, color=plt.cm.Paired(i*2)) # Boundary part
            plt.bar(x_indices + offset, pressure_data['interior_contribution'], bar_width,
                    bottom=pressure_data['boundary_contribution'], alpha=0.8,
                    label=f'{pressure} - Interior' if i == 0 else None, color=plt.cm.Paired(i*2+1)) # Interior part on top

        plt.xlabel('Cell Size Category', fontsize=14)
        plt.ylabel('Relative Contribution to Total Coverage', fontsize=14)
        plt.title('Membrane Location Contribution by Size and Pressure', fontsize=16)
        plt.xticks(x_indices, size_categories, rotation=15, ha='right')
        plt.ylim(0, 1.5) # Allow space for labels if sum > 1 due to definitions
        plt.grid(True, axis='y', alpha=0.4)
        # Create custom legend entries if automatic labels are messy
        from matplotlib.patches import Patch
        legend_elements = []
        for i, pressure in enumerate(pressures):
             legend_elements.append(Patch(facecolor=plt.cm.Paired(i*2), label=f'{pressure} Boundary', alpha=0.8))
             legend_elements.append(Patch(facecolor=plt.cm.Paired(i*2+1), label=f'{pressure} Interior', alpha=0.8))
        plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout for external legend

        plot_filename = "plot_bar_membrane_distribution_by_size.png"
        plt.savefig(os.path.join(size_dir, plot_filename), dpi=300)
        print(f"Saved: {plot_filename}")

    except Exception as e:
         print(f"Error generating stacked bar plot for membrane distribution: {e}")
         # print(traceback.format_exc()) # Uncomment for detailed error
    finally:
         plt.close()


    # --- Summary Statistics Table ---
    try:
        summary_stats = all_metrics_df.groupby(['pressure', 'size_category'], observed=True).agg(
            cell_count=('cell_id', 'count'),
            mean_cell_area=('cell_area', 'mean'),
            std_cell_area=('cell_area', 'std'),
            mean_total_coverage=('membrane_coverage_ratio', 'mean'),
            std_total_coverage=('membrane_coverage_ratio', 'std'),
            median_total_coverage=('membrane_coverage_ratio', 'median'),
            mean_boundary_coverage=('boundary_coverage_ratio', 'mean'),
            std_boundary_coverage=('boundary_coverage_ratio', 'std'),
            mean_interior_coverage=('interior_coverage_ratio', 'mean'),
            std_interior_coverage=('interior_coverage_ratio', 'std'),
            median_boundary_interior_ratio=('membrane_boundary_to_interior_ratio', 'median'),
        ).reset_index()

        # Save summary statistics
        summary_csv_path = os.path.join(size_dir, "summary_stats_by_size_category.csv")
        summary_stats.to_csv(summary_csv_path, index=False)
        print(f"Saved summary statistics to {summary_csv_path}")

    except Exception as e:
        print(f"Error calculating summary statistics: {e}")
        summary_stats = None # Indicate failure


    # --- Statistical Tests (Example: ANOVA between size categories within each pressure) ---
    stat_results = []
    if 'size_category' in all_metrics_df and all_metrics_df['size_category'].nunique() > 1: # Only run if multiple categories exist
        metrics_to_test = ['membrane_coverage_ratio', 'boundary_coverage_ratio', 'interior_coverage_ratio']
        for pressure in all_metrics_df['pressure'].unique():
            pressure_df = all_metrics_df[all_metrics_df['pressure'] == pressure]
            size_categories_in_pressure = pressure_df['size_category'].unique()

            if len(size_categories_in_pressure) < 2: # Need at least two groups to compare
                 continue

            for metric in metrics_to_test:
                if metric not in pressure_df.columns: continue

                # Prepare data for ANOVA: list of arrays, one for each size category
                groups = [group[metric].dropna().values for name, group in pressure_df.groupby('size_category', observed=True)]
                # Filter out empty groups or groups with insufficient data for ANOVA
                groups = [g for g in groups if len(g) > 1]

                if len(groups) >= 2: # Need at least two valid groups for ANOVA
                    try:
                        f_val, p_val = stats.f_oneway(*groups)
                        stat_results.append({
                            'pressure': pressure,
                            'metric': metric,
                            'test': 'ANOVA (vs size category)',
                            'f_value': f_val,
                            'p_value': p_val,
                            'significant (p<0.05)': p_val < 0.05
                        })
                    except Exception as e:
                         print(f"Error performing ANOVA for {metric} under {pressure}: {e}")
                         stat_results.append({
                            'pressure': pressure, 'metric': metric, 'test': 'ANOVA (vs size category)',
                            'f_value': None, 'p_value': None, 'significant (p<0.05)': 'Error', 'notes': str(e)
                         })

        # Save statistical test results
        if stat_results:
            try:
                stat_df = pd.DataFrame(stat_results)
                stat_csv_path = os.path.join(size_dir, "statistical_tests_size_categories.csv")
                stat_df.to_csv(stat_csv_path, index=False)
                print(f"Saved statistical test results to {stat_csv_path}")
            except Exception as e:
                 print(f"Error saving statistical test results CSV: {e}")
        else:
             print("No statistical tests performed (insufficient groups or data).")
    else:
         print("Skipping statistical tests: Only one size category present.")


    print("--- Size category analysis complete ---")
    return summary_stats # Return the summary table


def analyze_allometric_scaling(all_metrics_df, output_dir):
    """Performs allometric scaling analysis using log-log plots."""
    if all_metrics_df is None or all_metrics_df.empty:
        print("\n--- Skipping Allometric Scaling Analysis (No data) ---")
        return

    print("\n--- Performing Allometric Scaling Analysis ---")
    allo_dir = os.path.join(output_dir, "allometric_analysis")
    os.makedirs(allo_dir, exist_ok=True)

    # --- Prepare log-transformed data ---
    # Add a small constant (e.g., 1) before taking log to handle zeros
    log_constant = 1
    log_cols = {}
    cols_to_log = {
        'cell_area': 'log_cell_area',
        'total_membrane_area': 'log_total_membrane_area',
        'boundary_membrane_area': 'log_boundary_membrane_area',
        'interior_membrane_area': 'log_interior_membrane_area',
        'cell_perimeter': 'log_cell_perimeter'
    }

    for original_col, log_col in cols_to_log.items():
         if original_col in all_metrics_df:
              # Ensure column is numeric and handle potential non-positives if log_constant is not used
              numeric_col = pd.to_numeric(all_metrics_df[original_col], errors='coerce')
              if numeric_col.min() <= 0:
                  print(f"Warning: Column '{original_col}' contains non-positive values. Adding {log_constant} before log transform.")
                  all_metrics_df[log_col] = np.log(numeric_col + log_constant)
              else:
                  all_metrics_df[log_col] = np.log(numeric_col)
              # Replace -inf resulting from log(0+constant) or log(negative+constant) with NaN
              all_metrics_df[log_col] = all_metrics_df[log_col].replace(-np.inf, np.nan)
              log_cols[original_col] = log_col # Store mapping for later use
         else:
              print(f"Warning: Column '{original_col}' not found for log transform.")


    # --- Plot 1: Log(Total Membrane Area) vs Log(Cell Area) ---
    if 'log_cell_area' in all_metrics_df and 'log_total_membrane_area' in all_metrics_df:
        plt.figure(figsize=(12, 10))
        plot_success = False
        scaling_results = []

        for pressure in all_metrics_df['pressure'].unique():
            df = all_metrics_df[all_metrics_df['pressure'] == pressure]
            df_filt = df[['log_cell_area', 'log_total_membrane_area']].dropna()
            if df_filt.empty: continue

            sns.scatterplot(x='log_cell_area', y='log_total_membrane_area', data=df_filt, alpha=0.5, label=f"{pressure} cells")
            if len(df_filt) > 1:
                sns.regplot(x='log_cell_area', y='log_total_membrane_area', data=df_filt, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
                try:
                    X = df_filt['log_cell_area'].values.reshape(-1, 1)
                    y = df_filt['log_total_membrane_area'].values
                    model = LinearRegression().fit(X, y)
                    scaling_exponent = model.coef_[0] # Slope is the scaling exponent 'b' in log(y)=log(a)+b*log(x)
                    intercept = model.intercept_ # log(a)
                    r_squared = r2_score(y, model.predict(X))
                    annotation = f"{pressure}: Exp={scaling_exponent:.3f}, R²={r_squared:.3f}"
                    scaling_results.append({'pressure': pressure, 'analysis': 'TotalMembrane_vs_CellArea', 'exponent': scaling_exponent, 'r_squared': r_squared, 'intercept': intercept})
                except Exception as e:
                     print(f"Could not calculate scaling for {pressure} Total Membrane vs Area: {e}")
                     annotation = f"{pressure}: Stats Error"
                plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(all_metrics_df['pressure'].unique()).index(pressure)), xycoords='axes fraction',
                             ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                plot_success = True

        if plot_success:
            plt.title('Allometric Scaling: Total Membrane Area vs Cell Area (Log-Log)', fontsize=16)
            plt.xlabel(f'Log(Cell Area + {log_constant})', fontsize=14)
            plt.ylabel(f'Log(Total Membrane Area + {log_constant})', fontsize=14)
            plt.grid(True, alpha=0.4)
            plt.legend(fontsize=10)
            plt.tight_layout()
            plt.savefig(os.path.join(allo_dir, "plot_loglog_total_membrane_vs_area.png"), dpi=300)
            print("Saved: plot_loglog_total_membrane_vs_area.png")
        else:
            print("Skipping plot: Log(Total Membrane) vs Log(Area) (no data)")
        plt.close()


    # --- Plot 2: Log(Boundary Membrane Area) vs Log(Cell Perimeter) ---
    # Note: Scaling boundary membrane vs perimeter might be more theoretically grounded than vs area.
    if 'log_cell_perimeter' in all_metrics_df and 'log_boundary_membrane_area' in all_metrics_df:
        plt.figure(figsize=(12, 10))
        plot_success = False

        for pressure in all_metrics_df['pressure'].unique():
             df = all_metrics_df[all_metrics_df['pressure'] == pressure]
             df_filt = df[['log_cell_perimeter', 'log_boundary_membrane_area']].dropna()
             if df_filt.empty: continue

             sns.scatterplot(x='log_cell_perimeter', y='log_boundary_membrane_area', data=df_filt, alpha=0.5, label=f"{pressure} cells")
             if len(df_filt) > 1:
                 sns.regplot(x='log_cell_perimeter', y='log_boundary_membrane_area', data=df_filt, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
                 try:
                     X = df_filt['log_cell_perimeter'].values.reshape(-1, 1)
                     y = df_filt['log_boundary_membrane_area'].values
                     model = LinearRegression().fit(X, y)
                     scaling_exponent = model.coef_[0]
                     intercept = model.intercept_
                     r_squared = r2_score(y, model.predict(X))
                     annotation = f"{pressure}: Exp={scaling_exponent:.3f}, R²={r_squared:.3f}"
                     scaling_results.append({'pressure': pressure, 'analysis': 'BoundaryMembrane_vs_CellPerimeter', 'exponent': scaling_exponent, 'r_squared': r_squared, 'intercept': intercept})
                 except Exception as e:
                      print(f"Could not calculate scaling for {pressure} Boundary Membrane vs Perimeter: {e}")
                      annotation = f"{pressure}: Stats Error"
                 plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(all_metrics_df['pressure'].unique()).index(pressure)), xycoords='axes fraction',
                              ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                 plot_success = True

        if plot_success:
             plt.title('Allometric Scaling: Boundary Membrane Area vs Cell Perimeter (Log-Log)', fontsize=16)
             plt.xlabel(f'Log(Cell Perimeter + {log_constant})', fontsize=14)
             plt.ylabel(f'Log(Boundary Membrane Area + {log_constant})', fontsize=14)
             plt.grid(True, alpha=0.4)
             plt.legend(fontsize=10)
             plt.tight_layout()
             plt.savefig(os.path.join(allo_dir, "plot_loglog_boundary_membrane_vs_perimeter.png"), dpi=300)
             print("Saved: plot_loglog_boundary_membrane_vs_perimeter.png")
        else:
             print("Skipping plot: Log(Boundary Membrane) vs Log(Perimeter) (no data)")
        plt.close()

    # --- Plot 3: Log(Interior Membrane Area) vs Log(Cell Area) ---
    # Interior membrane might scale more closely with area (volume proxy).
    if 'log_cell_area' in all_metrics_df and 'log_interior_membrane_area' in all_metrics_df:
        plt.figure(figsize=(12, 10))
        plot_success = False

        for pressure in all_metrics_df['pressure'].unique():
             df = all_metrics_df[all_metrics_df['pressure'] == pressure]
             df_filt = df[['log_cell_area', 'log_interior_membrane_area']].dropna()
             if df_filt.empty: continue

             sns.scatterplot(x='log_cell_area', y='log_interior_membrane_area', data=df_filt, alpha=0.5, label=f"{pressure} cells")
             if len(df_filt) > 1:
                 sns.regplot(x='log_cell_area', y='log_interior_membrane_area', data=df_filt, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
                 try:
                     X = df_filt['log_cell_area'].values.reshape(-1, 1)
                     y = df_filt['log_interior_membrane_area'].values
                     model = LinearRegression().fit(X, y)
                     scaling_exponent = model.coef_[0]
                     intercept = model.intercept_
                     r_squared = r2_score(y, model.predict(X))
                     annotation = f"{pressure}: Exp={scaling_exponent:.3f}, R²={r_squared:.3f}"
                     scaling_results.append({'pressure': pressure, 'analysis': 'InteriorMembrane_vs_CellArea', 'exponent': scaling_exponent, 'r_squared': r_squared, 'intercept': intercept})
                 except Exception as e:
                      print(f"Could not calculate scaling for {pressure} Interior Membrane vs Area: {e}")
                      annotation = f"{pressure}: Stats Error"
                 plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(all_metrics_df['pressure'].unique()).index(pressure)), xycoords='axes fraction',
                              ha='left', fontsize=10, bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                 plot_success = True

        if plot_success:
             plt.title('Allometric Scaling: Interior Membrane Area vs Cell Area (Log-Log)', fontsize=16)
             plt.xlabel(f'Log(Cell Area + {log_constant})', fontsize=14)
             plt.ylabel(f'Log(Interior Membrane Area + {log_constant})', fontsize=14)
             plt.grid(True, alpha=0.4)
             plt.legend(fontsize=10)
             plt.tight_layout()
             plt.savefig(os.path.join(allo_dir, "plot_loglog_interior_membrane_vs_area.png"), dpi=300)
             print("Saved: plot_loglog_interior_membrane_vs_area.png")
        else:
             print("Skipping plot: Log(Interior Membrane) vs Log(Area) (no data)")
        plt.close()

    # Save scaling results to CSV
    if scaling_results:
         try:
             scaling_df = pd.DataFrame(scaling_results)
             scaling_csv_path = os.path.join(allo_dir, "summary_allometric_scaling_exponents.csv")
             scaling_df.to_csv(scaling_csv_path, index=False)
             print(f"Saved allometric scaling results to {scaling_csv_path}")
         except Exception as e:
              print(f"Error saving scaling results CSV: {e}")
    else:
         print("No scaling results calculated.")

    print("--- Allometric scaling analysis complete ---")


# --- Main Execution Block ---
if __name__ == "__main__":
    print("\n======= Starting Cell Membrane Analysis Script =======")

    # 1. Find and organize mask files
    # Assuming cell_mask_dir and membrane_dir are defined elsewhere in your script
    # Assuming find_mask_files is a function defined elsewhere
    pressure_file_dict = find_mask_files(cell_mask_dir, membrane_dir)

    if pressure_file_dict is None:
         print("\nSCRIPT HALTED: Could not find or access mask files. Please check paths and Drive mount.")
    else:
         # 2. Process files for each pressure group
         all_pressure_results = {}
         # Assuming process_pressure_group is a function defined elsewhere
         # Assuming output_dir is defined elsewhere
         # Assuming pd is imported as pandas
         import pandas as pd # Ensure pandas is imported if not already

         for pressure, file_list in pressure_file_dict.items():
             if not file_list:
                 print(f"\nNo files found for pressure group: {pressure}. Skipping.")
                 all_pressure_results[pressure] = {'pressure': pressure, 'total_samples_processed': 0, 'total_cells': 0, 'all_cell_metrics': [], 'cell_metrics_df': pd.DataFrame()}
                 continue
             pressure_results = process_pressure_group(file_list, pressure, output_dir)
             all_pressure_results[pressure] = pressure_results

         # 3. Perform combined analyses only if there's data
         # Check if any group produced a non-empty dataframe
         has_data = any(not results.get('cell_metrics_df', pd.DataFrame()).empty for results in all_pressure_results.values())

         if has_data:
             # 4. Create Correlation Analysis Plots and get combined DataFrame
             # Assuming create_correlation_analysis is a function defined elsewhere
             combined_df = create_correlation_analysis(all_pressure_results, output_dir)

             # --- Add Percentage Calculation ---
             if combined_df is not None and not combined_df.empty:
                 print("\n--- Calculating Percentage Metrics ---")
                 try:
                     # Calculate percentage columns by multiplying ratios by 100
                     combined_df['membrane_coverage_percentage'] = combined_df['membrane_coverage_ratio'] * 100
                     combined_df['boundary_coverage_percentage'] = combined_df['boundary_coverage_ratio'] * 100
                     combined_df['interior_coverage_percentage'] = combined_df['interior_coverage_ratio'] * 100
                     print("Added percentage columns (e.g., 'membrane_coverage_percentage') to the DataFrame.")

                     # Optional: You can now update plotting functions to use these
                     # new columns and adjust axis labels if desired. For example:
                     # In create_correlation_analysis:
                     #   replace y='membrane_coverage_ratio' with y='membrane_coverage_percentage'
                     #   replace plt.ylabel('Membrane Coverage Ratio ...') with plt.ylabel('Membrane Coverage (%) ...')
                     # Similar changes can be made in analyze_size_categories plots.

                 except KeyError as e:
                      print(f"Warning: Could not calculate percentages. Missing column: {e}")
                 except Exception as e:
                      print(f"Warning: Error calculating percentages: {e}")

                 # --- Continue with subsequent analyses using the updated combined_df ---
                 # 5. Analyze by Size Categories
                 # Assuming analyze_size_categories is a function defined elsewhere
                 analyze_size_categories(combined_df, output_dir)

                 # 6. Analyze Allometric Scaling
                 # Assuming analyze_allometric_scaling is a function defined elsewhere
                 analyze_allometric_scaling(combined_df, output_dir)
             # --- End of Percentage Calculation block ---
             else:
                  print("\nSkipping further analysis as combined data is empty or correlation analysis failed.")
         else:
              print("\nNo cell data was successfully processed. Skipping combined analyses.")


         print("\n======= Analysis Script Finished =======")
         print(f"Outputs (CSV files and PNG plots) saved in: {output_dir}")
         print("Please check the subfolders: correlation_analysis, size_categories, allometric_analysis")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dependencies checked/installed.
Libraries imported successfully.
Output directory created/verified: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation


--- Finding and Pairing Mask Files ---
Found 8 potential cell mask files in /content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative
Found 8 potential membrane mask files in /content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Membrane
Total matching cell-membrane file pairs found: 8
Pairs per pressure group:
  0Pa: 3 pairs
  1.4Pa: 5 pairs
-----------------------------------

=== Processing 0Pa Group (3 file pairs) ===
--- 0Pa Group Summary ---
Successfully processed: 3 pairs
Failed/Skipped: 0 pairs
Total cells analyzed: 3
Saved 0Pa metrics to /content/drive/MyDrive/knowledge/Universit

  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]


Saved: plot_corr_area_vs_membrane_coverage.png


  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]


Saved: plot_corr_perimeter_vs_boundary_coverage.png


  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]


Saved: plot_corr_circularity_vs_membrane_coverage.png
Saved: plot_scatter_interior_vs_boundary_coverage.png
--- Correlation analysis complete ---

--- Calculating Percentage Metrics ---
Added percentage columns (e.g., 'membrane_coverage_percentage') to the DataFrame.

--- Analyzing Metrics by Cell Size Categories ---
Size categories defined (based on area quartiles):
  Smallest: 0.0 - 1048575.0 pixels²
  Small-Med: 1048575.0 - inf pixels²
Saved categorized data to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/size_categories/cells_with_size_categories.csv
Saved: plot_boxplot_membrane_coverage_ratio_by_size_category.png
Saved: plot_boxplot_boundary_coverage_ratio_by_size_category.png
Saved: plot_boxplot_interior_coverage_ratio_by_size_category.png
Saved: plot_bar_membrane_distribution_by_size.png
Saved summary statistics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/s

In [22]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install opencv-python-headless scikit-image seaborn -q

# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation
from scipy import ndimage, stats
import re
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression

# Better visualization settings
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100
plt.style.use('ggplot')

# Define your input and output directories
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
membrane_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Membrane"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# --- Helper Functions ---
def extract_pressure(filename):
    """Extracts '0Pa' or '1.4Pa' from a filename."""
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    return match.group(1) if match else None

def extract_sample_id(filename):
    """Extracts a unique sample identifier from a filename to match pairs."""
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    return match.group(1) if match else None

def find_mask_files(cell_dir, membrane_dir):
    """Finds and pairs cell and membrane mask files based on sample ID and pressure."""
    print("\n--- Finding and Pairing Mask Files ---")
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    cell_files = [f for f in os.listdir(cell_dir) if f.endswith(('_cell_mask_merged_conservative.tif', '_cell_mask.tif')) and not f.startswith('.')]
    membrane_files = [f for f in os.listdir(membrane_dir) if f.endswith('.tif') and not f.startswith('.')]

    print(f"Found {len(cell_files)} cell mask files and {len(membrane_files)} membrane mask files")

    # Create lookup dictionary for membrane files
    membrane_lookup = {}
    for membrane_file in membrane_files:
        sample_id = extract_sample_id(membrane_file)
        if sample_id:
            membrane_lookup[sample_id] = membrane_file

    # Match cell files to membrane files
    pairs_found = 0
    processed_cell_ids = set()
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if sample_id in processed_cell_ids:
            continue

        if pressure and pressure in pressure_dict and sample_id:
            if sample_id in membrane_lookup:
                membrane_file = membrane_lookup[sample_id]
                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'membrane_file': os.path.join(membrane_dir, membrane_file),
                    'sample_id': sample_id
                }
                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                processed_cell_ids.add(sample_id)

    print(f"Total matching cell-membrane file pairs found: {pairs_found}")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} pairs")

    return pressure_dict

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8)
        elif np.max(img) > 1:
            img = (img > 0).astype(np.uint8)
        else:
            img = img.astype(np.uint8)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]
            img = (img > 0).astype(np.uint8)

        return img
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def analyze_cell_membrane_size_relation(cell_mask, membrane_mask):
    """Analyzes individual cells from masks to extract size and membrane metrics."""
    # Ensure masks are binary
    binary_cell_mask = (cell_mask > 0).astype(np.uint8)
    binary_membrane_mask = (membrane_mask > 0).astype(np.uint8)

    # Label individual cells in the cell mask
    labeled_cells, num_cells = ndimage.label(binary_cell_mask)
    if num_cells == 0:
        return []

    cell_metrics = []
    cell_slices = ndimage.find_objects(labeled_cells)

    for cell_id in range(1, num_cells + 1):
        current_slice = cell_slices[cell_id - 1]
        sub_labeled = labeled_cells[current_slice]
        sub_cell_mask = (sub_labeled == cell_id).astype(np.uint8)
        sub_membrane_mask = binary_membrane_mask[current_slice]

        # Calculate Cell Metrics
        cell_area = np.sum(sub_cell_mask)
        if cell_area < 5:  # Skip tiny artifacts
            continue

        # Cell Perimeter using OpenCV contours
        contours, _ = cv2.findContours(sub_cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        cell_perimeter = cv2.arcLength(contours[0], True) if contours else 0

        # Calculate Membrane Metrics
        cell_membrane_overlap = np.logical_and(sub_cell_mask, sub_membrane_mask)
        total_membrane_area = np.sum(cell_membrane_overlap)

        # Define boundary using erosion
        eroded_cell = ndimage.binary_erosion(sub_cell_mask).astype(sub_cell_mask.dtype)
        cell_boundary = sub_cell_mask - eroded_cell
        cell_boundary_length = np.sum(cell_boundary)

        # Membrane on boundary
        membrane_on_boundary = np.logical_and(cell_boundary, sub_membrane_mask)
        boundary_membrane_area = np.sum(membrane_on_boundary)

        # Cell Interior
        cell_interior = eroded_cell
        interior_area = np.sum(cell_interior)

        # Membrane in interior
        interior_membrane = np.logical_and(cell_interior, sub_membrane_mask)
        interior_membrane_area = np.sum(interior_membrane)

        # Calculate Ratios
        membrane_coverage_ratio = total_membrane_area / cell_area if cell_area > 0 else 0
        boundary_coverage_ratio = boundary_membrane_area / cell_boundary_length if cell_boundary_length > 0 else 0
        interior_coverage_ratio = interior_membrane_area / interior_area if interior_area > 0 else 0

        # Shape Metrics
        circularity = (4 * np.pi * cell_area) / (cell_perimeter ** 2) if cell_perimeter > 0 else 0
        equivalent_diameter = 2 * np.sqrt(cell_area / np.pi)

        # Ratio of membrane on boundary vs interior
        membrane_boundary_to_interior_ratio = boundary_membrane_area / interior_membrane_area if interior_membrane_area > 0 else (boundary_membrane_area if boundary_membrane_area > 0 else 0)

        metrics = {
            'cell_id': cell_id,
            'cell_area': cell_area,
            'cell_perimeter': cell_perimeter,
            'equivalent_diameter': equivalent_diameter,
            'circularity': circularity,
            'total_membrane_area': total_membrane_area,
            'membrane_coverage_ratio': membrane_coverage_ratio,
            'cell_boundary_length': cell_boundary_length,
            'boundary_membrane_area': boundary_membrane_area,
            'boundary_coverage_ratio': boundary_coverage_ratio,
            'interior_area': interior_area,
            'interior_membrane_area': interior_membrane_area,
            'interior_coverage_ratio': interior_coverage_ratio,
            'membrane_boundary_to_interior_ratio': membrane_boundary_to_interior_ratio
        }
        cell_metrics.append(metrics)

    return cell_metrics

def process_file_pair(file_pair, pressure):
    """Loads masks for a pair, handles resizing, runs analysis, returns metrics."""
    sample_id = file_pair['sample_id']

    # Load masks
    cell_mask = load_mask_image(file_pair['cell_file'])
    membrane_mask = load_mask_image(file_pair['membrane_file'])

    if cell_mask is None or membrane_mask is None:
        return None

    # Ensure masks have same dimensions
    if cell_mask.shape != membrane_mask.shape:
        membrane_mask = cv2.resize(membrane_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                 interpolation=cv2.INTER_NEAREST)

    # Analyze relationship
    cell_metrics = analyze_cell_membrane_size_relation(cell_mask, membrane_mask)

    # Add sample information
    for metrics in cell_metrics:
        metrics['sample_id'] = sample_id
        metrics['pressure'] = pressure

    return cell_metrics

def process_pressure_group(file_pairs, pressure, output_dir):
    """Processes all file pairs for a specific pressure group."""
    print(f"\n=== Processing {pressure} Group ({len(file_pairs)} file pairs) ===")

    all_cell_metrics = []
    processed_count = 0
    failed_count = 0

    for file_pair in file_pairs:
        cell_metrics = process_file_pair(file_pair, pressure)
        if cell_metrics is not None:
            all_cell_metrics.extend(cell_metrics)
            processed_count += 1
        else:
            failed_count += 1

    print(f"Successfully processed: {processed_count} pairs")
    print(f"Total cells analyzed: {len(all_cell_metrics)}")

    # Create dataframe for this pressure group
    if all_cell_metrics:
        results_df = pd.DataFrame(all_cell_metrics)
        group_csv_path = os.path.join(output_dir, f"{pressure}_cell_metrics.csv")
        results_df.to_csv(group_csv_path, index=False)
        print(f"Saved {pressure} metrics to {group_csv_path}")
    else:
        results_df = pd.DataFrame()

    return {
        'pressure': pressure,
        'total_samples_processed': processed_count,
        'total_cells': len(all_cell_metrics),
        'all_cell_metrics': all_cell_metrics,
        'cell_metrics_df': results_df
    }

def create_correlation_analysis(pressure_results, output_dir):
    """Generates scatter plots with regression lines for key metric correlations."""
    print("\n--- Generating Correlation Analysis Plots ---")
    corr_dir = os.path.join(output_dir, "correlation_analysis")
    os.makedirs(corr_dir, exist_ok=True)

    pressure_names = list(pressure_results.keys())
    valid_results = {p: r for p, r in pressure_results.items() if 'cell_metrics_df' in r and not r['cell_metrics_df'].empty}

    if not valid_results:
        print("No valid data found for correlation analysis.")
        return None

    # Combine dataframes from all valid pressure groups
    all_metrics_df = pd.concat([valid_results[p]['cell_metrics_df'] for p in valid_results])

    # Save the combined dataframe
    all_metrics_csv_path = os.path.join(corr_dir, "all_combined_cell_metrics.csv")
    all_metrics_df.to_csv(all_metrics_csv_path, index=False)

    # --- Plot: Cell Area vs Membrane Coverage ---
    plt.figure(figsize=(12, 10))
    for pressure in valid_results.keys():
        df = valid_results[pressure]['cell_metrics_df']
        if df.empty or 'cell_area' not in df or 'membrane_coverage_ratio' not in df:
            continue

        sns.scatterplot(x='cell_area', y='membrane_coverage_ratio', data=df, alpha=0.5, label=f"{pressure} cells")
        if len(df) > 1:
            sns.regplot(x='cell_area', y='membrane_coverage_ratio', data=df, scatter=False, line_kws={'linewidth': 2}, label=f"{pressure} trend")
            try:
                correlation = df['cell_area'].corr(df['membrane_coverage_ratio'])
                x = df['cell_area'].values.reshape(-1, 1)
                y = df['membrane_coverage_ratio'].values
                model = LinearRegression().fit(x, y)
                r_squared = r2_score(y, model.predict(x))
                annotation = f"{pressure}: r={correlation:.2f}, R²={r_squared:.2f}"
                plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(valid_results.keys()).index(pressure)),
                           xycoords='axes fraction', ha='left', fontsize=10,
                           bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
            except Exception as e:
                print(f"Could not calculate stats for {pressure} Area vs Coverage: {e}")

    plt.title('Cell Area vs Total Membrane Coverage Ratio', fontsize=16)
    plt.xlabel('Cell Area (pixels²)', fontsize=14)
    plt.ylabel('Membrane Coverage Ratio (Membrane Area / Cell Area)', fontsize=14)
    plt.grid(True, alpha=0.4)
    plt.legend(fontsize=10)
    plt.tight_layout()
    plt.savefig(os.path.join(corr_dir, "plot_corr_area_vs_membrane_coverage.png"), dpi=300)
    plt.close()

    # --- Create additional correlation plots based on your data ---
    # (You can add other important plots here)

    return all_metrics_df

def analyze_size_categories(all_metrics_df, output_dir):
    """Analyzes metrics grouped by cell size categories."""
    if all_metrics_df is None or all_metrics_df.empty:
        print("\n--- Skipping Size Category Analysis (No data) ---")
        return None

    print("\n--- Analyzing Metrics by Cell Size Categories ---")
    size_dir = os.path.join(output_dir, "size_categories")
    os.makedirs(size_dir, exist_ok=True)

    # Define size categories based on quartiles
    try:
        size_q1 = all_metrics_df['cell_area'].quantile(0.25)
        size_q2 = all_metrics_df['cell_area'].quantile(0.50)
        size_q3 = all_metrics_df['cell_area'].quantile(0.75)

        bins = sorted(list(set([0, size_q1, size_q2, size_q3, float('inf')])))
        labels = ['Smallest', 'Small-Med', 'Med-Large', 'Largest']

        # Adjust labels if bins collapsed
        if len(bins) - 1 < len(labels):
            labels = labels[:len(bins)-1]

        if len(bins) <= 2:  # Not enough variation
            all_metrics_df['size_category'] = 'All Sizes'
            print("Size categories defined: All Sizes > 0 pixels")
        else:
            all_metrics_df['size_category'] = pd.cut(
                all_metrics_df['cell_area'],
                bins=bins,
                labels=labels,
                right=False,
                include_lowest=True
            )
            print("Size categories defined (based on area quartiles):")
            for i in range(len(labels)):
                print(f"  {labels[i]}: {bins[i]:.1f} - {bins[i+1]:.1f} pixels²")

    except Exception as e:
        print(f"Error defining size categories: {e}. Assigning all to one category.")
        all_metrics_df['size_category'] = 'All Sizes'

    # Save categorized dataframe
    categorized_csv_path = os.path.join(size_dir, "cells_with_size_categories.csv")
    all_metrics_df.to_csv(categorized_csv_path, index=False)

    # Create boxplots for key metrics by size category
    metrics_to_plot = {
        'membrane_coverage_ratio': 'Total Membrane Coverage by Size Category',
        'boundary_coverage_ratio': 'Boundary Membrane Coverage by Size Category',
        'interior_coverage_ratio': 'Interior Membrane Coverage by Size Category'
    }

    category_order = labels if 'labels' in locals() and len(bins)>2 else ['All Sizes']

    for metric, title in metrics_to_plot.items():
        if metric not in all_metrics_df.columns or all_metrics_df[metric].isnull().all():
            continue

        plt.figure(figsize=(14, 8))
        try:
            sns.boxplot(
                x='size_category',
                y=metric,
                hue='pressure',
                data=all_metrics_df,
                order=category_order,
                palette='viridis'
            )
            plt.title(title, fontsize=16)
            plt.xlabel('Cell Size Category', fontsize=14)
            plt.ylabel(metric.replace('_', ' ').title(), fontsize=14)
            plt.xticks(rotation=15, ha='right')
            plt.grid(True, axis='y', alpha=0.4)
            plt.legend(title='Pressure')
            plt.tight_layout()
            plt.savefig(os.path.join(size_dir, f"plot_boxplot_{metric}_by_size_category.png"), dpi=300)
        except Exception as e:
            print(f"Error generating boxplot for {metric}: {e}")
        finally:
            plt.close()

    # Add percentage calculations
    for ratio_col in ['membrane_coverage_ratio', 'boundary_coverage_ratio', 'interior_coverage_ratio']:
        if ratio_col in all_metrics_df.columns:
            all_metrics_df[ratio_col.replace('ratio', 'percentage')] = all_metrics_df[ratio_col] * 100

    # Calculate summary statistics
    summary_stats = all_metrics_df.groupby(['pressure', 'size_category'], observed=True).agg(
        cell_count=('cell_id', 'count'),
        mean_cell_area=('cell_area', 'mean'),
        std_cell_area=('cell_area', 'std'),
        mean_total_coverage=('membrane_coverage_ratio', 'mean'),
        std_total_coverage=('membrane_coverage_ratio', 'std'),
        mean_boundary_coverage=('boundary_coverage_ratio', 'mean'),
        mean_interior_coverage=('interior_coverage_ratio', 'mean')
    ).reset_index()

    # Save summary statistics
    summary_csv_path = os.path.join(size_dir, "summary_stats_by_size_category.csv")
    summary_stats.to_csv(summary_csv_path, index=False)

    return summary_stats

def analyze_allometric_scaling(all_metrics_df, output_dir):
    """Performs allometric scaling analysis using log-log plots."""
    if all_metrics_df is None or all_metrics_df.empty:
        print("\n--- Skipping Allometric Scaling Analysis (No data) ---")
        return

    print("\n--- Performing Allometric Scaling Analysis ---")
    allo_dir = os.path.join(output_dir, "allometric_analysis")
    os.makedirs(allo_dir, exist_ok=True)

    # Prepare log-transformed data
    log_constant = 1
    cols_to_log = {
        'cell_area': 'log_cell_area',
        'total_membrane_area': 'log_total_membrane_area',
        'boundary_membrane_area': 'log_boundary_membrane_area',
        'interior_membrane_area': 'log_interior_membrane_area',
        'cell_perimeter': 'log_cell_perimeter'
    }

    for original_col, log_col in cols_to_log.items():
        if original_col in all_metrics_df:
            numeric_col = pd.to_numeric(all_metrics_df[original_col], errors='coerce')
            all_metrics_df[log_col] = np.log(numeric_col + log_constant)
            all_metrics_df[log_col] = all_metrics_df[log_col].replace(-np.inf, np.nan)

    # Create log-log plots and calculate scaling exponents
    scaling_results = []

    # Plot: Log(Total Membrane Area) vs Log(Cell Area)
    if 'log_cell_area' in all_metrics_df and 'log_total_membrane_area' in all_metrics_df:
        plt.figure(figsize=(12, 10))

        for pressure in all_metrics_df['pressure'].unique():
            df = all_metrics_df[all_metrics_df['pressure'] == pressure]
            df_filt = df[['log_cell_area', 'log_total_membrane_area']].dropna()
            if df_filt.empty: continue

            sns.scatterplot(x='log_cell_area', y='log_total_membrane_area', data=df_filt, alpha=0.5, label=f"{pressure} cells")
            if len(df_filt) > 1:
                sns.regplot(x='log_cell_area', y='log_total_membrane_area', data=df_filt, scatter=False, line_kws={'linewidth': 2})

                try:
                    X = df_filt['log_cell_area'].values.reshape(-1, 1)
                    y = df_filt['log_total_membrane_area'].values
                    model = LinearRegression().fit(X, y)
                    scaling_exponent = model.coef_[0]
                    r_squared = r2_score(y, model.predict(X))
                    annotation = f"{pressure}: Exp={scaling_exponent:.3f}, R²={r_squared:.3f}"
                    scaling_results.append({
                        'pressure': pressure,
                        'analysis': 'TotalMembrane_vs_CellArea',
                        'exponent': scaling_exponent,
                        'r_squared': r_squared
                    })
                    plt.annotate(annotation, xy=(0.05, 0.95 - 0.06 * list(all_metrics_df['pressure'].unique()).index(pressure)),
                               xycoords='axes fraction', ha='left', fontsize=10,
                               bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
                except Exception as e:
                    print(f"Could not calculate scaling for {pressure}: {e}")

        plt.title('Allometric Scaling: Total Membrane Area vs Cell Area (Log-Log)', fontsize=16)
        plt.xlabel(f'Log(Cell Area + {log_constant})', fontsize=14)
        plt.ylabel(f'Log(Total Membrane Area + {log_constant})', fontsize=14)
        plt.grid(True, alpha=0.4)
        plt.legend(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(allo_dir, "plot_loglog_total_membrane_vs_area.png"), dpi=300)
        plt.close()

    # Save scaling results
    if scaling_results:
        scaling_df = pd.DataFrame(scaling_results)
        scaling_csv_path = os.path.join(allo_dir, "summary_allometric_scaling_exponents.csv")
        scaling_df.to_csv(scaling_csv_path, index=False)

# --- Main Execution Block ---
if __name__ == "__main__":
    print("\n======= Starting Cell Membrane Analysis Script =======")

    # 1. Find and organize mask files
    pressure_file_dict = find_mask_files(cell_mask_dir, membrane_dir)

    if pressure_file_dict is None:
        print("\nSCRIPT HALTED: Could not find or access mask files. Please check paths.")
    else:
        # 2. Process files for each pressure group
        all_pressure_results = {}
        for pressure, file_list in pressure_file_dict.items():
            if not file_list:
                print(f"\nNo files found for pressure group: {pressure}. Skipping.")
                all_pressure_results[pressure] = {'pressure': pressure, 'total_samples_processed': 0, 'total_cells': 0, 'all_cell_metrics': [], 'cell_metrics_df': pd.DataFrame()}
                continue
            pressure_results = process_pressure_group(file_list, pressure, output_dir)
            all_pressure_results[pressure] = pressure_results

        # 3. Perform combined analyses only if there's data
        has_data = any(not results.get('cell_metrics_df', pd.DataFrame()).empty for results in all_pressure_results.values())

        if has_data:
            # 4. Create Correlation Analysis Plots and get combined DataFrame
            combined_df = create_correlation_analysis(all_pressure_results, output_dir)

            if combined_df is not None and not combined_df.empty:
                # 5. Analyze by Size Categories
                analyze_size_categories(combined_df, output_dir)

                # 6. Analyze Allometric Scaling
                analyze_allometric_scaling(combined_df, output_dir)
            else:
                print("\nSkipping further analysis as combined data is empty.")
        else:
            print("\nNo cell data was successfully processed. Skipping combined analyses.")

        print("\n======= Analysis Script Finished =======")
        print(f"Outputs saved in: {output_dir}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


--- Finding and Pairing Mask Files ---
Found 8 cell mask files and 8 membrane mask files
Total matching cell-membrane file pairs found: 8
  0Pa: 3 pairs
  1.4Pa: 5 pairs

=== Processing 0Pa Group (3 file pairs) ===
Successfully processed: 3 pairs
Total cells analyzed: 3
Saved 0Pa metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/0Pa_cell_metrics.csv

=== Processing 1.4Pa Group (5 file pairs) ===
Successfully processed: 5 pairs
Total cells analyzed: 5
Saved 1.4Pa metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/1.4Pa_cell_metrics.csv

--- Generating Correlation Analysis Plots ---


  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]



--- Analyzing Metrics by Cell Size Categories ---
Size categories defined (based on area quartiles):
  Smallest: 0.0 - 1048575.0 pixels²
  Small-Med: 1048575.0 - inf pixels²

--- Performing Allometric Scaling Analysis ---

Outputs saved in: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation


cell coverage

In [24]:
# Cell Membrane Analysis Script
# Analyzes the relationship between cell size and membrane distribution

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation
from scipy import ndimage, stats
import re
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

# Set visualization style
sns.set_theme()  # Uses default Seaborn style
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100

# Define input and output directories
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
membrane_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Membrane"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# --- Helper Functions ---

def extract_pressure(filename):
    """Extracts '0Pa' or '1.4Pa' from a filename."""
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    return match.group(1) if match else None

def extract_sample_id(filename):
    """Extracts a unique sample identifier from a filename to match pairs."""
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    return match.group(1) if match else None

def find_mask_files(cell_dir, membrane_dir):
    """Finds and pairs cell and membrane mask files based on sample ID and pressure."""
    print("\n--- Finding and Pairing Mask Files ---")
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    cell_files = [f for f in os.listdir(cell_dir) if f.endswith(('_cell_mask_merged_conservative.tif', '_cell_mask.tif')) and not f.startswith('.')]
    membrane_files = [f for f in os.listdir(membrane_dir) if f.endswith('.tif') and not f.startswith('.')]

    print(f"Found {len(cell_files)} cell mask files and {len(membrane_files)} membrane mask files")

    # Create lookup dictionary for membrane files
    membrane_lookup = {}
    for membrane_file in membrane_files:
        sample_id = extract_sample_id(membrane_file)
        if sample_id:
            membrane_lookup[sample_id] = membrane_file

    # Match cell files to membrane files
    pairs_found = 0
    processed_cell_ids = set()
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if sample_id in processed_cell_ids:
            continue

        if pressure and pressure in pressure_dict and sample_id:
            if sample_id in membrane_lookup:
                membrane_file = membrane_lookup[sample_id]
                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'membrane_file': os.path.join(membrane_dir, membrane_file),
                    'sample_id': sample_id
                }
                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                processed_cell_ids.add(sample_id)

    print(f"Total matching cell-membrane file pairs found: {pairs_found}")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} pairs")

    return pressure_dict

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8)
        elif np.max(img) > 1:
            img = (img > 0).astype(np.uint8)
        else:
            img = img.astype(np.uint8)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]
            img = (img > 0).astype(np.uint8)

        return img
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def analyze_cell_membrane_size_relation(cell_mask, membrane_mask):
    """Analyzes individual cells from masks to extract size and membrane metrics."""
    # Ensure masks are binary
    binary_cell_mask = (cell_mask > 0).astype(np.uint8)
    binary_membrane_mask = (membrane_mask > 0).astype(np.uint8)

    # Label individual cells in the cell mask
    labeled_cells, num_cells = ndimage.label(binary_cell_mask)
    if num_cells == 0:
        return []

    cell_metrics = []
    cell_slices = ndimage.find_objects(labeled_cells)

    for cell_id in range(1, num_cells + 1):
        # Get the slice for current cell
        current_slice = cell_slices[cell_id - 1]
        sub_labeled = labeled_cells[current_slice]
        sub_cell_mask = (sub_labeled == cell_id).astype(np.uint8)
        sub_membrane_mask = binary_membrane_mask[current_slice]

        # Calculate Cell Metrics
        cell_area = np.sum(sub_cell_mask)
        if cell_area < 5:  # Skip tiny artifacts
            continue

        # Cell Perimeter using OpenCV contours
        contours, _ = cv2.findContours(sub_cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        cell_perimeter = cv2.arcLength(contours[0], True) if contours else 0

        # Calculate Membrane Metrics
        cell_membrane_overlap = np.logical_and(sub_cell_mask, sub_membrane_mask)
        total_membrane_area = np.sum(cell_membrane_overlap)

        # Define boundary using erosion
        eroded_cell = ndimage.binary_erosion(sub_cell_mask).astype(sub_cell_mask.dtype)
        cell_boundary = sub_cell_mask - eroded_cell
        cell_boundary_length = np.sum(cell_boundary)

        # Membrane on boundary
        membrane_on_boundary = np.logical_and(cell_boundary, sub_membrane_mask)
        boundary_membrane_area = np.sum(membrane_on_boundary)

        # Cell Interior
        cell_interior = eroded_cell
        interior_area = np.sum(cell_interior)

        # Membrane in interior
        interior_membrane = np.logical_and(cell_interior, sub_membrane_mask)
        interior_membrane_area = np.sum(interior_membrane)

        # Calculate Ratios
        membrane_coverage_ratio = total_membrane_area / cell_area if cell_area > 0 else 0
        boundary_coverage_ratio = boundary_membrane_area / cell_boundary_length if cell_boundary_length > 0 else 0
        interior_coverage_ratio = interior_membrane_area / interior_area if interior_area > 0 else 0

        # Shape Metrics
        circularity = (4 * np.pi * cell_area) / (cell_perimeter ** 2) if cell_perimeter > 0 else 0
        equivalent_diameter = 2 * np.sqrt(cell_area / np.pi)

        # Ratio of membrane on boundary vs interior
        membrane_boundary_to_interior_ratio = boundary_membrane_area / interior_membrane_area if interior_membrane_area > 0 else (boundary_membrane_area if boundary_membrane_area > 0 else 0)

        metrics = {
            'cell_id': cell_id,
            'cell_area': cell_area,
            'cell_perimeter': cell_perimeter,
            'equivalent_diameter': equivalent_diameter,
            'circularity': circularity,
            'total_membrane_area': total_membrane_area,
            'membrane_coverage_ratio': membrane_coverage_ratio,
            'cell_boundary_length': cell_boundary_length,
            'boundary_membrane_area': boundary_membrane_area,
            'boundary_coverage_ratio': boundary_coverage_ratio,
            'interior_area': interior_area,
            'interior_membrane_area': interior_membrane_area,
            'interior_coverage_ratio': interior_coverage_ratio,
            'membrane_boundary_to_interior_ratio': membrane_boundary_to_interior_ratio
        }
        cell_metrics.append(metrics)

    return cell_metrics

def process_file_pair(file_pair, pressure):
    """Loads masks for a pair, handles resizing, runs analysis, returns metrics."""
    sample_id = file_pair['sample_id']

    # Load masks
    cell_mask = load_mask_image(file_pair['cell_file'])
    membrane_mask = load_mask_image(file_pair['membrane_file'])

    if cell_mask is None or membrane_mask is None:
        return None

    # Ensure masks have same dimensions - Resize membrane mask if necessary
    if cell_mask.shape != membrane_mask.shape:
        try:
            membrane_mask = cv2.resize(membrane_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                     interpolation=cv2.INTER_NEAREST)
        except Exception as e:
            print(f"Error resizing membrane mask for {sample_id}: {e}")
            return None

    # Analyze relationship
    cell_metrics = analyze_cell_membrane_size_relation(cell_mask, membrane_mask)

    # Add sample information
    for metrics in cell_metrics:
        metrics['sample_id'] = sample_id
        metrics['pressure'] = pressure

    return cell_metrics

def process_pressure_group(file_pairs, pressure, output_dir):
    """Processes all file pairs for a specific pressure group."""
    print(f"\n=== Processing {pressure} Group ({len(file_pairs)} file pairs) ===")

    all_cell_metrics = []
    processed_count = 0
    failed_count = 0

    for file_pair in file_pairs:
        cell_metrics = process_file_pair(file_pair, pressure)
        if cell_metrics is not None:
            all_cell_metrics.extend(cell_metrics)
            processed_count += 1
        else:
            failed_count += 1

    print(f"Successfully processed: {processed_count} pairs")
    print(f"Failed/Skipped: {failed_count} pairs")
    print(f"Total cells analyzed: {len(all_cell_metrics)}")

    # Create and save dataframe for this pressure group
    if all_cell_metrics:
        results_df = pd.DataFrame(all_cell_metrics)
        group_csv_path = os.path.join(output_dir, f"{pressure}_cell_metrics.csv")
        results_df.to_csv(group_csv_path, index=False)
        print(f"Saved {pressure} metrics to {group_csv_path}")
    else:
        results_df = pd.DataFrame()

    return {
        'pressure': pressure,
        'total_samples_processed': processed_count,
        'total_cells': len(all_cell_metrics),
        'all_cell_metrics': all_cell_metrics,
        'cell_metrics_df': results_df
    }

# --- Analysis Functions ---

def create_combined_dataframe(pressure_results, output_dir):
    """Combines the data from all pressure groups and computes percentage metrics."""
    print("\n--- Creating Combined Analysis Dataset ---")

    valid_results = {p: r for p, r in pressure_results.items()
                    if 'cell_metrics_df' in r and not r['cell_metrics_df'].empty}

    if not valid_results:
        print("No valid data found for analysis.")
        return None

    # Combine dataframes
    combined_df = pd.concat([valid_results[p]['cell_metrics_df'] for p in valid_results])

    # Calculate percentage metrics
    combined_df['membrane_coverage_percentage'] = combined_df['membrane_coverage_ratio'] * 100
    combined_df['boundary_coverage_percentage'] = combined_df['boundary_coverage_ratio'] * 100
    combined_df['interior_coverage_percentage'] = combined_df['interior_coverage_ratio'] * 100

    # Calculate membrane distribution percentages
    combined_df['total_membrane'] = combined_df['boundary_membrane_area'] + combined_df['interior_membrane_area']
    combined_df['boundary_membrane_percentage'] = combined_df['boundary_membrane_area'] / combined_df['total_membrane'] * 100
    combined_df['interior_membrane_percentage'] = combined_df['interior_membrane_area'] / combined_df['total_membrane'] * 100

    # Calculate normalized metrics for better comparisons
    if len(combined_df) > 0:
        min_area = combined_df['cell_area'].min()
        max_area = combined_df['cell_area'].max()
        if max_area > min_area:
            combined_df['normalized_area'] = (combined_df['cell_area'] - min_area) / (max_area - min_area)

    # Save combined dataframe
    analysis_dir = os.path.join(output_dir, "analysis")
    os.makedirs(analysis_dir, exist_ok=True)

    combined_csv_path = os.path.join(analysis_dir, "all_combined_cell_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"Saved combined metrics to {combined_csv_path}")

    return combined_df

def create_membrane_coverage_plots(df, output_dir):
    """Creates plots showing membrane coverage by cell size, clustered by pressure."""
    if df is None or df.empty:
        print("No data for membrane coverage plots.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Create scatter plot of cell area vs membrane coverage percentage
    if 'cell_area' in df.columns and 'membrane_coverage_percentage' in df.columns:
        plt.figure(figsize=(12, 10))

        # Create scatter plot with pressure as hue
        scatter = sns.scatterplot(
            x='cell_area',
            y='membrane_coverage_percentage',
            hue='pressure',
            style='pressure',
            s=150,  # Larger point size
            alpha=0.8,
            palette=['#2D68C4', '#F2B950'],  # Blue and gold colors
            data=df
        )

        # Add regression lines for each pressure group
        for pressure in df['pressure'].unique():
            subset = df[df['pressure'] == pressure]
            if len(subset) > 1:  # Need at least 2 points for regression
                try:
                    sns.regplot(
                        x='cell_area',
                        y='membrane_coverage_percentage',
                        data=subset,
                        scatter=False,
                        line_kws={'linestyle': '--', 'linewidth': 2},
                        color='#2D68C4' if pressure == '0Pa' else '#F2B950'
                    )

                    # Calculate correlation and display
                    corr, p = stats.pearsonr(subset['cell_area'], subset['membrane_coverage_percentage'])
                    plt.annotate(
                        f"{pressure}: r={corr:.2f}, p={p:.3f}",
                        xy=(0.05, 0.95 - 0.06 * list(df['pressure'].unique()).index(pressure)),
                        xycoords='axes fraction',
                        ha='left',
                        fontsize=12,
                        bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
                    )
                except Exception as e:
                    print(f"Could not calculate correlation for {pressure}: {e}")

        # Calculate and display group means
        for i, pressure in enumerate(df['pressure'].unique()):
            subset = df[df['pressure'] == pressure]
            x_mean = subset['cell_area'].mean()
            y_mean = subset['membrane_coverage_percentage'].mean()

            plt.scatter(
                x_mean, y_mean,
                s=300,
                color='white',
                edgecolors='#2D68C4' if pressure == '0Pa' else '#F2B950',
                linewidths=2,
                marker='X',
                zorder=10
            )
            plt.annotate(
                f"{pressure} mean",
                (x_mean, y_mean),
                xytext=(10, 10),
                textcoords='offset points',
                fontsize=12,
                bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
            )

        # Improve plot aesthetics
        plt.title('Cell Size vs Membrane Coverage Percentage', fontsize=16)
        plt.xlabel('Cell Area (pixels²)', fontsize=14)
        plt.ylabel('Membrane Coverage (%)', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.legend(title='Pressure', fontsize=12, title_fontsize=13)

        plt.tight_layout()
        plt.savefig(os.path.join(viz_dir, "cell_size_vs_membrane_coverage.png"), dpi=300, bbox_inches='tight')
        print("Saved: cell_size_vs_membrane_coverage.png")
        plt.close()

    # Create interior vs boundary membrane distribution plot
    if 'interior_coverage_percentage' in df.columns and 'boundary_coverage_percentage' in df.columns:
        plt.figure(figsize=(12, 10))

        # Create scatter plot
        scatter = sns.scatterplot(
            x='interior_coverage_percentage',
            y='boundary_coverage_percentage',
            hue='pressure',
            style='pressure',
            s=150,
            alpha=0.8,
            palette=['#2D68C4', '#F2B950'],
            data=df
        )

        # Add diagonal line
        x_vals = np.array(plt.xlim())
        y_vals = np.array(plt.ylim())
        max_val = max(x_vals.max(), y_vals.max())
        min_val = min(x_vals.min(), y_vals.min())
        plt.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.3)

        # Add explanatory annotations
        plt.annotate('Boundary-dominated', xy=(0.25, 0.75), xycoords='axes fraction', fontsize=14)
        plt.annotate('Interior-dominated', xy=(0.75, 0.25), xycoords='axes fraction', fontsize=14)

        # Calculate and display group centroids
        for pressure in df['pressure'].unique():
            subset = df[df['pressure'] == pressure]
            x_mean = subset['interior_coverage_percentage'].mean()
            y_mean = subset['boundary_coverage_percentage'].mean()

            plt.scatter(
                x_mean, y_mean,
                s=300,
                color='white',
                edgecolors='#2D68C4' if pressure == '0Pa' else '#F2B950',
                linewidths=2,
                marker='X',
                zorder=10
            )
            plt.annotate(
                f"{pressure} centroid",
                (x_mean, y_mean),
                xytext=(10, 10),
                textcoords='offset points',
                fontsize=12,
                bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
            )

        plt.title('Membrane Distribution: Interior vs Boundary Coverage', fontsize=16)
        plt.xlabel('Interior Coverage (%)', fontsize=14)
        plt.ylabel('Boundary Coverage (%)', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.legend(title='Pressure', fontsize=12, title_fontsize=13)

        plt.tight_layout()
        plt.savefig(os.path.join(viz_dir, "interior_vs_boundary_coverage.png"), dpi=300, bbox_inches='tight')
        print("Saved: interior_vs_boundary_coverage.png")
        plt.close()

def create_membrane_composition_plot(df, output_dir):
    """Creates stacked bar chart showing membrane distribution between boundary and interior."""
    if df is None or df.empty:
        print("No data for membrane composition plot.")
        return

    if 'boundary_membrane_area' not in df.columns or 'interior_membrane_area' not in df.columns:
        print("Missing required columns for membrane composition plot.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Calculate membrane distribution percentages if they don't exist
    if 'boundary_membrane_percentage' not in df.columns:
        df['total_membrane'] = df['boundary_membrane_area'] + df['interior_membrane_area']
        df['boundary_membrane_percentage'] = df['boundary_membrane_area'] / df['total_membrane'] * 100
        df['interior_membrane_percentage'] = df['interior_membrane_area'] / df['total_membrane'] * 100

    # Group by pressure and calculate means
    grouped = df.groupby('pressure').agg(
        boundary_percent=('boundary_membrane_percentage', 'mean'),
        interior_percent=('interior_membrane_percentage', 'mean'),
        boundary_std=('boundary_membrane_percentage', 'std'),
        interior_std=('interior_membrane_percentage', 'std'),
        cell_count=('cell_id', 'count')
    ).reset_index()

    # Create stacked bar chart
    plt.figure(figsize=(10, 8))
    bar_width = 0.6

    # Create bars
    bars1 = plt.bar(
        grouped['pressure'],
        grouped['boundary_percent'],
        bar_width,
        yerr=grouped['boundary_std'],
        capsize=5,
        label='Boundary Membrane',
        color='#5975A4',
        alpha=0.8
    )

    bars2 = plt.bar(
        grouped['pressure'],
        grouped['interior_percent'],
        bar_width,
        yerr=grouped['interior_std'],
        capsize=5,
        bottom=grouped['boundary_percent'],
        label='Interior Membrane',
        color='#5F9E6E',
        alpha=0.8
    )

    # Add data labels to bars
    for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
        height1 = bar1.get_height()
        height2 = bar2.get_height()

        plt.text(
            bar1.get_x() + bar1.get_width()/2.,
            height1/2,
            f'{height1:.1f}%',
            ha='center',
            va='center',
            color='white',
            fontweight='bold'
        )

        plt.text(
            bar2.get_x() + bar2.get_width()/2.,
            height1 + height2/2,
            f'{height2:.1f}%',
            ha='center',
            va='center',
            color='white',
            fontweight='bold'
        )

        # Add cell count
        plt.text(
            bar1.get_x() + bar1.get_width()/2.,
            -5,
            f'n={grouped["cell_count"].iloc[i]}',
            ha='center',
            va='top'
        )

    plt.title('Membrane Distribution Between Boundary and Interior', fontsize=16)
    plt.ylabel('Percentage of Total Membrane (%)', fontsize=14)
    plt.ylim(0, 110)  # Leave room for error bars
    plt.grid(axis='y', alpha=0.3)
    plt.legend(fontsize=12)

    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, "membrane_composition_by_pressure.png"), dpi=300, bbox_inches='tight')
    print("Saved: membrane_composition_by_pressure.png")
    plt.close()

def create_size_category_analysis(df, output_dir):
    """Groups cells by size categories and analyzes membrane metrics by category."""
    if df is None or df.empty:
        print("No data for size category analysis.")
        return

    if 'cell_area' not in df.columns:
        print("Missing cell_area column for size category analysis.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Create size categories based on cell area
    try:
        # Determine appropriate number of categories based on data size
        n_categories = min(4, max(2, len(df) // 3))  # At least 2, at most 4, ideally ~3 cells per category

        # Create size bins
        df['size_category'] = pd.qcut(
            df['cell_area'],
            q=n_categories,
            labels=[f'Size {i+1}' for i in range(n_categories)]
        )

        # Get the actual bin ranges for annotation
        size_bins = pd.qcut(df['cell_area'], q=n_categories)
        bin_ranges = [f"{int(b.left)}-{int(b.right)}" for b in size_bins.categories]

        # Calculate average membrane metrics by size category
        size_stats = df.groupby(['pressure', 'size_category']).agg(
            mean_membrane_coverage=('membrane_coverage_percentage', 'mean'),
            std_membrane_coverage=('membrane_coverage_percentage', 'std'),
            mean_boundary_coverage=('boundary_coverage_percentage', 'mean'),
            std_boundary_coverage=('boundary_coverage_percentage', 'std'),
            mean_interior_coverage=('interior_coverage_percentage', 'mean'),
            std_interior_coverage=('interior_coverage_percentage', 'std'),
            cell_count=('cell_id', 'count')
        ).reset_index()

        # Create bar plot of membrane coverage by size and pressure
        plt.figure(figsize=(14, 8))

        # Set positions for bars
        n_pressures = len(df['pressure'].unique())
        n_sizes = len(df['size_category'].unique())
        bar_width = 0.35

        # Prepare x positions
        x_pos = np.arange(n_sizes)

        # Plot bars for each pressure
        for i, pressure in enumerate(df['pressure'].unique()):
            pressure_data = size_stats[size_stats['pressure'] == pressure]

            offset = (i - (n_pressures-1)/2) * bar_width

            plt.bar(
                x_pos + offset,
                pressure_data['mean_membrane_coverage'],
                bar_width,
                yerr=pressure_data['std_membrane_coverage'],
                capsize=5,
                label=pressure,
                color=['#2D68C4', '#F2B950'][i % 2]
            )

            # Add cell count labels
            for j, count in enumerate(pressure_data['cell_count']):
                plt.text(
                    x_pos[j] + offset,
                    pressure_data['mean_membrane_coverage'].iloc[j] +
                    pressure_data['std_membrane_coverage'].iloc[j] + 1,
                    f'n={count}',
                    ha='center',
                    va='bottom',
                    fontsize=9
                )

        plt.title('Membrane Coverage by Cell Size and Pressure', fontsize=16)
        plt.xlabel('Cell Size Category', fontsize=14)
        plt.ylabel('Membrane Coverage (%)', fontsize=14)
        plt.xticks(x_pos, [f'{cat}\n({bin_ranges[i]})' for i, cat in enumerate(df['size_category'].unique())])
        plt.grid(axis='y', alpha=0.3)
        plt.legend(title='Pressure')

        plt.tight_layout()
        plt.savefig(os.path.join(viz_dir, "membrane_coverage_by_size_category.png"), dpi=300, bbox_inches='tight')
        print("Saved: membrane_coverage_by_size_category.png")
        plt.close()

    except Exception as e:
        print(f"Error in size category analysis: {e}")

def create_membrane_distribution_scatter(df, output_dir):
    """Creates enhanced scatter plot showing membrane distribution patterns."""
    if df is None or df.empty:
        print("No data for membrane distribution scatter.")
        return

    if 'cell_area' not in df.columns or 'boundary_membrane_percentage' not in df.columns:
        print("Missing required columns for membrane distribution scatter.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Create scatter plot of cell size vs boundary membrane percentage
    plt.figure(figsize=(12, 10))

    # Create scatter with pressure as hue and size proportional to cell area
    scatter = sns.scatterplot(
        x='cell_area',
        y='boundary_membrane_percentage',
        hue='pressure',
        size='membrane_coverage_percentage' if 'membrane_coverage_percentage' in df.columns else None,
        sizes=(100, 400),
        alpha=0.7,
        palette=['#2D68C4', '#F2B950'],
        data=df
    )

    # Add regression lines
    for pressure in df['pressure'].unique():
        subset = df[df['pressure'] == pressure]
        if len(subset) > 1:
            try:
                sns.regplot(
                    x='cell_area',
                    y='boundary_membrane_percentage',
                    data=subset,
                    scatter=False,
                    line_kws={'linestyle': '--', 'linewidth': 2},
                    color='#2D68C4' if pressure == '0Pa' else '#F2B950'
                )

                # Calculate correlation and display
                corr, p = stats.pearsonr(subset['cell_area'], subset['boundary_membrane_percentage'])
                plt.annotate(
                    f"{pressure}: r={corr:.2f}, p={p:.3f}",
                    xy=(0.05, 0.95 - 0.06 * list(df['pressure'].unique()).index(pressure)),
                    xycoords='axes fraction',
                    ha='left',
                    fontsize=12,
                    bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
                )
            except Exception as e:
                print(f"Could not calculate correlation for {pressure}: {e}")

    plt.title('Cell Size vs Boundary Membrane Distribution', fontsize=16)
    plt.xlabel('Cell Area (pixels²)', fontsize=14)
    plt.ylabel('Boundary Membrane (%)', fontsize=14)
    plt.grid(True, alpha=0.3)

    # Add reference line showing typical boundary membrane percentage
    avg_boundary = df['boundary_membrane_percentage'].mean()
    plt.axhline(avg_boundary, color='gray', linestyle='--', alpha=0.5)
    plt.annotate(
        f'Mean: {avg_boundary:.1f}%',
        xy=(0.95, avg_boundary),
        xycoords=('axes fraction', 'data'),
        ha='right',
        bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
    )

    plt.legend(title='Pressure', fontsize=12, title_fontsize=13)
    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, "cell_size_vs_boundary_membrane.png"), dpi=300, bbox_inches='tight')
    print("Saved: cell_size_vs_boundary_membrane.png")
    plt.close()

def create_pressure_comparison_plots(df, output_dir):
    """Creates plots comparing key metrics between pressure groups."""
    if df is None or df.empty:
        print("No data for pressure comparison plots.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Create box plots comparing key metrics between pressure groups
    metrics_to_compare = [
        ('membrane_coverage_percentage', 'Total Membrane Coverage (%)'),
        ('boundary_coverage_percentage', 'Boundary Membrane Coverage (%)'),
        ('interior_coverage_percentage', 'Interior Membrane Coverage (%)'),
        ('boundary_membrane_percentage', 'Boundary Membrane (% of Total Membrane)'),
        ('interior_membrane_percentage', 'Interior Membrane (% of Total Membrane)')
    ]

    for metric, title in metrics_to_compare:
        if metric not in df.columns:
            continue

        plt.figure(figsize=(10, 8))

        # Create box plot
        box_plot = sns.boxplot(
            x='pressure',
            y=metric,
            data=df,
            palette=['#2D68C4', '#F2B950'],
            width=0.5
        )

        # Add individual data points
        sns.stripplot(
            x='pressure',
            y=metric,
            data=df,
            size=7,
            color='black',
            alpha=0.5,
            jitter=True
        )

        # Add mean values as text
        for i, pressure in enumerate(df['pressure'].unique()):
            subset = df[df['pressure'] == pressure]
            mean_val = subset[metric].mean()
            std_val = subset[metric].std()
            count = len(subset)

            plt.text(
                i,
                df[metric].min() - (df[metric].max() - df[metric].min()) * 0.1,
                f"Mean: {mean_val:.1f}%\nStd: {std_val:.1f}%\nn={count}",
                ha='center',
                va='top',
                fontsize=10,
                bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
            )

        # Run t-test if there are two pressure groups
        if len(df['pressure'].unique()) == 2:
            group1 = df[df['pressure'] == df['pressure'].unique()[0]][metric]
            group2 = df[df['pressure'] == df['pressure'].unique()[1]][metric]

            try:
                t_stat, p_val = stats.ttest_ind(group1, group2, equal_var=False)
                significance = "significant" if p_val < 0.05 else "not significant"

                plt.title(f"{title} by Pressure (p={p_val:.3f}, {significance})", fontsize=16)
            except:
                plt.title(f"{title} by Pressure", fontsize=16)
        else:
            plt.title(f"{title} by Pressure", fontsize=16)

        plt.xlabel('Pressure', fontsize=14)
        plt.ylabel(title, fontsize=14)
        plt.grid(True, axis='y', alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(viz_dir, f"pressure_comparison_{metric}.png"), dpi=300, bbox_inches='tight')
        print(f"Saved: pressure_comparison_{metric}.png")
        plt.close()

def create_ratio_plots(df, output_dir):
    """Creates plots showing membrane distribution ratios."""
    if df is None or df.empty:
        print("No data for ratio plots.")
        return

    viz_dir = os.path.join(output_dir, "improved_visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # Calculate membrane boundary-to-interior ratio if not exists
    if 'membrane_boundary_to_interior_ratio' not in df.columns and 'boundary_membrane_area' in df.columns and 'interior_membrane_area' in df.columns:
        df['membrane_boundary_to_interior_ratio'] = df['boundary_membrane_area'] / df['interior_membrane_area'].replace(0, np.nan)

    if 'membrane_boundary_to_interior_ratio' in df.columns:
        # Create box plot of boundary-to-interior ratio
        plt.figure(figsize=(10, 8))

        # Create box plot
        box_plot = sns.boxplot(
            x='pressure',
            y='membrane_boundary_to_interior_ratio',
            data=df,
            palette=['#2D68C4', '#F2B950'],
            width=0.5
        )

        # Add individual data points
        sns.stripplot(
            x='pressure',
            y='membrane_boundary_to_interior_ratio',
            data=df,
            size=7,
            color='black',
            alpha=0.5,
            jitter=True
        )

        # Add reference line at ratio=1 (equal boundary and interior)
        plt.axhline(1, color='gray', linestyle='--', alpha=0.5)
        plt.annotate(
            'Equal boundary & interior',
            xy=(0.5, 1),
            xytext=(0, 10),
            textcoords='offset points',
            ha='center',
            bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7)
        )

        plt.title('Boundary-to-Interior Membrane Ratio by Pressure', fontsize=16)
        plt.xlabel('Pressure', fontsize=14)
        plt.ylabel('Boundary/Interior Ratio', fontsize=14)
        plt.grid(True, axis='y', alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(viz_dir, "boundary_to_interior_ratio.png"), dpi=300, bbox_inches='tight')
        print("Saved: boundary_to_interior_ratio.png")
        plt.close()

# --- Main Execution Block ---

def main():
    """Main execution function."""
    print("\n======= Starting Cell Membrane Analysis Script =======")

    # 1. Find and organize mask files
    pressure_file_dict = find_mask_files(cell_mask_dir, membrane_dir)

    if pressure_file_dict is None:
        print("\nSCRIPT HALTED: Could not find or access mask files. Please check paths.")
        return

    # 2. Process files for each pressure group
    all_pressure_results = {}
    for pressure, file_list in pressure_file_dict.items():
        if not file_list:
            print(f"\nNo files found for pressure group: {pressure}. Skipping.")
            all_pressure_results[pressure] = {
                'pressure': pressure,
                'total_samples_processed': 0,
                'total_cells': 0,
                'all_cell_metrics': [],
                'cell_metrics_df': pd.DataFrame()
            }
            continue

        pressure_results = process_pressure_group(file_list, pressure, output_dir)
        all_pressure_results[pressure] = pressure_results

    # 3. Create combined dataset and calculate percentage metrics
    combined_df = create_combined_dataframe(all_pressure_results, output_dir)

    if combined_df is None or combined_df.empty:
        print("\nNo data was successfully processed. Skipping visualization.")
        return

    # 4. Create visualization plots
    print("\n--- Creating Enhanced Visualizations ---")

    # Cell size vs membrane coverage plots
    create_membrane_coverage_plots(combined_df, output_dir)

    # Membrane composition plots
    create_membrane_composition_plot(combined_df, output_dir)

    # Size category analysis
    create_size_category_analysis(combined_df, output_dir)

    # Membrane distribution scatter
    create_membrane_distribution_scatter(combined_df, output_dir)

    # Pressure comparison plots
    create_pressure_comparison_plots(combined_df, output_dir)

    # Membrane ratio plots
    create_ratio_plots(combined_df, output_dir)

    print("\n======= Analysis Script Finished =======")
    print(f"Outputs saved in: {output_dir}")
    print("Please check the 'improved_visualizations' subfolder for enhanced plots")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


--- Finding and Pairing Mask Files ---
Found 8 cell mask files and 8 membrane mask files
Total matching cell-membrane file pairs found: 8
  0Pa: 3 pairs
  1.4Pa: 5 pairs

=== Processing 0Pa Group (3 file pairs) ===
Successfully processed: 3 pairs
Failed/Skipped: 0 pairs
Total cells analyzed: 3
Saved 0Pa metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/0Pa_cell_metrics.csv

=== Processing 1.4Pa Group (5 file pairs) ===
Successfully processed: 5 pairs
Failed/Skipped: 0 pairs
Total cells analyzed: 5
Saved 1.4Pa metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation/1.4Pa_cell_metrics.csv

--- Creating Combined Analysis Dataset ---
Saved combined metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Ce

  corr, p = stats.pearsonr(subset['cell_area'], subset['membrane_coverage_percentage'])
  corr, p = stats.pearsonr(subset['cell_area'], subset['membrane_coverage_percentage'])


Saved: cell_size_vs_membrane_coverage.png
Saved: interior_vs_boundary_coverage.png
Saved: membrane_composition_by_pressure.png
Error in size category analysis: Bin edges must be unique: Index([1048575.0, 1048575.0, 1048575.0], dtype='float64', name='cell_area').
You can drop duplicate edges by setting the 'duplicates' kwarg


  corr, p = stats.pearsonr(subset['cell_area'], subset['boundary_membrane_percentage'])
  corr, p = stats.pearsonr(subset['cell_area'], subset['boundary_membrane_percentage'])


Saved: cell_size_vs_boundary_membrane.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: pressure_comparison_membrane_coverage_percentage.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: pressure_comparison_boundary_coverage_percentage.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: pressure_comparison_interior_coverage_percentage.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: pressure_comparison_boundary_membrane_percentage.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: pressure_comparison_interior_membrane_percentage.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  box_plot = sns.boxplot(


Saved: boundary_to_interior_ratio.png

Outputs saved in: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Membrane_Size_Relation
Please check the 'improved_visualizations' subfolder for enhanced plots


In [26]:
# Cell-Nuclei Linking and Feature Analysis Script
# Links cell masks with corresponding nuclei and extracts relevant features

# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from skimage import io, measure, segmentation
from scipy import ndimage, stats
import re
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Set visualization style
sns.set_theme()
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100

# Define input and output directories
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
nuclei_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Nuclei_Relation"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# --- Helper Functions ---

def extract_pressure(filename):
    """Extracts '0Pa' or '1.4Pa' from a filename."""
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    return match.group(1) if match else None

def extract_sample_id(filename):
    """Extracts a unique sample identifier from a filename to match pairs."""
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    return match.group(1) if match else None

def find_mask_files(cell_dir, nuclei_dir):
    """Finds and pairs cell and nuclei mask files based on sample ID and pressure."""
    print("\n--- Finding and Pairing Mask Files ---")
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    cell_files = [f for f in os.listdir(cell_dir) if f.endswith(('_cell_mask_merged_conservative.tif', '_cell_mask.tif')) and not f.startswith('.')]
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith(('_filtered_mask.tif', '.tif')) and not f.startswith('.')]

    print(f"Found {len(cell_files)} cell mask files and {len(nuclei_files)} nuclei mask files")

    # Create lookup dictionary for nuclei files
    nuclei_lookup = {}
    for nuclei_file in nuclei_files:
        sample_id = extract_sample_id(nuclei_file)
        if sample_id:
            nuclei_lookup[sample_id] = nuclei_file

    # Match cell files to nuclei files
    pairs_found = 0
    processed_cell_ids = set()
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if sample_id in processed_cell_ids:
            continue

        if pressure and pressure in pressure_dict and sample_id:
            if sample_id in nuclei_lookup:
                nuclei_file = nuclei_lookup[sample_id]
                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'nuclei_file': os.path.join(nuclei_dir, nuclei_file),
                    'sample_id': sample_id
                }
                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                processed_cell_ids.add(sample_id)

    print(f"Total matching cell-nuclei file pairs found: {pairs_found}")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} pairs")

    return pressure_dict

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8)
        elif np.max(img) > 1:
            img = (img > 0).astype(np.uint8)
        else:
            img = img.astype(np.uint8)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]
            img = (img > 0).astype(np.uint8)

        return img
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def link_cells_to_nuclei(cell_mask, nuclei_mask):
    """Links individual cells to nuclei based on spatial overlap."""
    # Label each mask
    labeled_cells, num_cells = ndimage.label(cell_mask)
    labeled_nuclei, num_nuclei = ndimage.label(nuclei_mask)

    if num_cells == 0 or num_nuclei == 0:
        return [], [], 0, 0

    # Find properties of each labeled object
    cell_props = measure.regionprops(labeled_cells)
    nuclei_props = measure.regionprops(labeled_nuclei)

    # Initialize results
    cell_metrics = []
    cell_nuclei_links = []
    cells_with_nuclei = 0
    multi_nuclei_cells = 0

    # For each cell, find associated nuclei
    for cell in cell_props:
        cell_id = cell.label
        cell_mask_binary = (labeled_cells == cell_id)

        # Find nuclei that overlap with this cell
        overlapping_nuclei = []
        for nucleus in nuclei_props:
            nucleus_id = nucleus.label
            nucleus_mask_binary = (labeled_nuclei == nucleus_id)

            # Check overlap
            overlap = np.logical_and(cell_mask_binary, nucleus_mask_binary)
            overlap_area = np.sum(overlap)

            if overlap_area > 0:
                nucleus_area = nucleus.area
                overlap_ratio = overlap_area / nucleus_area

                # Only count if significant overlap (>50% of nucleus in cell)
                if overlap_ratio > 0.5:
                    overlapping_nuclei.append({
                        'nucleus_id': nucleus_id,
                        'nucleus_area': nucleus_area,
                        'nucleus_centroid': nucleus.centroid,
                        'overlap_area': overlap_area,
                        'overlap_ratio': overlap_ratio
                    })

        # Calculate metrics for this cell
        cell_area = cell.area
        cell_perimeter = cell.perimeter
        equivalent_diameter = cell.equivalent_diameter

        # Calculate shape metrics
        circularity = (4 * np.pi * cell_area) / (cell_perimeter ** 2) if cell_perimeter > 0 else 0

        # Count nuclei in this cell
        nuclei_count = len(overlapping_nuclei)
        has_nucleus = nuclei_count > 0
        has_multiple_nuclei = nuclei_count > 1

        if has_nucleus:
            cells_with_nuclei += 1
            if has_multiple_nuclei:
                multi_nuclei_cells += 1

        # Total nuclei area in this cell
        total_nuclei_area = sum(n['nucleus_area'] for n in overlapping_nuclei)

        # Nuclear-to-cytoplasmic ratio
        nuc_cyto_ratio = total_nuclei_area / (cell_area - total_nuclei_area) if cell_area > total_nuclei_area else 0

        # Distance to nucleus (use closest if multiple)
        distance_to_nucleus = None
        if nuclei_count > 0:
            cell_centroid = np.array(cell.centroid)
            distances = [np.linalg.norm(cell_centroid - np.array(n['nucleus_centroid'])) for n in overlapping_nuclei]
            distance_to_nucleus = min(distances)

        metrics = {
            'cell_id': cell_id,
            'cell_area': cell_area,
            'cell_perimeter': cell_perimeter,
            'equivalent_diameter': equivalent_diameter,
            'circularity': circularity,
            'nuclei_count': nuclei_count,
            'has_nucleus': has_nucleus,
            'has_multiple_nuclei': has_multiple_nuclei,
            'total_nuclei_area': total_nuclei_area if has_nucleus else 0,
            'nuclei_to_cell_area_ratio': total_nuclei_area / cell_area if cell_area > 0 and has_nucleus else 0,
            'nuclear_cytoplasmic_ratio': nuc_cyto_ratio,
            'distance_to_nucleus': distance_to_nucleus
        }

        cell_metrics.append(metrics)

        # Record links
        for nucleus in overlapping_nuclei:
            link = {
                'cell_id': cell_id,
                'nucleus_id': nucleus['nucleus_id'],
                'overlap_area': nucleus['overlap_area'],
                'overlap_ratio': nucleus['overlap_ratio']
            }
            cell_nuclei_links.append(link)

    return cell_metrics, cell_nuclei_links, cells_with_nuclei, multi_nuclei_cells

def process_file_pair(file_pair, pressure):
    """Loads masks for a pair, handles resizing, runs analysis, returns metrics."""
    sample_id = file_pair['sample_id']

    # Load masks
    cell_mask = load_mask_image(file_pair['cell_file'])
    nuclei_mask = load_mask_image(file_pair['nuclei_file'])

    if cell_mask is None or nuclei_mask is None:
        return None

    # Ensure masks have same dimensions - Resize nuclei mask if necessary
    if cell_mask.shape != nuclei_mask.shape:
        try:
            nuclei_mask = cv2.resize(nuclei_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                     interpolation=cv2.INTER_NEAREST)
        except Exception as e:
            print(f"Error resizing nuclei mask for {sample_id}: {e}")
            return None

    # Link cells to nuclei and analyze relationship
    cell_metrics, cell_nuclei_links, cells_with_nuclei, multi_nuclei_cells = link_cells_to_nuclei(cell_mask, nuclei_mask)

    # Add sample information to metrics
    for metrics in cell_metrics:
        metrics['sample_id'] = sample_id
        metrics['pressure'] = pressure

    for link in cell_nuclei_links:
        link['sample_id'] = sample_id
        link['pressure'] = pressure

    return {
        'cell_metrics': cell_metrics,
        'cell_nuclei_links': cell_nuclei_links,
        'cells_with_nuclei': cells_with_nuclei,
        'multi_nuclei_cells': multi_nuclei_cells,
        'total_cells': len(cell_metrics)
    }

def process_pressure_group(file_pairs, pressure, output_dir):
    """Processes all file pairs for a specific pressure group."""
    print(f"\n=== Processing {pressure} Group ({len(file_pairs)} file pairs) ===")

    all_cell_metrics = []
    all_cell_nuclei_links = []
    total_cells = 0
    total_cells_with_nuclei = 0
    total_multi_nuclei_cells = 0

    processed_count = 0
    failed_count = 0

    for file_pair in file_pairs:
        results = process_file_pair(file_pair, pressure)
        if results is not None:
            all_cell_metrics.extend(results['cell_metrics'])
            all_cell_nuclei_links.extend(results['cell_nuclei_links'])
            total_cells += results['total_cells']
            total_cells_with_nuclei += results['cells_with_nuclei']
            total_multi_nuclei_cells += results['multi_nuclei_cells']
            processed_count += 1
        else:
            failed_count += 1

    print(f"Successfully processed: {processed_count} pairs")
    print(f"Failed/Skipped: {failed_count} pairs")
    print(f"Total cells analyzed: {total_cells}")
    print(f"Cells with nuclei: {total_cells_with_nuclei} ({100*total_cells_with_nuclei/total_cells:.1f}% of cells)")
    print(f"Cells with multiple nuclei: {total_multi_nuclei_cells} ({100*total_multi_nuclei_cells/total_cells:.1f}% of cells)")

    # Create and save dataframes for this pressure group
    if all_cell_metrics:
        cell_metrics_df = pd.DataFrame(all_cell_metrics)
        group_csv_path = os.path.join(output_dir, f"{pressure}_cell_metrics.csv")
        cell_metrics_df.to_csv(group_csv_path, index=False)
        print(f"Saved {pressure} cell metrics to {group_csv_path}")
    else:
        cell_metrics_df = pd.DataFrame()

    if all_cell_nuclei_links:
        links_df = pd.DataFrame(all_cell_nuclei_links)
        links_csv_path = os.path.join(output_dir, f"{pressure}_cell_nuclei_links.csv")
        links_df.to_csv(links_csv_path, index=False)
        print(f"Saved {pressure} cell-nuclei links to {links_csv_path}")
    else:
        links_df = pd.DataFrame()

    return {
        'pressure': pressure,
        'total_samples_processed': processed_count,
        'total_cells': total_cells,
        'cells_with_nuclei': total_cells_with_nuclei,
        'multi_nuclei_cells': total_multi_nuclei_cells,
        'cell_metrics_df': cell_metrics_df,
        'links_df': links_df
    }

def create_combined_dataframe(pressure_results, output_dir):
    """Creates a combined dataframe with all results and saves it."""
    print("\n--- Creating Combined Dataset ---")

    # Combine all cell metrics
    all_dfs = []
    for pressure, results in pressure_results.items():
        if not results['cell_metrics_df'].empty:
            all_dfs.append(results['cell_metrics_df'])

    if not all_dfs:
        print("No data available to create combined dataframe.")
        return None

    combined_df = pd.concat(all_dfs, ignore_index=True)

    # Save combined dataset
    combined_csv_path = os.path.join(output_dir, "combined_cell_nuclei_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"Saved combined metrics to {combined_csv_path}")

    return combined_df

def create_visualizations(combined_df, pressure_results, output_dir):
    """Create comprehensive visualizations of the cell-nuclei analysis."""
    print("\n--- Creating Visualizations ---")

    viz_dir = os.path.join(output_dir, "visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    if combined_df is None or combined_df.empty:
        print("No data available for visualization.")
        return

    # Safely create each visualization with error handling
    try:
        # 1. Cell Size vs. Nuclei Count
        plt.figure(figsize=(12, 8))
        # Use countplot if we only have small number of cells
        if len(combined_df) <= 8:
            ax = sns.stripplot(x='nuclei_count', y='cell_area', data=combined_df, jitter=True)
        else:
            # Use hue parameter instead of palette directly
            ax = sns.boxplot(x='nuclei_count', y='cell_area', hue='nuclei_count', data=combined_df,
                           palette='viridis', legend=False)
        ax.set_title('Cell Size vs. Number of Nuclei', fontsize=16)
        ax.set_xlabel('Number of Nuclei', fontsize=14)
        ax.set_ylabel('Cell Area (pixels)', fontsize=14)
        plt.savefig(os.path.join(viz_dir, 'cell_size_vs_nuclei_count.png'), dpi=300, bbox_inches='tight')
        plt.close()
        print("Created cell size vs nuclei count plot")
    except Exception as e:
        print(f"Error creating cell size vs nuclei count plot: {str(e)}")

    try:
        # 2. Nuclei-to-Cell Area Ratio Distribution
        plt.figure(figsize=(12, 8))
        for pressure in ['0Pa', '1.4Pa']:
            pressure_data = combined_df[combined_df['pressure'] == pressure]
            if not pressure_data.empty:
                sns.histplot(data=pressure_data, x='nuclei_to_cell_area_ratio',
                             kde=True, alpha=0.6, label=pressure)
        plt.title('Distribution of Nuclei-to-Cell Area Ratio by Pressure', fontsize=16)
        plt.xlabel('Nuclei to Cell Area Ratio', fontsize=14)
        plt.ylabel('Frequency', fontsize=14)
        plt.legend()
        plt.savefig(os.path.join(viz_dir, 'nuclei_cell_ratio_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()
        print("Created nuclei-to-cell ratio distribution plot")
    except Exception as e:
        print(f"Error creating nuclei-to-cell ratio distribution plot: {str(e)}")

    try:
        # 3. Cell-Nuclei Count Distribution by Pressure
        nuclei_count_data = []
        for pressure, results in pressure_results.items():
            if 'cell_metrics_df' in results and not results['cell_metrics_df'].empty:
                df = results['cell_metrics_df']

                total = len(df)
                if total > 0:  # Prevent division by zero
                    no_nuclei = sum(df['nuclei_count'] == 0)
                    one_nucleus = sum(df['nuclei_count'] == 1)
                    multi_nuclei = sum(df['nuclei_count'] > 1)

                    nuclei_count_data.append({
                        'pressure': pressure,
                        'category': 'No Nucleus',
                        'percentage': 100 * no_nuclei / total if total > 0 else 0,
                        'count': no_nuclei
                    })
                    nuclei_count_data.append({
                        'pressure': pressure,
                        'category': 'One Nucleus',
                        'percentage': 100 * one_nucleus / total if total > 0 else 0,
                        'count': one_nucleus
                    })
                    nuclei_count_data.append({
                        'pressure': pressure,
                        'category': 'Multiple Nuclei',
                        'percentage': 100 * multi_nuclei / total if total > 0 else 0,
                        'count': multi_nuclei
                    })

        if nuclei_count_data:
            nuclei_count_df = pd.DataFrame(nuclei_count_data)

            # Create bar plot
            plt.figure(figsize=(14, 8))
            ax = sns.barplot(x='pressure', y='percentage', hue='category', data=nuclei_count_df)

            # Add count labels safely
            for i, p in enumerate(ax.patches):
                if i < len(nuclei_count_df):  # Safety check
                    height = p.get_height()
                    count = nuclei_count_df.iloc[i]['count']
                    ax.text(p.get_x() + p.get_width()/2., height + 1, f'n={count}',
                            ha="center", fontsize=9)

            ax.set_title('Distribution of Cells by Nuclei Count and Pressure', fontsize=16)
            ax.set_xlabel('Pressure', fontsize=14)
            ax.set_ylabel('Percentage of Cells (%)', fontsize=14)
            plt.legend(title='Category')
            plt.savefig(os.path.join(viz_dir, 'cell_nuclei_count_distribution.png'), dpi=300, bbox_inches='tight')
            plt.close()
            print("Created nuclei count distribution plot")
    except Exception as e:
        print(f"Error creating nuclei count distribution plot: {str(e)}")

    try:
        # 4. Nuclear-Cytoplasmic Ratio vs. Cell Size
        plt.figure(figsize=(12, 8))

        # Only use cells that have nuclei
        valid_cells = combined_df[combined_df['has_nucleus']]

        if not valid_cells.empty:
            # Create scatter plot instead if few data points
            if len(valid_cells) <= 10:
                for pressure in ['0Pa', '1.4Pa']:
                    pressure_data = valid_cells[valid_cells['pressure'] == pressure]
                    if not pressure_data.empty:
                        plt.scatter(pressure_data['cell_area'], pressure_data['nuclear_cytoplasmic_ratio'],
                                  label=pressure, alpha=0.7)
            else:
                # Create scatterplot with regression line by pressure
                for pressure in ['0Pa', '1.4Pa']:
                    pressure_data = valid_cells[valid_cells['pressure'] == pressure]
                    if not pressure_data.empty:
                        sns.regplot(x='cell_area', y='nuclear_cytoplasmic_ratio',
                                  data=pressure_data, scatter_kws={'alpha':0.5},
                                  line_kws={'label': f"{pressure} Trend"}, label=pressure)

            plt.title('Nuclear-Cytoplasmic Ratio vs. Cell Size', fontsize=16)
            plt.xlabel('Cell Area (pixels)', fontsize=14)
            plt.ylabel('Nuclear-Cytoplasmic Ratio', fontsize=14)
            plt.legend()
            plt.savefig(os.path.join(viz_dir, 'nuclear_cytoplasmic_vs_size.png'), dpi=300, bbox_inches='tight')
            print("Created nuclear-cytoplasmic ratio plot")
        plt.close()
    except Exception as e:
        print(f"Error creating nuclear-cytoplasmic ratio plot: {str(e)}")

    try:
        # 5. Distance to Nucleus vs. Cell Size
        plt.figure(figsize=(12, 8))
        valid_cells = combined_df[combined_df['has_nucleus']]

        if not valid_cells.empty:
            # Simple scatter for few points
            if len(valid_cells) <= 10:
                for pressure in ['0Pa', '1.4Pa']:
                    pressure_data = valid_cells[valid_cells['pressure'] == pressure]
                    if not pressure_data.empty:
                        plt.scatter(pressure_data['cell_area'], pressure_data['distance_to_nucleus'],
                                  label=pressure, alpha=0.7)
            else:
                for pressure in ['0Pa', '1.4Pa']:
                    pressure_data = valid_cells[valid_cells['pressure'] == pressure]
                    if not pressure_data.empty:
                        sns.regplot(x='cell_area', y='distance_to_nucleus',
                                  data=pressure_data, scatter_kws={'alpha':0.5},
                                  line_kws={'label': f"{pressure} Trend"}, label=pressure)

            plt.title('Distance to Nucleus vs. Cell Size', fontsize=16)
            plt.xlabel('Cell Area (pixels)', fontsize=14)
            plt.ylabel('Distance to Nucleus (pixels)', fontsize=14)
            plt.legend()
            plt.savefig(os.path.join(viz_dir, 'distance_to_nucleus_vs_size.png'), dpi=300, bbox_inches='tight')
            print("Created distance to nucleus plot")
        plt.close()
    except Exception as e:
        print(f"Error creating distance to nucleus plot: {str(e)}")

    try:
        # 6. Cell Shape (Circularity) vs Nuclei Count
        plt.figure(figsize=(12, 8))

        # Determine actual nuclei count values present in the data
        nuclei_counts = sorted(combined_df['nuclei_count'].unique().tolist())

        if nuclei_counts:  # Only proceed if we have data
            # Use stripplot for few data points
            if len(combined_df) <= 8:
                ax = sns.stripplot(x='nuclei_count', y='circularity', data=combined_df, jitter=True)
            else:
                ax = sns.boxplot(x='nuclei_count', y='circularity', data=combined_df,
                               hue='nuclei_count', palette='viridis', legend=False)

            ax.set_title('Cell Circularity vs. Number of Nuclei', fontsize=16)
            ax.set_xlabel('Number of Nuclei', fontsize=14)
            ax.set_ylabel('Circularity (4π·Area/Perimeter²)', fontsize=14)
            plt.savefig(os.path.join(viz_dir, 'circularity_vs_nuclei.png'), dpi=300, bbox_inches='tight')
            print("Created circularity vs nuclei count plot")
        plt.close()
    except Exception as e:
        print(f"Error creating circularity vs nuclei count plot: {str(e)}")

    print(f"Saved visualizations to {viz_dir}")

# --- Main Execution Block ---

def main():
    """Main execution function."""
    print("\n======= Starting Cell-Nuclei Linking and Analysis Script =======")

    # 1. Find and organize mask files
    pressure_file_dict = find_mask_files(cell_mask_dir, nuclei_dir)

    if not pressure_file_dict:
        print("\nSCRIPT HALTED: Could not find or access mask files. Please check paths.")
        return

    # 2. Process files for each pressure group
    all_pressure_results = {}
    for pressure, file_list in pressure_file_dict.items():
        if not file_list:
            print(f"\nNo files found for pressure group: {pressure}. Skipping.")
            continue

        pressure_results = process_pressure_group(file_list, pressure, output_dir)
        all_pressure_results[pressure] = pressure_results

    # 3. Create combined dataset
    combined_df = create_combined_dataframe(all_pressure_results, output_dir)

    # 4. Create visualization plots
    create_visualizations(combined_df, all_pressure_results, output_dir)

    print("\n======= Analysis Script Finished =======")
    print(f"Outputs saved in: {output_dir}")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


--- Finding and Pairing Mask Files ---
Found 8 cell mask files and 8 nuclei mask files
Total matching cell-nuclei file pairs found: 8
  0Pa: 3 pairs
  1.4Pa: 5 pairs

=== Processing 0Pa Group (3 file pairs) ===
Successfully processed: 3 pairs
Failed/Skipped: 0 pairs
Total cells analyzed: 3
Cells with nuclei: 3 (100.0% of cells)
Cells with multiple nuclei: 3 (100.0% of cells)
Saved 0Pa cell metrics to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Nuclei_Relation/0Pa_cell_metrics.csv
Saved 0Pa cell-nuclei links to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Nuclei_Relation/0Pa_cell_nuclei_links.csv

=== Processing 1.4Pa Group (5 file pairs) ===
Successfully processed: 5 pairs
Failed/Skipped: 0 pairs
Total cells analyzed: 5
Cells with nuclei: 5 (100.0% of cells)
Cells with multiple nucle

In [27]:
# Improved Cell-Nuclei Tracking and Size Analysis
# This script properly identifies nuclei inside cell masks and analyzes the relationship
# between cell size and nuclei count across pressure conditions

# Import libraries
from google.colab import drive
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from skimage import io, measure, segmentation
from scipy import ndimage
import cv2
from pathlib import Path

# Mount Google Drive
drive.mount('/content/drive')

# Set visualization style
sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 10)
plt.rcParams['figure.dpi'] = 100

# Define input and output directories
cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
nuclei_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Nuclei_Size_Analysis"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# --- Helper Functions ---

def extract_pressure(filename):
    """Extracts '0Pa' or '1.4Pa' from a filename."""
    match = re.search(r'(0Pa|1\.4Pa)', str(filename))
    return match.group(1) if match else None

def extract_sample_id(filename):
    """Extracts a unique sample identifier from a filename to match pairs."""
    match = re.search(r'((?:0Pa|1\.4Pa)_U_[^_]+_20x_[^_]+_[^_]+_seq\d+)', str(filename))
    return match.group(1) if match else None

def find_mask_files(cell_dir, nuclei_dir):
    """Finds and pairs cell and nuclei mask files based on sample ID and pressure."""
    print("\n--- Finding and Pairing Mask Files ---")
    pressure_dict = {'0Pa': [], '1.4Pa': []}

    cell_files = [f for f in os.listdir(cell_dir) if f.endswith(('_cell_mask_merged_conservative.tif', '_cell_mask.tif')) and not f.startswith('.')]
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith(('_filtered_mask.tif', '.tif')) and not f.startswith('.')]

    print(f"Found {len(cell_files)} cell mask files and {len(nuclei_files)} nuclei mask files")

    # Create lookup dictionary for nuclei files
    nuclei_lookup = {}
    for nuclei_file in nuclei_files:
        sample_id = extract_sample_id(nuclei_file)
        if sample_id:
            nuclei_lookup[sample_id] = nuclei_file

    # Match cell files to nuclei files
    pairs_found = 0
    processed_cell_ids = set()
    for cell_file in cell_files:
        pressure = extract_pressure(cell_file)
        sample_id = extract_sample_id(cell_file)

        if sample_id in processed_cell_ids:
            continue

        if pressure and pressure in pressure_dict and sample_id:
            if sample_id in nuclei_lookup:
                nuclei_file = nuclei_lookup[sample_id]
                file_pair = {
                    'cell_file': os.path.join(cell_dir, cell_file),
                    'nuclei_file': os.path.join(nuclei_dir, nuclei_file),
                    'sample_id': sample_id
                }
                pressure_dict[pressure].append(file_pair)
                pairs_found += 1
                processed_cell_ids.add(sample_id)

    print(f"Total matching cell-nuclei file pairs found: {pairs_found}")
    for pressure, file_list in pressure_dict.items():
        print(f"  {pressure}: {len(file_list)} pairs")

    return pressure_dict

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8)
        elif np.max(img) > 1:
            # If it's a labeled mask, keep the labels
            if np.max(img) <= 255:
                img = img.astype(np.uint8)
            else:
                img = img.astype(np.uint16)
        else:
            img = img.astype(np.uint8)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]
            # If it's a labeled mask with multiple channels, threshold to get binary
            img = (img > 0).astype(np.uint8)

        return img
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def accurately_track_nuclei_in_cells(cell_mask, nuclei_mask):
    """
    Accurately identifies which nuclei are inside which cells.

    Algorithm:
    1. Label individual cells and nuclei
    2. For each nucleus, find its centroid
    3. Check which cell contains this centroid
    4. If >50% of nucleus area is inside a cell, assign it to that cell
    """
    # Ensure masks are properly labeled
    # For cell_mask, each cell should have a unique ID
    if np.max(cell_mask) <= 1:
        labeled_cells, num_cells = ndimage.label(cell_mask)
    else:
        labeled_cells = cell_mask
        num_cells = np.max(labeled_cells)

    # For nuclei_mask, each nucleus should have a unique ID
    if np.max(nuclei_mask) <= 1:
        labeled_nuclei, num_nuclei = ndimage.label(nuclei_mask)
    else:
        labeled_nuclei = nuclei_mask
        num_nuclei = np.max(labeled_nuclei)

    print(f"Found {num_cells} cells and {num_nuclei} nuclei")

    # Extract properties for cells and nuclei
    cell_props = measure.regionprops(labeled_cells)
    nuclei_props = measure.regionprops(labeled_nuclei)

    # Create a results structure
    results = {
        'cell_data': [],
        'nuclei_data': [],
        'cell_nuclei_mapping': {}
    }

    # Process each cell
    for cell in cell_props:
        cell_id = cell.label
        cell_mask_binary = (labeled_cells == cell_id)

        # Basic cell properties
        cell_data = {
            'cell_id': cell_id,
            'area': cell.area,
            'perimeter': cell.perimeter,
            'eccentricity': cell.eccentricity,
            'orientation': np.degrees(cell.orientation) % 180 if hasattr(cell, 'orientation') else None,
            'major_axis_length': cell.major_axis_length if hasattr(cell, 'major_axis_length') else None,
            'minor_axis_length': cell.minor_axis_length if hasattr(cell, 'minor_axis_length') else None,
            'centroid_y': cell.centroid[0],
            'centroid_x': cell.centroid[1],
            'nuclei_count': 0
        }

        results['cell_data'].append(cell_data)
        results['cell_nuclei_mapping'][cell_id] = []

    # Find which nuclei belong to which cells
    for nucleus in nuclei_props:
        nucleus_id = nucleus.label
        nucleus_mask_binary = (labeled_nuclei == nucleus_id)
        nucleus_area = nucleus.area

        # Find which cell contains this nucleus
        contained_in_cell = None
        max_overlap_ratio = 0

        for cell in cell_props:
            cell_id = cell.label
            cell_mask_binary = (labeled_cells == cell_id)

            # Calculate overlap
            overlap = np.logical_and(cell_mask_binary, nucleus_mask_binary)
            overlap_area = np.sum(overlap)

            # Calculate what percentage of the nucleus is in this cell
            overlap_ratio = overlap_area / nucleus_area

            # If most of the nucleus is in this cell, assign it to this cell
            if overlap_ratio > max_overlap_ratio:
                max_overlap_ratio = overlap_ratio
                contained_in_cell = cell_id

        # Only count the nucleus if a significant portion is inside the cell (>50%)
        if contained_in_cell is not None and max_overlap_ratio > 0.5:
            # Store nucleus data
            nucleus_data = {
                'nucleus_id': nucleus_id,
                'cell_id': contained_in_cell,
                'area': nucleus.area,
                'eccentricity': nucleus.eccentricity if hasattr(nucleus, 'eccentricity') else None,
                'centroid_y': nucleus.centroid[0],
                'centroid_x': nucleus.centroid[1],
                'overlap_ratio': max_overlap_ratio
            }

            results['nuclei_data'].append(nucleus_data)

            # Update cell's nuclei count
            for cell_data in results['cell_data']:
                if cell_data['cell_id'] == contained_in_cell:
                    cell_data['nuclei_count'] += 1
                    break

            # Update cell-nuclei mapping
            results['cell_nuclei_mapping'][contained_in_cell].append(nucleus_id)

    # Count nuclei per cell
    cells_with_nuclei = sum(1 for cell_data in results['cell_data'] if cell_data['nuclei_count'] > 0)
    cells_with_multiple_nuclei = sum(1 for cell_data in results['cell_data'] if cell_data['nuclei_count'] > 1)

    print(f"Cells with nuclei: {cells_with_nuclei}/{num_cells} ({100*cells_with_nuclei/num_cells:.1f}% of cells)")
    print(f"Cells with multiple nuclei: {cells_with_multiple_nuclei}/{num_cells} ({100*cells_with_multiple_nuclei/num_cells:.1f}% of cells)")

    # Create summary of nuclei per cell
    nuclei_counts = [cell_data['nuclei_count'] for cell_data in results['cell_data']]
    unique_counts = sorted(set(nuclei_counts))
    for count in unique_counts:
        cells_with_count = sum(1 for n in nuclei_counts if n == count)
        print(f"  Cells with {count} nuclei: {cells_with_count} ({100*cells_with_count/num_cells:.1f}%)")

    return results

def process_sample(file_pair, pressure):
    """Processes a single sample (cell-nuclei image pair)."""
    sample_id = file_pair['sample_id']
    print(f"\nProcessing sample: {sample_id}")

    # Load masks
    cell_mask = load_mask_image(file_pair['cell_file'])
    nuclei_mask = load_mask_image(file_pair['nuclei_file'])

    if cell_mask is None or nuclei_mask is None:
        print(f"Error loading masks for {sample_id}")
        return None

    # Ensure masks have same dimensions - Resize nuclei mask if necessary
    if cell_mask.shape != nuclei_mask.shape:
        try:
            nuclei_mask = cv2.resize(nuclei_mask, (cell_mask.shape[1], cell_mask.shape[0]),
                                     interpolation=cv2.INTER_NEAREST)
        except Exception as e:
            print(f"Error resizing nuclei mask for {sample_id}: {e}")
            return None

    # Track nuclei inside cells
    results = accurately_track_nuclei_in_cells(cell_mask, nuclei_mask)

    # Add sample info to results
    for cell_data in results['cell_data']:
        cell_data['sample_id'] = sample_id
        cell_data['pressure'] = pressure

    for nucleus_data in results['nuclei_data']:
        nucleus_data['sample_id'] = sample_id
        nucleus_data['pressure'] = pressure

    # Create visualization of the cell-nuclei relationship
    output_sample_dir = os.path.join(output_dir, f"{pressure}_{sample_id}")
    os.makedirs(output_sample_dir, exist_ok=True)

    # Create visualization
    try:
        visualize_cell_nuclei_mapping(cell_mask, nuclei_mask, results,
                                      os.path.join(output_sample_dir, f"{sample_id}_cell_nuclei_mapping.png"))
    except Exception as e:
        print(f"Error creating visualization: {str(e)}")

    return results

def visualize_cell_nuclei_mapping(cell_mask, nuclei_mask, results, output_file=None):
    """Creates a visualization showing which nuclei belong to which cells."""
    plt.figure(figsize=(14, 12))

    # Create a color overlay
    # Background: Black
    # Cells without nuclei: Gray
    # Cells with 1 nucleus: Blue
    # Cells with 2 nuclei: Green
    # Cells with 3+ nuclei: Red

    # Base layer: cell outlines
    cell_outlines = segmentation.find_boundaries(cell_mask)

    # Create an RGB image
    h, w = cell_mask.shape
    rgb_img = np.zeros((h, w, 3), dtype=np.uint8)

    # First, fill all cells with gray
    for cell_data in results['cell_data']:
        cell_id = cell_data['cell_id']
        cell_binary = (cell_mask == cell_id)

        # Choose color based on number of nuclei
        if cell_data['nuclei_count'] == 0:
            color = [130, 130, 130]  # Gray for no nuclei
        elif cell_data['nuclei_count'] == 1:
            color = [100, 100, 255]  # Blue for 1 nucleus
        elif cell_data['nuclei_count'] == 2:
            color = [100, 255, 100]  # Green for 2 nuclei
        else:
            color = [255, 100, 100]  # Red for 3+ nuclei

        # Fill the cell with the appropriate color
        for c in range(3):
            rgb_img[:,:,c] = np.where(cell_binary, color[c], rgb_img[:,:,c])

    # Add cell outlines in white
    for c in range(3):
        rgb_img[:,:,c] = np.where(cell_outlines, 255, rgb_img[:,:,c])

    # Overlay nuclei in yellow
    if nuclei_mask is not None:
        nuclei_outlines = segmentation.find_boundaries(nuclei_mask)
        for c in range(3):
            if c < 2:  # Red and Green channels (for yellow)
                rgb_img[:,:,c] = np.where(nuclei_outlines, 255, rgb_img[:,:,c])
            else:  # Blue channel
                rgb_img[:,:,c] = np.where(nuclei_outlines, 0, rgb_img[:,:,c])

    # Add cell IDs and nuclei counts as text
    plt.imshow(rgb_img)

    for cell_data in results['cell_data']:
        y, x = cell_data['centroid_y'], cell_data['centroid_x']
        nuclei_count = cell_data['nuclei_count']
        plt.text(x, y, f"{nuclei_count}",
                 color='white', fontsize=9, ha='center', va='center',
                 bbox=dict(boxstyle="circle", fc="black", ec="white", alpha=0.6))

    plt.title(f'Cell-Nuclei Mapping (colors: gray=0, blue=1, green=2, red=3+ nuclei)', fontsize=12)
    plt.axis('off')

    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Saved visualization to {output_file}")

    plt.close()

def process_pressure_group(file_pairs, pressure):
    """Processes all samples for a specific pressure group."""
    print(f"\n=== Processing {pressure} Group ({len(file_pairs)} file pairs) ===")

    all_cell_data = []
    all_nuclei_data = []

    for file_pair in file_pairs:
        results = process_sample(file_pair, pressure)
        if results:
            all_cell_data.extend(results['cell_data'])
            all_nuclei_data.extend(results['nuclei_data'])

    # Create summary DataFrames
    if all_cell_data:
        df_cells = pd.DataFrame(all_cell_data)
        df_cells.to_csv(os.path.join(output_dir, f"{pressure}_cell_data.csv"), index=False)
        print(f"Saved {pressure} cell data with {len(df_cells)} cells")
    else:
        df_cells = pd.DataFrame()

    if all_nuclei_data:
        df_nuclei = pd.DataFrame(all_nuclei_data)
        df_nuclei.to_csv(os.path.join(output_dir, f"{pressure}_nuclei_data.csv"), index=False)
        print(f"Saved {pressure} nuclei data with {len(df_nuclei)} nuclei")
    else:
        df_nuclei = pd.DataFrame()

    return {
        'pressure': pressure,
        'cell_data': df_cells,
        'nuclei_data': df_nuclei
    }

def analyze_cell_size_vs_nuclei_count(all_results):
    """Analyzes the relationship between cell size and nuclei count."""
    print("\n--- Analyzing Cell Size vs. Nuclei Count ---")

    # Combine all cell data
    all_cells = []
    for pressure, results in all_results.items():
        if not results['cell_data'].empty:
            all_cells.append(results['cell_data'])

    if not all_cells:
        print("No cell data available for analysis")
        return

    combined_df = pd.concat(all_cells, ignore_index=True)
    combined_df.to_csv(os.path.join(output_dir, "combined_cell_data.csv"), index=False)
    print(f"Combined data from {len(combined_df)} cells")

    # Create visualization directory
    viz_dir = os.path.join(output_dir, "visualizations")
    os.makedirs(viz_dir, exist_ok=True)

    # 1. Distribution of nuclei per cell
    plt.figure(figsize=(12, 8))

    # Get nuclei count distribution by pressure
    nuclei_count_data = []

    for pressure in ['0Pa', '1.4Pa']:
        pressure_data = combined_df[combined_df['pressure'] == pressure]

        if not pressure_data.empty:
            # Group by nuclei count
            nuclei_counts = pressure_data['nuclei_count'].value_counts().sort_index()

            for count, num_cells in nuclei_counts.items():
                nuclei_count_data.append({
                    'pressure': pressure,
                    'nuclei_count': count,
                    'num_cells': num_cells,
                    'percentage': 100 * num_cells / len(pressure_data)
                })

    if nuclei_count_data:
        nuclei_count_df = pd.DataFrame(nuclei_count_data)

        # Create a grouped bar chart
        plt.figure(figsize=(14, 8))
        ax = sns.barplot(x='nuclei_count', y='percentage', hue='pressure', data=nuclei_count_df)

        # Add count annotations
        for i, p in enumerate(ax.patches):
            if i < len(nuclei_count_df):
                height = p.get_height()
                count = nuclei_count_df.iloc[i]['num_cells']
                ax.text(p.get_x() + p.get_width()/2., height + 1, f'n={count}',
                        ha="center", fontsize=9)

        plt.title('Distribution of Nuclei Count per Cell by Pressure', fontsize=16)
        plt.xlabel('Number of Nuclei per Cell', fontsize=14)
        plt.ylabel('Percentage of Cells (%)', fontsize=14)
        plt.savefig(os.path.join(viz_dir, 'nuclei_count_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()

    # 2. Cell Size vs. Nuclei Count
    plt.figure(figsize=(14, 10))

    # Calculate and show average cell size for each nuclei count
    plt.figure(figsize=(14, 10))

    # Use boxplot to show distribution of cell sizes by nuclei count
    # Split by pressure
    ax = sns.boxplot(x='nuclei_count', y='area', hue='pressure', data=combined_df)

    plt.title('Cell Size vs. Number of Nuclei', fontsize=16)
    plt.xlabel('Number of Nuclei per Cell', fontsize=14)
    plt.ylabel('Cell Area (pixels)', fontsize=14)
    plt.savefig(os.path.join(viz_dir, 'cell_size_vs_nuclei_count_boxplot.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 3. Statistical Analysis
    # Calculate average cell size by nuclei count and pressure
    avg_size_by_count = combined_df.groupby(['pressure', 'nuclei_count'])['area'].agg(['mean', 'std', 'count']).reset_index()
    avg_size_by_count.columns = ['pressure', 'nuclei_count', 'mean_area', 'std_area', 'cell_count']
    avg_size_by_count.to_csv(os.path.join(output_dir, "avg_cell_size_by_nuclei_count.csv"), index=False)

    print("\nAverage Cell Size by Nuclei Count:")
    for pressure in ['0Pa', '1.4Pa']:
        pressure_data = avg_size_by_count[avg_size_by_count['pressure'] == pressure]
        if not pressure_data.empty:
            print(f"\n{pressure}:")
            for _, row in pressure_data.iterrows():
                print(f"  {row['nuclei_count']} nuclei: {row['mean_area']:.1f} ± {row['std_area']:.1f} pixels (n={row['cell_count']})")

    # 4. Visualize relationship using scatterplot with trendline
    plt.figure(figsize=(14, 10))

    for pressure in ['0Pa', '1.4Pa']:
        pressure_data = combined_df[combined_df['pressure'] == pressure]
        if len(pressure_data) > 5:
            ax = sns.regplot(x='nuclei_count', y='area', data=pressure_data,
                           scatter_kws={'alpha': 0.5}, line_kws={'label': f"{pressure} Trend"},
                           label=pressure)

    plt.title('Cell Size vs. Number of Nuclei (with Trendline)', fontsize=16)
    plt.xlabel('Number of Nuclei per Cell', fontsize=14)
    plt.ylabel('Cell Area (pixels)', fontsize=14)
    plt.legend()
    plt.savefig(os.path.join(viz_dir, 'cell_size_vs_nuclei_count_trend.png'), dpi=300, bbox_inches='tight')
    plt.close()

    print(f"\nSaved visualizations to {viz_dir}")

def main():
    """Main execution function."""
    print("\n======= Starting Cell Size vs. Nuclei Count Analysis =======")

    # 1. Find and organize mask files
    pressure_file_dict = find_mask_files(cell_mask_dir, nuclei_dir)

    if not pressure_file_dict:
        print("\nSCRIPT HALTED: Could not find or access mask files. Please check paths.")
        return

    # 2. Process files for each pressure group
    all_results = {}
    for pressure, file_list in pressure_file_dict.items():
        if not file_list:
            print(f"\nNo files found for pressure group: {pressure}. Skipping.")
            continue

        results = process_pressure_group(file_list, pressure)
        all_results[pressure] = results

    # 3. Analyze relationship between cell size and nuclei count
    analyze_cell_size_vs_nuclei_count(all_results)

    print("\n======= Analysis Script Finished =======")
    print(f"Outputs saved in: {output_dir}")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


--- Finding and Pairing Mask Files ---
Found 8 cell mask files and 8 nuclei mask files
Total matching cell-nuclei file pairs found: 8
  0Pa: 3 pairs
  1.4Pa: 5 pairs

=== Processing 0Pa Group (3 file pairs) ===

Processing sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
Found 329 cells and 367 nuclei
Cells with nuclei: 329/329 (100.0% of cells)
Cells with multiple nuclei: 31/329 (9.4% of cells)
  Cells with 1 nuclei: 298 (90.6%)
  Cells with 2 nuclei: 29 (8.8%)
  Cells with 3 nuclei: 2 (0.6%)
Saved visualization to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Cell_Nuclei_Size_Analysis/0Pa_0Pa_U_05mar19_20x_L2RA_Flat_seq001/0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_nuclei_mapping.png

Processing sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq002
Found 368 cells and 431 nuclei
Cells with nuclei: 368/368 (100.0% of cells)
Cells with multiple n

<Figure size 1200x800 with 0 Axes>

<Figure size 1400x1000 with 0 Axes>