In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import re
from scipy import ndimage
from skimage import io, measure
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import umap
from tqdm import tqdm

# Use the functions you provided
def extract_sample_id(filename):
    """
    Extract the sample ID from a filename based on the specific naming pattern.

    Examples:
    - '1.4Pa_U_05mar19_20x_L2R_Flat_seq005_cell_mask_merged_conservative.tif' → '1.4Pa_U_05mar19_20x_L2R_Flat_seq005'
    - 'denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq005_Cadherins_filtered_mask.tif' → '1.4Pa_U_05mar19_20x_L2R_Flat_seq005'
    """
    # Remove file extension
    base_name = os.path.splitext(filename)[0]

    # Handle special prefixes like "denoised_"
    if base_name.startswith('denoised_'):
        base_name = base_name[len('denoised_'):]

    # Find the pattern that includes Pa (pressure), followed by identifiers and sequence number
    # Looking for patterns like "1.4Pa_U_05mar19_20x_L2R_Flat_seq005"
    pattern = re.compile(r'([\d\.]+Pa_[^_]+_[^_]+_[^_]+_[^_]+_[^_]+_seq\d+)')
    match = pattern.search(base_name)

    if match:
        # Use only the first match to avoid concatenating multiple matches
        return match.group(1)

    # If the regex didn't work, try a simpler approach with string splitting
    parts = base_name.split('_')

    # Look for the "seq" pattern which is common in your filenames
    for i, part in enumerate(parts):
        if part.startswith('seq') and i >= 2:  # Need at least 3 parts for a meaningful ID
            # Include parts up to and including the seq part (but limit length)
            # Limiting to maximum of 6 parts to avoid overly long IDs
            max_parts = min(i+1, 6)
            return '_'.join(parts[:max_parts])

    # Fallback: truncate to avoid excessively long IDs
    if len(parts) > 3:
        return '_'.join(parts[:min(len(parts)-2, 6)])  # Skip last 2 parts, max 6 parts

    # Last resort: use a shortened version of the filename (max 50 chars)
    shortened_name = os.path.basename(filename)[:50]
    return shortened_name

def find_mask_files(cell_dir, nuclei_dir):
    """Finds and pairs cell and nuclei mask files based on extracted sample ID."""
    print("\n--- Finding and Pairing Mask Files ---")

    # Get all relevant mask files
    cell_files = [f for f in os.listdir(cell_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]

    print(f"Found {len(cell_files)} cell mask files and {len(nuclei_files)} nuclei mask files")

    # Create lookup dictionary for nuclei files with extracted sample IDs
    nuclei_lookup = {}
    for nuclei_file in nuclei_files:
        sample_id = extract_sample_id(nuclei_file)
        if sample_id:
            nuclei_lookup[sample_id] = nuclei_file
            print(f"Nuclei file: '{nuclei_file}' → Sample ID: '{sample_id}'")

    # Match cell files to nuclei files
    file_pairs = []
    pairs_found = 0

    for cell_file in cell_files:
        sample_id = extract_sample_id(cell_file)
        print(f"Cell file: '{cell_file}' → Sample ID: '{sample_id}'")

        if sample_id and sample_id in nuclei_lookup:
            print(f"Match found: {sample_id}")
            nuclei_file = nuclei_lookup[sample_id]
            file_pair = {
                'cell_file': os.path.join(cell_dir, cell_file),
                'nuclei_file': os.path.join(nuclei_dir, nuclei_file),
                'sample_id': sample_id
            }
            file_pairs.append(file_pair)
            pairs_found += 1

    print(f"Total matching cell-nuclei file pairs found: {pairs_found}")

    return file_pairs

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's binary (0 or 1)."""
    try:
        img = io.imread(filepath)
        # Convert to binary uint8
        if img.dtype == bool:
            img = img.astype(np.uint8)
        elif np.max(img) > 1:
            # If it's a labeled mask, keep the labels
            if np.max(img) <= 255:
                img = img.astype(np.uint8)
            else:
                img = img.astype(np.uint16)
        else:
            img = img.astype(np.uint8)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]
            # If it's a labeled mask with multiple channels, threshold to get binary
            img = (img > 0).astype(np.uint8)

        return img
    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def accurately_track_nuclei_in_cells(cell_mask, nuclei_mask):
    """
    Accurately identifies which nuclei are inside which cells.

    Algorithm:
    1. Label individual cells and nuclei
    2. For each nucleus, find its centroid
    3. Check which cell contains this centroid
    4. If >50% of nucleus area is inside a cell, assign it to that cell
    """
    # Ensure masks are properly labeled
    # For cell_mask, each cell should have a unique ID
    if np.max(cell_mask) <= 1:
        labeled_cells, num_cells = ndimage.label(cell_mask)
    else:
        labeled_cells = cell_mask
        num_cells = np.max(labeled_cells)

    # For nuclei_mask, each nucleus should have a unique ID
    if np.max(nuclei_mask) <= 1:
        labeled_nuclei, num_nuclei = ndimage.label(nuclei_mask)
    else:
        labeled_nuclei = nuclei_mask
        num_nuclei = np.max(labeled_nuclei)

    print(f"Found {num_cells} cells and {num_nuclei} nuclei")

    # Extract properties for cells and nuclei
    cell_props = measure.regionprops(labeled_cells)
    nuclei_props = measure.regionprops(labeled_nuclei)

    # Create a results structure
    results = {
        'cell_data': [],
        'nuclei_data': [],
        'cell_nuclei_mapping': {}
    }

    # Process each cell
    for cell in cell_props:
        cell_id = cell.label
        cell_mask_binary = (labeled_cells == cell_id)

        # Basic cell properties
        cell_data = {
            'cell_id': cell_id,
            'area': cell.area,
            'perimeter': cell.perimeter,
            'eccentricity': cell.eccentricity,
            'orientation': np.degrees(cell.orientation) % 180 if hasattr(cell, 'orientation') else None,
            'major_axis_length': cell.major_axis_length if hasattr(cell, 'major_axis_length') else None,
            'minor_axis_length': cell.minor_axis_length if hasattr(cell, 'minor_axis_length') else None,
            'centroid_y': cell.centroid[0],
            'centroid_x': cell.centroid[1],
            'nuclei_count': 0
        }

        results['cell_data'].append(cell_data)
        results['cell_nuclei_mapping'][cell_id] = []

    # Find which nuclei belong to which cells
    for nucleus in nuclei_props:
        nucleus_id = nucleus.label
        nucleus_mask_binary = (labeled_nuclei == nucleus_id)
        nucleus_area = nucleus.area

        # Find which cell contains this nucleus
        contained_in_cell = None
        max_overlap_ratio = 0

        for cell in cell_props:
            cell_id = cell.label
            cell_mask_binary = (labeled_cells == cell_id)

            # Calculate overlap
            overlap = np.logical_and(cell_mask_binary, nucleus_mask_binary)
            overlap_area = np.sum(overlap)

            # Calculate what percentage of the nucleus is in this cell
            overlap_ratio = overlap_area / nucleus_area

            # If most of the nucleus is in this cell, assign it to this cell
            if overlap_ratio > max_overlap_ratio:
                max_overlap_ratio = overlap_ratio
                contained_in_cell = cell_id

        # Only count the nucleus if a significant portion is inside the cell (>50%)
        if contained_in_cell is not None and max_overlap_ratio > 0.5:
            # Store nucleus data
            nucleus_data = {
                'nucleus_id': nucleus_id,
                'cell_id': contained_in_cell,
                'area': nucleus.area,
                'eccentricity': nucleus.eccentricity if hasattr(nucleus, 'eccentricity') else None,
                'centroid_y': nucleus.centroid[0],
                'centroid_x': nucleus.centroid[1],
                'overlap_ratio': max_overlap_ratio
            }

            results['nuclei_data'].append(nucleus_data)

            # Update cell's nuclei count
            for cell_data in results['cell_data']:
                if cell_data['cell_id'] == contained_in_cell:
                    cell_data['nuclei_count'] += 1
                    break

            # Update cell-nuclei mapping
            results['cell_nuclei_mapping'][contained_in_cell].append(nucleus_id)

    # Count nuclei per cell
    cells_with_nuclei = sum(1 for cell_data in results['cell_data'] if cell_data['nuclei_count'] > 0)
    cells_with_multiple_nuclei = sum(1 for cell_data in results['cell_data'] if cell_data['nuclei_count'] > 1)

    print(f"Cells with nuclei: {cells_with_nuclei}/{num_cells} ({100*cells_with_nuclei/num_cells:.1f}% of cells)")
    print(f"Cells with multiple nuclei: {cells_with_multiple_nuclei}/{num_cells} ({100*cells_with_multiple_nuclei/num_cells:.1f}% of cells)")

    # Create summary of nuclei per cell
    nuclei_counts = [cell_data['nuclei_count'] for cell_data in results['cell_data']]
    unique_counts = sorted(set(nuclei_counts))
    for count in unique_counts:
        cells_with_count = sum(1 for n in nuclei_counts if n == count)
        print(f"  Cells with {count} nuclei: {cells_with_count} ({100*cells_with_count/num_cells:.1f}%)")

    return results

def extract_features_for_cell(cell_data, cell_nuclei_map, nuclei_data_list):
    """
    Extract comprehensive morphometric features for a cell and its associated nuclei.
    These features will be used for clustering.
    """
    features = {}

    # Basic cell features - directly from cell_data
    features['cell_area'] = cell_data['area']
    features['cell_perimeter'] = cell_data['perimeter']
    features['cell_eccentricity'] = cell_data['eccentricity'] if cell_data['eccentricity'] is not None else 0

    # Calculate additional cell shape features
    if cell_data['perimeter'] > 0:
        features['cell_circularity'] = 4 * np.pi * cell_data['area'] / (cell_data['perimeter'] ** 2)
    else:
        features['cell_circularity'] = 0

    if cell_data['major_axis_length'] is not None and cell_data['minor_axis_length'] is not None:
        if cell_data['minor_axis_length'] > 0:
            features['cell_aspect_ratio'] = cell_data['major_axis_length'] / cell_data['minor_axis_length']
        else:
            features['cell_aspect_ratio'] = 1.0
    else:
        features['cell_aspect_ratio'] = 1.0

    # Nuclear features
    cell_id = cell_data['cell_id']
    nucleus_ids = cell_nuclei_map.get(cell_id, [])
    features['nuclei_count'] = len(nucleus_ids)

    # Initialize nuclear features with defaults
    features['avg_nucleus_area'] = 0
    features['total_nuclear_area'] = 0
    features['max_nucleus_area'] = 0
    features['avg_nucleus_eccentricity'] = 0
    features['nucleus_area_std'] = 0
    features['nucleus_displacement'] = 0  # Distance between nucleus centroid and cell centroid

    if features['nuclei_count'] > 0:
        # Get nuclei associated with this cell
        cell_nuclei = [n for n in nuclei_data_list if n['cell_id'] == cell_id]

        # Calculate nuclear features
        nuclear_areas = [n['area'] for n in cell_nuclei]
        nuclear_eccentricities = [n['eccentricity'] if n['eccentricity'] is not None else 0 for n in cell_nuclei]

        features['avg_nucleus_area'] = np.mean(nuclear_areas) if nuclear_areas else 0
        features['total_nuclear_area'] = sum(nuclear_areas)
        features['max_nucleus_area'] = max(nuclear_areas) if nuclear_areas else 0
        features['avg_nucleus_eccentricity'] = np.mean(nuclear_eccentricities) if nuclear_eccentricities else 0
        features['nucleus_area_std'] = np.std(nuclear_areas) if len(nuclear_areas) > 1 else 0

        # Calculate nucleus-to-cell metrics
        features['nucleus_to_cell_area_ratio'] = features['total_nuclear_area'] / features['cell_area'] if features['cell_area'] > 0 else 0

        # Calculate average displacement of nuclei from cell center
        displacements = []
        for nucleus in cell_nuclei:
            dx = nucleus['centroid_x'] - cell_data['centroid_x']
            dy = nucleus['centroid_y'] - cell_data['centroid_y']
            displacement = np.sqrt(dx**2 + dy**2)
            displacements.append(displacement)

        features['nucleus_displacement'] = np.mean(displacements) if displacements else 0

    # Derived features specifically useful for senescence
    # 1. Polynucleation indicator (more likely to be senescent)
    features['is_polynucleated'] = 1 if features['nuclei_count'] > 1 else 0

    # 2. Nuclear enlargement (common in senescent cells)
    # We'll compare to typical nuclear size - this is a placeholder value
    avg_normal_nucleus_area = 500  # This should be determined from data
    features['nuclear_enlargement'] = features['avg_nucleus_area'] / avg_normal_nucleus_area if avg_normal_nucleus_area > 0 else 1

    # 3. Cell spreading (senescent cells are typically larger)
    avg_normal_cell_area = 2000  # This should be determined from data
    features['cell_enlargement'] = features['cell_area'] / avg_normal_cell_area if avg_normal_cell_area > 0 else 1

    return features

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import umap

def perform_clustering(all_features_df, umap_n_neighbors=30, umap_min_dist=0.1, umap_random_state=42, kmeans_random_state=42):
    """
    Performs UMAP dimensionality reduction and k-Means clustering on the features,
    then assigns cell types based on a nuanced senescent score and a multinucleation rule.

    Args:
        all_features_df (pd.DataFrame): DataFrame containing all extracted features for cells.
        umap_n_neighbors (int): UMAP n_neighbors parameter.
        umap_min_dist (float): UMAP min_dist parameter.
        umap_random_state (int): Random state for UMAP.
        kmeans_random_state (int): Random state for k-Means.

    Returns:
        pd.DataFrame: The input DataFrame with added columns for UMAP coordinates,
                      cluster labels, and final cell_type.
    """
    print("Starting clustering process...")
    if all_features_df.empty:
        print("Input DataFrame is empty. Cannot perform clustering.")
        return pd.DataFrame()

    # Define feature columns (ensure 'cell_id' and 'sample_id' are excluded if they were part of all_features_df)
    # Also exclude columns that will be added by this function like 'umap_x', 'umap_y', 'cluster', 'cell_type'
    potential_id_cols = ['cell_id', 'sample_id', 'original_index'] # Add any other ID columns
    feature_columns = [col for col in all_features_df.columns if col not in potential_id_cols and col not in ['umap_x', 'umap_y', 'cluster', 'cell_type']]

    # Ensure all feature columns are numeric, if not, attempt conversion or drop
    numeric_feature_columns = []
    for col in feature_columns:
        if pd.api.types.is_numeric_dtype(all_features_df[col]):
            numeric_feature_columns.append(col)
        else:
            try:
                all_features_df[col] = pd.to_numeric(all_features_df[col])
                numeric_feature_columns.append(col)
                print(f"Column {col} converted to numeric.")
            except ValueError:
                print(f"Warning: Column {col} is not numeric and could not be converted. It will be excluded from clustering features.")

    feature_columns = numeric_feature_columns
    if not feature_columns:
        print("No valid numeric feature columns found for clustering.")
        return all_features_df # Or an empty DataFrame

    features_df = all_features_df[feature_columns].copy()

    # --- 1. (NEW) Log Transform Area-based Features (before scaling) ---
    # This helps normalize skewed distributions for size-related features.
    area_features_to_log = ['cell_area', 'avg_nucleus_area', 'total_nuclear_area', 'max_nucleus_area', 'cell_perimeter']
    # Add other features like 'cell_enlargement', 'nuclear_enlargement' if their distributions are also highly skewed.
    # However, for derived ratios/enlargement factors, assess if log transform is appropriate or if they are better used as is in the heuristic.
    # For now, let's log primary area/perimeter measures.

    print("\nApplying log transformation to selected area/perimeter features...")
    for col in area_features_to_log:
        if col in features_df.columns:
            features_df[col] = np.log1p(features_df[col]) # log1p handles zeros if any
            print(f"  Log-transformed: {col}")
        else:
            print(f"  Warning: Column {col} not found for log transformation.")

    # --- 2. Standardize Features ---
    print("\nStandardizing features...")
    scaler = StandardScaler()
    features_standardized = scaler.fit_transform(features_df)
    features_standardized_df = pd.DataFrame(features_standardized, columns=feature_columns, index=features_df.index)

    # --- 3. UMAP Dimensionality Reduction ---
    print("\nPerforming UMAP reduction...")
    reducer = umap.UMAP(n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, random_state=umap_random_state, n_components=2)
    embedding = reducer.fit_transform(features_standardized_df)

    clustered_df = all_features_df.copy() # Start with the original df to add new columns
    clustered_df['umap_x'] = embedding[:, 0]
    clustered_df['umap_y'] = embedding[:, 1]

    # --- 4. k-Means Clustering ---
    print("\nPerforming k-Means clustering (k=2)...")
    kmeans = KMeans(n_clusters=2, random_state=kmeans_random_state, n_init='auto')
    clustered_df['cluster'] = kmeans.fit_predict(embedding) # Cluster on UMAP embedding

    # --- 5. (NEW) Identify Senescent Cluster using a More Extensive Score ---
    print("\nIdentifying senescent cluster using a weighted score...")

    # Calculate cluster statistics (mean feature values for each cluster)
    # Use the original feature values (or log-transformed if you want the score to reflect that scale)
    # for interpretability of the score. Here, we'll use the features_df which includes log-transformed values.
    # If you prefer original scale for scoring features that were log-transformed for UMAP,
    # you would calculate means from all_features_df[feature_columns]
    # but be careful with consistency if some are logged and some are not in the score.

    # For calculating cluster_stats, add the 'cluster' column to the dataframe that has the features used for scoring
    # If using log-transformed features for scoring (as in features_df):
    scoring_features_df = features_df.copy()
    scoring_features_df['cluster'] = clustered_df['cluster']
    cluster_stats = scoring_features_df.groupby('cluster').mean()

    # If you want to score based on *original* feature values (before log transform and scaling):
    # temp_df_for_stats = all_features_df[feature_columns].copy()
    # temp_df_for_stats['cluster'] = clustered_df['cluster']
    # cluster_stats = temp_df_for_stats.groupby('cluster').mean()
    # Ensure features in `weights` below exist in this `cluster_stats`

    # --- Define weights for each feature ---
    # Positive weight if a higher value indicates senescence, negative if a lower value does.
    # **TUNE THESE WEIGHTS BASED ON YOUR KNOWLEDGE AND OBSERVATIONS!**
    # Note: 'cell_area', 'avg_nucleus_area' etc. in cluster_stats will be log-scaled if 'features_df' was used above.
    # 'cell_enlargement' and 'nuclear_enlargement' from all_features_df are on original scale.
    # This mixing of scales means weights need careful thought.
    # For simplicity, let's assume cluster_stats are from 'features_df' (so some are log-scaled).
    # If you added 'cell_enlargement' etc. to features_df and log-transformed them, this is consistent.
    # If not, and they are part of `all_features_df` but not `features_df` for UMAP,
    # you'd need to get their means separately or ensure they are in `scoring_features_df`.

    # For a robust score, it's best if all features contributing to it are on a somewhat comparable scale OR
    # weights strongly reflect their differing scales and importance.
    # Let's ensure all features used in weights are actually present in cluster_stats.
    # We will use the `all_features_df` to get means for `cell_enlargement` and `nuclear_enlargement`
    # as they might not have been log-transformed or included in the UMAP features.

    cluster_stats_original_scale = all_features_df.copy()
    cluster_stats_original_scale['cluster'] = clustered_df['cluster']
    cluster_stats_original_scale_means = cluster_stats_original_scale.groupby('cluster')[feature_columns].mean()


    weights = {
        # Features assumed to be log-transformed if coming from 'features_df' based cluster_stats
        'cell_area': 1.0,        # Larger is more senescent (log-scale)
        'avg_nucleus_area': 0.7, # Larger nuclei often seen (log-scale)
        'total_nuclear_area': 0.5, # (log-scale)
        'cell_perimeter': 0.5,   # (log-scale)

        # Features from original scale (using cluster_stats_original_scale_means)
        # Or ensure these are in your `feature_columns` for UMAP if you want them processed that way
        'cell_enlargement': 1.5,    # More enlargement is more senescent (original scale from all_features_df)
        'nuclear_enlargement': 1.0, # More enlargement is more senescent (original scale from all_features_df)
        'nuclei_count': 0.8,        # Higher count is indicative (original scale)
        'cell_circularity': -0.5,   # Often less circular (original scale)
        'nucleus_to_cell_area_ratio': -1.0 # Often decreases (original scale)
        # Add/remove/tune features and weights as needed.
        # 'cell_aspect_ratio': -0.3, # Example, if relevant
    }

    senescent_score_0 = 0
    senescent_score_1 = 0

    print("Calculating senescent scores for clusters using weights:")
    for feature, weight in weights.items():
        if feature in cluster_stats.columns: # From log-transformed features_df
            senescent_score_0 += cluster_stats.loc[0, feature] * weight
            senescent_score_1 += cluster_stats.loc[1, feature] * weight
            print(f"  Using {feature} (log-transformed potentially) from features_df based stats: w={weight}")
        elif feature in cluster_stats_original_scale_means.columns: # From original scale all_features_df
            senescent_score_0 += cluster_stats_original_scale_means.loc[0, feature] * weight
            senescent_score_1 += cluster_stats_original_scale_means.loc[1, feature] * weight
            print(f"  Using {feature} (original scale) from all_features_df based stats: w={weight}")
        else:
            print(f"  Warning: Feature '{feature}' for scoring not found in available cluster stats. Skipping.")

    print(f"Senescent score for cluster 0: {senescent_score_0:.4f}")
    print(f"Senescent score for cluster 1: {senescent_score_1:.4f}")

    senescent_cluster_label = 0 if senescent_score_0 > senescent_score_1 else 1
    non_senescent_cluster_label = 1 - senescent_cluster_label
    print(f"Cluster {senescent_cluster_label} initially identified as 'Senescent'.")

    # --- 6. Initial Cell Type Assignment ---
    clustered_df['cell_type'] = clustered_df['cluster'].apply(
        lambda x: 'Senescent' if x == senescent_cluster_label else 'Non-senescent'
    )

    # --- 7. (NEW) Post-processing: Multinucleation Rule ---
    # Force classify cells with nuclei_count > 2 as Senescent.
    # Ensure 'nuclei_count' is present in clustered_df (it should be from all_features_df).
    if 'nuclei_count' in clustered_df.columns:
        print(f"\nApplying multinucleation rule (nuclei_count > 2)...")
        highly_multinucleated_mask = clustered_df['nuclei_count'] > 2

        # Count how many were Non-senescent before this rule and will be changed
        num_reclassified = clustered_df.loc[highly_multinucleated_mask & (clustered_df['cell_type'] == 'Non-senescent')].shape[0]

        clustered_df.loc[highly_multinucleated_mask, 'cell_type'] = 'Senescent'
        # Optional: use a distinct label like 'Senescent (Rule)' or 'Senescent (Multinucleated)'
        # if you want to track these specifically. If so, you'll need to update visualization palettes.
        print(f"{highly_multinucleated_mask.sum()} cells have >2 nuclei.")
        print(f"{num_reclassified} cells re-classified from 'Non-senescent' to 'Senescent' by this rule.")
    else:
        print("\nWarning: 'nuclei_count' column not found. Cannot apply multinucleation rule.")

    print("\nClustering process complete.")
    return clustered_df

def visualize_clustering_results(clustered_df):
    """
    Create visualizations of the clustering results.
    """
    print("\n--- Creating Visualizations ---")

    # UMAP plot colored by cluster
    plt.figure(figsize=(12, 10))

    # Scatter plot with UMAP results
    scatter = plt.scatter(
        clustered_df['umap_x'],
        clustered_df['umap_y'],
        c=clustered_df['cluster'],
        cmap='viridis',
        s=30,
        alpha=0.8
    )

    # Add legend
    legend_labels = clustered_df['cell_type'].unique()
    handles = [plt.Line2D([0], [0], marker='o', color='w',
                        markerfacecolor=scatter.cmap(scatter.norm(i)),
                        markersize=10)
               for i in range(len(legend_labels))]
    plt.legend(handles, legend_labels, title='Cell Type', loc='upper right')

    plt.title('UMAP Projection of Endothelial Cells', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)

    # Save the UMAP visualization
    plt.savefig('umap_clustering_results.png', dpi=300, bbox_inches='tight')
    plt.close()

    # Create a pair plot for key features
    key_features = ['cell_area', 'nuclei_count', 'avg_nucleus_area',
                   'nucleus_to_cell_area_ratio', 'cell_circularity']

    plt.figure(figsize=(15, 12))

    # Create custom pair plot
    fig, axes = plt.subplots(len(key_features), len(key_features), figsize=(15, 15))

    for i, feature_x in enumerate(key_features):
        for j, feature_y in enumerate(key_features):
            ax = axes[i, j]

            if i == j:  # Diagonal - histogram
                for cell_type, color in zip(['Non-senescent', 'Senescent'], ['blue', 'red']):
                    subset = clustered_df[clustered_df['cell_type'] == cell_type]
                    ax.hist(subset[feature_x], alpha=0.5, color=color, bins=20)
                ax.set_title(feature_x.replace('_', ' ').title(), fontsize=9)

            else:  # Off-diagonal - scatter plot
                for cell_type, color in zip(['Non-senescent', 'Senescent'], ['blue', 'red']):
                    subset = clustered_df[clustered_df['cell_type'] == cell_type]
                    ax.scatter(subset[feature_x], subset[feature_y], s=10, alpha=0.5, color=color)

                if j == 0:  # First column
                    ax.set_ylabel(feature_y.replace('_', ' ').title(), fontsize=9)
                if i == len(key_features) - 1:  # Last row
                    ax.set_xlabel(feature_x.replace('_', ' ').title(), fontsize=9)

    # Add a common legend
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=c, markersize=10)
              for c in ['blue', 'red']]
    fig.legend(handles, ['Non-senescent', 'Senescent'],
               loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=2, fontsize=12)

    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig('feature_relationships.png', dpi=300, bbox_inches='tight')
    plt.close()

    # Feature distributions by cell type
    plt.figure(figsize=(15, 15))

    # Create boxplots for key features
    fig, axes = plt.subplots(3, 2, figsize=(15, 15))
    axes = axes.flatten()

    top_features = sorted(
        [(feature, abs(
            clustered_df[clustered_df['cell_type'] == 'Senescent'][feature].mean() -
            clustered_df[clustered_df['cell_type'] == 'Non-senescent'][feature].mean()
        )) for feature in key_features],
        key=lambda x: x[1],
        reverse=True
    )

    for i, (feature, _) in enumerate(top_features[:6]):
        ax = axes[i]
        sns.boxplot(x='cell_type', y=feature, data=clustered_df, ax=ax, palette=['blue', 'red'])
        ax.set_title(feature.replace('_', ' ').title(), fontsize=12)
        ax.set_xlabel('')
        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=10)

    plt.tight_layout()
    plt.savefig('feature_distributions.png', dpi=300, bbox_inches='tight')
    plt.close()

    print("Visualizations saved to disk: umap_clustering_results.png, feature_relationships.png, feature_distributions.png")

def save_results(clustered_df, output_directory='results'):
    """
    Save the clustering results to CSV files and create additional visualizations.
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Save full results
    clustered_df.to_csv(os.path.join(output_directory, 'cell_classification_results.csv'), index=False)

    # Create summary by sample
    sample_summary = clustered_df.groupby('sample_id')['cell_type'].value_counts().unstack(fill_value=0)

    # Handle case where some samples might not have both cell types
    if 'Senescent' not in sample_summary.columns:
        sample_summary['Senescent'] = 0
    if 'Non-senescent' not in sample_summary.columns:
        sample_summary['Non-senescent'] = 0

    sample_summary['total_cells'] = sample_summary.sum(axis=1)
    sample_summary['percent_senescent'] = sample_summary['Senescent'] / sample_summary['total_cells'] * 100

    sample_summary.to_csv(os.path.join(output_directory, 'sample_summary.csv'))

    # Create and save additional visualizations

    # 1. Senescent percentage by sample
    plt.figure(figsize=(12, 6))
    sns.barplot(x=sample_summary.index, y=sample_summary['percent_senescent'])
    plt.title('Percentage of Senescent Cells by Sample')
    plt.xlabel('Sample ID')
    plt.ylabel('Senescent Cells (%)')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig(os.path.join(output_directory, 'senescent_percentage_by_sample.png'), dpi=300)
    plt.close()

    # 2. Scatter plot of key senescence indicators
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        clustered_df['cell_area'],
        clustered_df['avg_nucleus_area'],
        c=clustered_df['cell_type'].map({'Senescent': 1, 'Non-senescent': 0}),
        cmap='coolwarm',
        alpha=0.7,
        s=30
    )
    plt.xlabel('Cell Area')
    plt.ylabel('Average Nucleus Area')
    plt.title('Cell Area vs. Nuclear Area by Cell Type')
    plt.colorbar(scatter, label='Cell Type', ticks=[0, 1],
                format=plt.FuncFormatter(lambda x, pos: 'Non-senescent' if x < 0.5 else 'Senescent'))
    plt.tight_layout()
    plt.savefig(os.path.join(output_directory, 'area_relationship.png'), dpi=300)
    plt.close()

    # 3. Nuclei count distribution
    plt.figure(figsize=(10, 6))
    nuclei_counts = clustered_df.groupby('cell_type')['nuclei_count'].value_counts(normalize=True).unstack(fill_value=0) * 100
    nuclei_counts.plot(kind='bar')
    plt.title('Distribution of Nuclei Count by Cell Type')
    plt.xlabel('Number of Nuclei')
    plt.ylabel('Percentage of Cells (%)')
    plt.legend(title='Cell Type')
    plt.tight_layout()
    plt.savefig(os.path.join(output_directory, 'nuclei_count_distribution.png'), dpi=300)
    plt.close()

    print(f"\nResults saved to {output_directory}/")
    print(f"  - Full cell classification: cell_classification_results.csv")
    print(f"  - Sample summary: sample_summary.csv")
    print(f"  - Additional visualizations: senescent_percentage_by_sample.png, area_relationship.png, nuclei_count_distribution.png")

def main(cell_dir, nuclei_dir, output_dir=None):
    """
    Main function to perform the analysis.
    """
    print("=== Senescent Cell Classification Analysis ===")

    # Set default output directory if not provided
    if output_dir is None:
        output_dir = "results"

    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Find and pair mask files (without pressure classification)
    file_pairs = find_mask_files(cell_dir, nuclei_dir)

    # Create a list to store all cell features
    all_cell_features = []

    # Process all image pairs
    print(f"\nProcessing {len(file_pairs)} image pairs")

    # Process each image pair
    for file_pair in tqdm(file_pairs, desc="Processing images"):
        try:
            # Load cell and nuclei masks
            cell_mask = load_mask_image(file_pair['cell_file'])
            nuclei_mask = load_mask_image(file_pair['nuclei_file'])

            if cell_mask is None or nuclei_mask is None:
                print(f"Error: Could not load masks for {file_pair['sample_id']}")
                continue

            # Track nuclei in cells
            results = accurately_track_nuclei_in_cells(cell_mask, nuclei_mask)

            # Extract features for each cell
            for cell_data in results['cell_data']:
                features = extract_features_for_cell(
                    cell_data,
                    results['cell_nuclei_mapping'],
                    results['nuclei_data']
                )

                # Add cell ID and sample ID
                features['cell_id'] = f"{file_pair['sample_id']}_{cell_data['cell_id']}"
                features['sample_id'] = file_pair['sample_id']

                all_cell_features.append(features)

        except Exception as e:
            print(f"Error processing {file_pair['sample_id']}: {str(e)}")

    # Convert to DataFrame
    if all_cell_features:
        all_features_df = pd.DataFrame(all_cell_features)
        print(f"\nTotal cells extracted: {len(all_features_df)}")

        # Perform clustering
        clustered_df = perform_clustering(all_features_df)

        # Visualize results
        visualize_clustering_results(clustered_df)

        # Save results
        save_results(clustered_df, output_dir)

        print("\nAnalysis complete!")
    else:
        print("No cell features were extracted. Please check your input files.")

if __name__ == "__main__":
    # Your specific directories
    cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
    nuclei_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
    output_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence"

    main(cell_mask_dir, nuclei_dir, output_dir)

=== Senescent Cell Classification Analysis ===

--- Finding and Pairing Mask Files ---
Found 8 cell mask files and 8 nuclei mask files
Nuclei file: 'denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Cadherins_filtered_mask.tif' → Sample ID: '0Pa_U_05mar19_20x_L2RA_Flat_seq001'
Nuclei file: 'denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Cadherins_filtered_mask.tif' → Sample ID: '0Pa_U_05mar19_20x_L2RA_Flat_seq002'
Nuclei file: 'denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Cadherins_filtered_mask.tif' → Sample ID: '0Pa_U_05mar19_20x_L2RA_Flat_seq003'
Nuclei file: 'denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq001_Cadherins_filtered_mask.tif' → Sample ID: '1.4Pa_U_05mar19_20x_L2R_Flat_seq001'
Nuclei file: 'denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq002_Cadherins_filtered_mask.tif' → Sample ID: '1.4Pa_U_05mar19_20x_L2R_Flat_seq002'
Nuclei file: 'denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq003_Cadherins_filtered_mask.tif' → Sample ID: '1.4Pa_U_05mar19_20x_L2R_Flat_seq003'
Nuclei file: 'denoised_1.4Pa_U_05mar1

Processing images:   0%|          | 0/8 [00:00<?, ?it/s]

Found 329 cells and 367 nuclei


Processing images:  12%|█▎        | 1/8 [02:13<15:33, 133.38s/it]

Cells with nuclei: 329/329 (100.0% of cells)
Cells with multiple nuclei: 31/329 (9.4% of cells)
  Cells with 1 nuclei: 298 (90.6%)
  Cells with 2 nuclei: 29 (8.8%)
  Cells with 3 nuclei: 2 (0.6%)
Found 368 cells and 431 nuclei


Processing images:  25%|██▌       | 2/8 [05:13<16:04, 160.79s/it]

Cells with nuclei: 368/368 (100.0% of cells)
Cells with multiple nuclei: 55/368 (14.9% of cells)
  Cells with 1 nuclei: 313 (85.1%)
  Cells with 2 nuclei: 50 (13.6%)
  Cells with 3 nuclei: 5 (1.4%)
Found 385 cells and 450 nuclei


Processing images:  38%|███▊      | 3/8 [08:29<14:44, 176.90s/it]

Cells with nuclei: 385/385 (100.0% of cells)
Cells with multiple nuclei: 54/385 (14.0% of cells)
  Cells with 1 nuclei: 331 (86.0%)
  Cells with 2 nuclei: 51 (13.2%)
  Cells with 3 nuclei: 1 (0.3%)
  Cells with 4 nuclei: 2 (0.5%)
Found 264 cells and 285 nuclei


Processing images:  50%|█████     | 4/8 [09:57<09:27, 141.81s/it]

Cells with nuclei: 264/264 (100.0% of cells)
Cells with multiple nuclei: 20/264 (7.6% of cells)
  Cells with 1 nuclei: 244 (92.4%)
  Cells with 2 nuclei: 20 (7.6%)
Found 210 cells and 229 nuclei


Processing images:  62%|██████▎   | 5/8 [10:48<05:27, 109.22s/it]

Cells with nuclei: 210/210 (100.0% of cells)
Cells with multiple nuclei: 15/210 (7.1% of cells)
  Cells with 1 nuclei: 195 (92.9%)
  Cells with 2 nuclei: 14 (6.7%)
  Cells with 3 nuclei: 1 (0.5%)
Found 294 cells and 315 nuclei


Processing images:  75%|███████▌  | 6/8 [12:30<03:33, 106.73s/it]

Cells with nuclei: 294/294 (100.0% of cells)
Cells with multiple nuclei: 15/294 (5.1% of cells)
  Cells with 1 nuclei: 279 (94.9%)
  Cells with 2 nuclei: 14 (4.8%)
  Cells with 5 nuclei: 1 (0.3%)
Found 287 cells and 299 nuclei


Processing images:  88%|████████▊ | 7/8 [14:06<01:43, 103.22s/it]

Cells with nuclei: 287/287 (100.0% of cells)
Cells with multiple nuclei: 9/287 (3.1% of cells)
  Cells with 1 nuclei: 278 (96.9%)
  Cells with 2 nuclei: 9 (3.1%)
Found 335 cells and 370 nuclei


Processing images: 100%|██████████| 8/8 [16:21<00:00, 122.72s/it]

Cells with nuclei: 335/335 (100.0% of cells)
Cells with multiple nuclei: 26/335 (7.8% of cells)
  Cells with 1 nuclei: 309 (92.2%)
  Cells with 2 nuclei: 24 (7.2%)
  Cells with 3 nuclei: 2 (0.6%)

Total cells extracted: 2472
Starting clustering process...

Applying log transformation to selected area/perimeter features...
  Log-transformed: cell_area
  Log-transformed: avg_nucleus_area
  Log-transformed: total_nuclear_area
  Log-transformed: max_nucleus_area
  Log-transformed: cell_perimeter

Standardizing features...

Performing UMAP reduction...



  warn(



Performing k-Means clustering (k=2)...

Identifying senescent cluster using a weighted score...
Calculating senescent scores for clusters using weights:
  Using cell_area (log-transformed potentially) from features_df based stats: w=1.0
  Using avg_nucleus_area (log-transformed potentially) from features_df based stats: w=0.7
  Using total_nuclear_area (log-transformed potentially) from features_df based stats: w=0.5
  Using cell_perimeter (log-transformed potentially) from features_df based stats: w=0.5
  Using cell_enlargement (log-transformed potentially) from features_df based stats: w=1.5
  Using nuclear_enlargement (log-transformed potentially) from features_df based stats: w=1.0
  Using nuclei_count (log-transformed potentially) from features_df based stats: w=0.8
  Using cell_circularity (log-transformed potentially) from features_df based stats: w=-0.5
  Using nucleus_to_cell_area_ratio (log-transformed potentially) from features_df based stats: w=-1.0
Senescent score for clu


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type', y=feature, data=clustered_df, ax=ax, palette=['blue', 'red'])

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type', y=feature, data=clustered_df, ax=ax, palette=['blue', 'red'])

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type', y=feature, data=clustered_df, ax=ax, palette=['blue', 'red'])

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type', y=feature, data=clustered_df, 

Visualizations saved to disk: umap_clustering_results.png, feature_relationships.png, feature_distributions.png

Results saved to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/
  - Full cell classification: cell_classification_results.csv
  - Sample summary: sample_summary.csv
  - Additional visualizations: senescent_percentage_by_sample.png, area_relationship.png, nuclei_count_distribution.png

Analysis complete!


<Figure size 1500x1200 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 1000x600 with 0 Axes>

Make classifications

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from skimage import io, measure, segmentation
from scipy import ndimage, spatial, stats
from sklearn.neighbors import NearestNeighbors
import cv2
import networkx as nx
from tqdm import tqdm
import re

def analyze_spatial_distribution(cell_mask_dir, nuclei_dir, results_csv, output_dir="spatial_analysis"):
    """
    Analyze the spatial distribution of senescent vs non-senescent cells.

    Parameters:
    -----------
    cell_mask_dir : str
        Directory containing the cell mask files
    nuclei_dir : str
        Directory containing the nuclei mask files
    results_csv : str
        Path to the cell_classification_results.csv file
    output_dir : str
        Directory to save the spatial analysis results
    """
    print("=== Spatial Analysis of Senescent Cell Distribution ===")

    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Load classification results
    results_df = pd.read_csv(results_csv)
    print(f"Loaded {len(results_df)} classified cells")

    # Extract sample IDs and cell IDs from the results
    # Assuming cell_id format is "sample_id_cell_id"
    results_df['original_sample_id'] = results_df['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
    results_df['original_cell_id'] = results_df['cell_id'].apply(lambda x: int(x.split('_')[-1]))

    # Group by sample ID
    sample_groups = results_df.groupby('original_sample_id')

    # Find and process mask files for each sample
    cell_files = [f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]

    # Create DataFrames to store results
    spatial_stats = []
    neighbor_stats = []

    # Process each sample in the results
    for sample_id, group in tqdm(sample_groups, desc="Analyzing spatial distribution"):
        # Find the corresponding cell mask file
        cell_file = next((f for f in cell_files if sample_id in f), None)

        if cell_file is None:
            print(f"Warning: Could not find cell mask file for sample {sample_id}")
            continue

        # Load the cell mask
        cell_mask_path = os.path.join(cell_mask_dir, cell_file)

        try:
            # Load mask
            cell_mask = load_mask_image(cell_mask_path)

            if cell_mask is None:
                print(f"Error: Could not load cell mask for {sample_id}")
                continue

            # Extract cell properties with centroids
            cell_props = extract_cell_properties(cell_mask)

            # Map classification to cell properties
            cell_props = map_classification_to_cells(cell_props, group)

            # Calculate nearest neighbor statistics
            nn_stats = calculate_nearest_neighbor_stats(cell_props, sample_id)
            neighbor_stats.append(nn_stats)

            # Calculate spatial statistics (clustering, dispersion)
            spatial_result = calculate_spatial_statistics(cell_props, sample_id)
            spatial_stats.append(spatial_result)

            # Visualize cell positions with color-coded classifications
            visualize_cell_positions(cell_props, sample_id, output_dir)

            # Create neighborhood analysis visualization
            visualize_neighborhoods(cell_props, sample_id, output_dir)

        except Exception as e:
            print(f"Error processing {sample_id}: {str(e)}")

    # Combine results and save to CSV
    if spatial_stats:
        spatial_df = pd.DataFrame(spatial_stats)
        spatial_df.to_csv(os.path.join(output_dir, "spatial_statistics.csv"), index=False)

    if neighbor_stats:
        neighbor_df = pd.DataFrame(neighbor_stats)
        neighbor_df.to_csv(os.path.join(output_dir, "neighbor_statistics.csv"), index=False)

        # Create summary visualizations
        create_summary_visualizations(spatial_df, neighbor_df, output_dir)

    print(f"\nSpatial analysis complete! Results saved to {output_dir}/")

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's properly formatted for analysis."""
    try:
        img = io.imread(filepath)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]

        # If it's already a labeled mask (values > 1), use it directly
        if np.max(img) > 1:
            return img

        # Otherwise, label the connected components
        labeled_img, num_features = ndimage.label(img > 0)
        return labeled_img

    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def extract_cell_properties(cell_mask):
    """
    Extract cell properties including centroids and boundaries.

    Parameters:
    -----------
    cell_mask : ndarray
        Labeled cell mask

    Returns:
    --------
    cell_props : list of dicts
        List of dictionaries containing cell properties
    """
    # Extract region properties
    props = measure.regionprops(cell_mask)

    cell_props = []
    for prop in props:
        cell_id = prop.label

        # Skip background (label 0)
        if cell_id == 0:
            continue

        cell_data = {
            'cell_id': cell_id,
            'centroid': prop.centroid,
            'area': prop.area,
            'perimeter': prop.perimeter,
            'eccentricity': prop.eccentricity if hasattr(prop, 'eccentricity') else 0,
            'classification': None,  # Will be filled later
            'neighbor_ids': [],  # Will be filled later
            'boundary_pixels': []  # Will be extracted below
        }

        # Extract boundary pixels for this cell
        cell_mask_binary = (cell_mask == cell_id)
        boundary = segmentation.find_boundaries(cell_mask_binary, mode='outer')
        boundary_coords = np.where(boundary)
        cell_data['boundary_pixels'] = list(zip(boundary_coords[0], boundary_coords[1]))

        cell_props.append(cell_data)

    return cell_props

def map_classification_to_cells(cell_props, classification_group):
    """
    Map senescent/non-senescent classifications to cells.

    Parameters:
    -----------
    cell_props : list of dicts
        List of cell properties
    classification_group : DataFrame
        DataFrame with classification results for this sample

    Returns:
    --------
    cell_props : list of dicts
        Updated list with classifications added
    """
    # Create mapping from cell ID to classification
    cell_classifications = dict(zip(
        classification_group['original_cell_id'],
        classification_group['cell_type']
    ))

    # Map classifications to cell properties
    for cell in cell_props:
        cell_id = cell['cell_id']
        if cell_id in cell_classifications:
            cell['classification'] = cell_classifications[cell_id]
        else:
            cell['classification'] = 'Unknown'

    return cell_props

def calculate_nearest_neighbor_stats(cell_props, sample_id):
    """
    Calculate nearest neighbor statistics for senescent vs non-senescent cells.

    Parameters:
    -----------
    cell_props : list of dicts
        List of cell properties with classifications
    sample_id : str
        Sample identifier

    Returns:
    --------
    stats : dict
        Dictionary of nearest neighbor statistics
    """
    # Extract cell centroids and classifications
    centroids = np.array([cell['centroid'] for cell in cell_props])
    classifications = np.array([cell['classification'] for cell in cell_props])

    # Skip if too few cells
    if len(centroids) < 5:
        return {
            'sample_id': sample_id,
            'senescent_nn_mean_distance': np.nan,
            'non_senescent_nn_mean_distance': np.nan,
            'senescent_to_senescent_ratio': np.nan,
            'mixed_neighbor_percentage': np.nan,
            'senescent_percentage': np.nan,
            'avg_senescent_per_neighborhood': np.nan
        }

    # Find senescent and non-senescent cells
    senescent_mask = classifications == 'Senescent'
    non_senescent_mask = classifications == 'Non-senescent'

    # Calculate percentage of senescent cells
    senescent_percentage = np.mean(senescent_mask) * 100

    # Calculate nearest neighbors for each cell
    nn = NearestNeighbors(n_neighbors=min(6, len(centroids)))
    nn.fit(centroids)
    distances, indices = nn.kneighbors(centroids)

    # First neighbor is the cell itself (distance=0), so use the second one
    nearest_dists = distances[:, 1]
    nearest_indices = indices[:, 1]

    # Calculate mean nearest neighbor distance for each cell type
    sen_nn_dists = nearest_dists[senescent_mask]
    non_sen_nn_dists = nearest_dists[non_senescent_mask]

    sen_nn_mean = np.mean(sen_nn_dists) if len(sen_nn_dists) > 0 else np.nan
    non_sen_nn_mean = np.mean(non_sen_nn_dists) if len(non_sen_nn_dists) > 0 else np.nan

    # Calculate neighbor relationships
    # Check if senescent cells tend to have senescent neighbors
    sen_to_sen_neighbors = 0
    total_sen_neighbors = 0

    for i, cell in enumerate(cell_props):
        if cell['classification'] == 'Senescent':
            neighbor_idx = nearest_indices[i]
            if classifications[neighbor_idx] == 'Senescent':
                sen_to_sen_neighbors += 1
            total_sen_neighbors += 1

    # Calculate ratio of senescent-to-senescent neighbors
    sen_to_sen_ratio = sen_to_sen_neighbors / total_sen_neighbors if total_sen_neighbors > 0 else np.nan

    # Calculate percentage of mixed neighborhoods
    # (neighborhoods with both senescent and non-senescent cells)
    mixed_neighborhoods = 0
    total_neighborhoods = len(centroids)

    # Consider neighborhood as 5 nearest neighbors
    k = min(5, len(centroids) - 1)
    for i in range(len(centroids)):
        neighborhood = indices[i, 1:k+1]  # Skip first (self)
        neighbor_types = classifications[neighborhood]
        if np.any(neighbor_types == 'Senescent') and np.any(neighbor_types == 'Non-senescent'):
            mixed_neighborhoods += 1

    mixed_percentage = (mixed_neighborhoods / total_neighborhoods) * 100 if total_neighborhoods > 0 else np.nan

    # Calculate average number of senescent cells per neighborhood
    sen_per_neighborhood = []
    for i in range(len(centroids)):
        neighborhood = indices[i, :k+1]  # Include self
        neighbor_types = classifications[neighborhood]
        sen_per_neighborhood.append(np.sum(neighbor_types == 'Senescent'))

    avg_sen_per_neighborhood = np.mean(sen_per_neighborhood)

    # Return statistics
    return {
        'sample_id': sample_id,
        'senescent_nn_mean_distance': sen_nn_mean,
        'non_senescent_nn_mean_distance': non_sen_nn_mean,
        'senescent_to_senescent_ratio': sen_to_sen_ratio,
        'mixed_neighbor_percentage': mixed_percentage,
        'senescent_percentage': senescent_percentage,
        'avg_senescent_per_neighborhood': avg_sen_per_neighborhood
    }

def calculate_spatial_statistics(cell_props, sample_id):
    """
    Calculate spatial statistics for senescent vs non-senescent cells.

    Parameters:
    -----------
    cell_props : list of dicts
        List of cell properties with classifications
    sample_id : str
        Sample identifier

    Returns:
    --------
    stats : dict
        Dictionary of spatial statistics
    """
    # Extract cell centroids and classifications
    centroids = np.array([cell['centroid'] for cell in cell_props])
    classifications = np.array([cell['classification'] for cell in cell_props])

    # Skip if too few cells
    if len(centroids) < 5:
        return {
            'sample_id': sample_id,
            'nn_ratio': np.nan,
            'ripley_k_diff': np.nan,
            'sen_clustering_index': np.nan,
            'sen_dispersion_index': np.nan,
            'moran_i': np.nan,
            'geary_c': np.nan
        }

    # Find senescent and non-senescent cells
    senescent_mask = classifications == 'Senescent'
    non_senescent_mask = classifications == 'Non-senescent'

    senescent_centroids = centroids[senescent_mask]
    non_senescent_centroids = centroids[non_senescent_mask]

    # Calculate nearest neighbor ratio (Clark-Evans R statistic)
    def calculate_clark_evans_r(points):
        if len(points) < 2:
            return np.nan

        # Calculate mean nearest neighbor distance
        nn = NearestNeighbors(n_neighbors=2)  # 2 because first neighbor is self
        nn.fit(points)
        distances, _ = nn.kneighbors(points)
        mean_nn_dist = np.mean(distances[:, 1])

        # Calculate point density using bounding box approximation
        min_x, min_y = np.min(points, axis=0)
        max_x, max_y = np.max(points, axis=0)
        area = (max_x - min_x) * (max_y - min_y)

        if area == 0:  # Handle edge case
            return np.nan

        density = len(points) / area

        # Expected mean distance for random distribution
        expected_mean_dist = 0.5 / np.sqrt(density)

        # Clark-Evans R statistic
        r = mean_nn_dist / expected_mean_dist if expected_mean_dist > 0 else np.nan

        return r

    # Calculate Clark-Evans R for senescent cells
    nn_ratio = calculate_clark_evans_r(senescent_centroids)

    # Calculate a simplified Ripley's K function for detecting clustering
    def simplified_ripley_k(points, r_max, n_steps=10):
        if len(points) < 5:
            return np.zeros(n_steps)

        n_points = len(points)

        # Calculate bounding box area
        min_x, min_y = np.min(points, axis=0)
        max_x, max_y = np.max(points, axis=0)
        area = (max_x - min_x) * (max_y - min_y)

        if area == 0:  # Handle edge case
            return np.zeros(n_steps)

        # Calculate distances between all pairs of points
        distances = spatial.distance.pdist(points)

        # Calculate K function for different radii
        r_values = np.linspace(0, r_max, n_steps)
        k_values = np.zeros(n_steps)

        for i, r in enumerate(r_values):
            # Count pairs with distance <= r
            count = np.sum(distances <= r)
            # Ripley's K formula (simplified)
            k = (area * count) / (n_points * (n_points - 1)) if n_points > 1 else 0
            k_values[i] = k

        return k_values

    # Calculate max radius as half the maximum dimension of the bounding box
    if len(centroids) > 0:
        min_x, min_y = np.min(centroids, axis=0)
        max_x, max_y = np.max(centroids, axis=0)
        r_max = max(max_x - min_x, max_y - min_y) / 2
    else:
        r_max = 100  # Default value

    # Calculate Ripley's K for senescent and non-senescent cells
    k_sen = simplified_ripley_k(senescent_centroids, r_max)
    k_non_sen = simplified_ripley_k(non_senescent_centroids, r_max)

    # Calculate difference in K functions (positive indicates more clustering in senescent)
    ripley_k_diff = np.mean(k_sen - k_non_sen) if len(k_sen) > 0 and len(k_non_sen) > 0 else np.nan

    # Create binary indicator for senescent cells (1 = senescent, 0 = non-senescent)
    values = np.where(classifications == 'Senescent', 1, 0)

    # Calculate spatial autocorrelation (simplified)
    moran_i = np.nan
    geary_c = np.nan

    # Find neighbors for each cell
    k = min(5, len(centroids) - 1)
    if k > 0:
        nn = NearestNeighbors(n_neighbors=k+1)  # +1 to include self
        nn.fit(centroids)
        _, indices = nn.kneighbors(centroids)

        # Calculate senescent clustering index
        observed_sen_neighbors = 0
        total_neighbors = 0
        sen_ratio = np.mean(values)

        for i, cell_type in enumerate(values):
            if cell_type == 1:  # If senescent
                # Get neighbors excluding self
                neighbors = indices[i, 1:]
                observed_sen_neighbors += np.sum(values[neighbors])
                total_neighbors += len(neighbors)

        # Expected number of senescent neighbors based on random distribution
        expected_sen_neighbors = total_neighbors * sen_ratio if sen_ratio > 0 else 0

        # Clustering index: ratio of observed to expected
        sen_clustering_index = observed_sen_neighbors / expected_sen_neighbors if expected_sen_neighbors > 0 else np.nan

        # Dispersion index - coefficient of variation of nearest neighbor distances
        if len(senescent_centroids) >= 2:
            nn_sen = NearestNeighbors(n_neighbors=2)
            nn_sen.fit(senescent_centroids)
            sen_dists, _ = nn_sen.kneighbors(senescent_centroids)
            sen_dists = sen_dists[:, 1]  # Exclude self

            sen_dispersion_index = np.std(sen_dists) / np.mean(sen_dists) if np.mean(sen_dists) > 0 else np.nan
        else:
            sen_dispersion_index = np.nan
    else:
        sen_clustering_index = np.nan
        sen_dispersion_index = np.nan

    # Return results
    return {
        'sample_id': sample_id,
        'nn_ratio': nn_ratio,
        'ripley_k_diff': ripley_k_diff,
        'sen_clustering_index': sen_clustering_index,
        'sen_dispersion_index': sen_dispersion_index,
        'moran_i': moran_i,
        'geary_c': geary_c
    }

def visualize_cell_positions(cell_props, sample_id, output_dir):
    """
    Create a visualization of cell positions with senescent/non-senescent color coding.

    Parameters:
    -----------
    cell_props : list of dicts
        List of cell properties with classifications
    sample_id : str
        Sample identifier
    output_dir : str
        Directory to save visualization
    """
    # Extract cell centroids and classifications
    centroids = np.array([cell['centroid'] for cell in cell_props])
    classifications = np.array([cell['classification'] for cell in cell_props])
    areas = np.array([cell['area'] for cell in cell_props])

    # Skip if too few cells
    if len(centroids) < 2:
        return

    # Create scatter plot of cell positions
    plt.figure(figsize=(10, 10))

    # Create color map for classification
    color_map = {'Senescent': 'red', 'Non-senescent': 'blue', 'Unknown': 'gray'}
    colors = [color_map[c] for c in classifications]

    # Scale point sizes based on cell area
    min_area = np.min(areas)
    max_area = np.max(areas)
    normalized_areas = 50 * (areas - min_area) / (max_area - min_area) + 20

    # Plot cell positions with colors indicating classification
    scatter = plt.scatter(
        centroids[:, 1],  # x coordinate (column 1)
        centroids[:, 0],  # y coordinate (column 0)
        c=colors,
        s=normalized_areas,
        alpha=0.7,
        edgecolor='black',
        linewidth=0.5
    )

    # Invert y-axis to match image coordinates
    plt.gca().invert_yaxis()

    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red',
                  markersize=10, label='Senescent'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue',
                  markersize=10, label='Non-senescent')
    ]
    plt.legend(handles=legend_elements, loc='upper right')

    plt.title(f'Cell Positions - {sample_id}', fontsize=14)
    plt.xlabel('X Position', fontsize=12)
    plt.ylabel('Y Position', fontsize=12)
    plt.grid(True, alpha=0.3)

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{sample_id}_cell_positions.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Create additional visualization showing nearest neighbor connections
    plt.figure(figsize=(12, 12))

    # Plot cell positions again
    plt.scatter(
        centroids[:, 1],  # x coordinate
        centroids[:, 0],  # y coordinate
        c=colors,
        s=normalized_areas,
        alpha=0.7,
        edgecolor='black',
        linewidth=0.5
    )

    # Connect nearest neighbors with lines
    k = min(3, len(centroids) - 1)
    if k > 0:
        nn = NearestNeighbors(n_neighbors=k+1)  # +1 to include self
        nn.fit(centroids)
        distances, indices = nn.kneighbors(centroids)

        # Draw lines between neighbors
        for i, neighbors in enumerate(indices):
            # Skip first neighbor (self)
            for j in neighbors[1:]:
                # Draw line with alpha based on distance
                plt.plot(
                    [centroids[i, 1], centroids[j, 1]],  # x coordinates
                    [centroids[i, 0], centroids[j, 0]],  # y coordinates
                    color='gray',
                    alpha=0.3,
                    linewidth=0.5
                )

    # Invert y-axis
    plt.gca().invert_yaxis()

    # Add legend
    plt.legend(handles=legend_elements, loc='upper right')

    plt.title(f'Cell Neighborhoods - {sample_id}', fontsize=14)
    plt.xlabel('X Position', fontsize=12)
    plt.ylabel('Y Position', fontsize=12)
    plt.grid(True, alpha=0.3)

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{sample_id}_cell_neighborhoods.png'), dpi=300, bbox_inches='tight')
    plt.close()

def visualize_neighborhoods(cell_props, sample_id, output_dir):
    """
    Create a network visualization showing relationships between cells.

    Parameters:
    -----------
    cell_props : list of dicts
        List of cell properties with classifications
    sample_id : str
        Sample identifier
    output_dir : str
        Directory to save visualization
    """
    # Extract cell centroids and classifications
    centroids = np.array([cell['centroid'] for cell in cell_props])
    classifications = np.array([cell['classification'] for cell in cell_props])

    # Skip if too few cells
    if len(centroids) < 5:
        return

    # Build a graph to represent cell neighborhoods
    G = nx.Graph()

    # Add nodes with positions and classifications
    for i, (pos, classification) in enumerate(zip(centroids, classifications)):
        G.add_node(i, pos=(pos[1], -pos[0]), classification=classification)

    # Connect nearest neighbors
    k = min(5, len(centroids) - 1)
    if k > 0:
        nn = NearestNeighbors(n_neighbors=k+1)  # +1 to include self
        nn.fit(centroids)
        distances, indices = nn.kneighbors(centroids)

        # Add edges
        for i, neighbors in enumerate(indices):
            # Skip first neighbor (self)
            for j, d in zip(neighbors[1:], distances[i, 1:]):
                G.add_edge(i, j, weight=1.0/d if d > 0 else 1.0)

    # Create figure
    plt.figure(figsize=(14, 12))

    # Get node positions
    pos = nx.get_node_attributes(G, 'pos')

    # Get node colors based on classification
    node_colors = []
    for node in G.nodes():
        if G.nodes[node]['classification'] == 'Senescent':
            node_colors.append('red')
        elif G.nodes[node]['classification'] == 'Non-senescent':
            node_colors.append('blue')
        else:
            node_colors.append('gray')

    # Draw the graph
    nx.draw_networkx_nodes(
        G, pos,
        node_color=node_colors,
        node_size=80,
        alpha=0.8,
        edgecolors='black',
        linewidths=0.5
    )

    # Categorize edges based on node classifications
    sen_to_sen_edges = []
    non_sen_to_non_sen_edges = []
    mixed_edges = []

    for u, v in G.edges():
        u_class = G.nodes[u]['classification']
        v_class = G.nodes[v]['classification']

        if u_class == 'Senescent' and v_class == 'Senescent':
            sen_to_sen_edges.append((u, v))
        elif u_class == 'Non-senescent' and v_class == 'Non-senescent':
            non_sen_to_non_sen_edges.append((u, v))
        else:
            mixed_edges.append((u, v))

    # Draw edges with different colors
    nx.draw_networkx_edges(
        G, pos,
        edgelist=sen_to_sen_edges,
        width=1.0,
        alpha=0.7,
        edge_color='red'
    )

    nx.draw_networkx_edges(
        G, pos,
        edgelist=non_sen_to_non_sen_edges,
        width=1.0,
        alpha=0.7,
        edge_color='blue'
    )

    nx.draw_networkx_edges(
        G, pos,
        edgelist=mixed_edges,
        width=0.5,
        alpha=0.3,
        edge_color='purple',
        style='dashed'
    )

    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red',
                  markersize=10, label='Senescent Cell'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue',
                  markersize=10, label='Non-senescent Cell'),
        plt.Line2D([0], [0], color='red', lw=2, label='Senescent-Senescent Connection'),
        plt.Line2D([0], [0], color='blue', lw=2, label='Non-senescent-Non-senescent Connection'),
        plt.Line2D([0], [0], color='purple', lw=2, linestyle='--', label='Mixed Connection')
    ]
    plt.legend(handles=legend_elements, loc='upper right')

    plt.title(f'Cell Interaction Network - {sample_id}', fontsize=14)
    plt.axis('off')

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{sample_id}_cell_network.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Create a density visualization showing clustering of senescent cells
    plt.figure(figsize=(12, 10))

    # Get just the senescent cell positions
    senescent_positions = [pos[i] for i, c in enumerate(classifications) if c == 'Senescent']
    non_senescent_positions = [pos[i] for i, c in enumerate(classifications) if c == 'Non-senescent']

    if senescent_positions and len(senescent_positions) > 5:
        senescent_positions = np.array(senescent_positions)

        # Plot heatmap for senescent cells
        try:
            # Create a 2D histogram
            heatmap, xedges, yedges = np.histogram2d(
                [p[1] for p in senescent_positions],
                [-p[0] for p in senescent_positions],
                bins=50
            )

            # Smooth the heatmap
            heatmap = ndimage.gaussian_filter(heatmap, sigma=1.5)

            # Plot the heatmap with transparency
            extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
            plt.imshow(heatmap.T, extent=extent, origin='lower', cmap='Reds', alpha=0.7)

            # Add contour lines
            plt.contour(heatmap.T, extent=extent, colors='red', alpha=0.5, levels=5)
        except Exception as e:
            print(f"Warning: Could not create heatmap for {sample_id}: {e}")

    # Plot all cell positions
    plt.scatter(
        [p[1] for p in centroids],
        [-p[0] for p in centroids],
        c=[{'Senescent': 'red', 'Non-senescent': 'blue', 'Unknown': 'gray'}[c] for c in classifications],
        s=30,
        alpha=0.7,
        edgecolor='black',
        linewidth=0.5
    )

    plt.title(f'Senescent Cell Density - {sample_id}', fontsize=14)
    plt.xlabel('X Position', fontsize=12)
    plt.ylabel('Y Position', fontsize=12)
    plt.colorbar(label='Senescent Cell Density')

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{sample_id}_senescent_density.png'), dpi=300, bbox_inches='tight')
    plt.close()

def create_summary_visualizations(spatial_df, neighbor_df, output_dir):
    """
    Create summary visualizations of spatial statistics across all samples.

    Parameters:
    -----------
    spatial_df : DataFrame
        DataFrame with spatial statistics
    neighbor_df : DataFrame
        DataFrame with neighbor statistics
    output_dir : str
        Directory to save visualizations
    """
    if spatial_df.empty or neighbor_df.empty:
        return

    # Merge the dataframes
    df = pd.merge(spatial_df, neighbor_df, on='sample_id', how='outer')

    # Calculate overall senescent clustering pattern
    # 1. Plot clustering index vs percentage of senescent cells
    plt.figure(figsize=(10, 8))

    plt.scatter(
        df['senescent_percentage'],
        df['sen_clustering_index'],
        s=80,
        alpha=0.7,
        c=df['senescent_percentage'],
        cmap='RdYlBu_r',
        edgecolor='black'
    )

    # Add trendline
    if len(df) > 1:
        try:
            z = np.polyfit(df['senescent_percentage'], df['sen_clustering_index'], 1)
            p = np.poly1d(z)
            plt.plot(
                np.sort(df['senescent_percentage']),
                p(np.sort(df['senescent_percentage'])),
                "r--",
                alpha=0.8
            )

            # Calculate correlation
            corr = df['senescent_percentage'].corr(df['sen_clustering_index'])
            plt.text(
                0.05, 0.95,
                f'Correlation: {corr:.2f}',
                transform=plt.gca().transAxes,
                fontsize=12,
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
            )
        except Exception as e:
            print(f"Warning: Could not create trendline: {e}")

    # Add reference line for random distribution
    plt.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)

    plt.title('Senescent Cell Clustering vs Percentage', fontsize=14)
    plt.xlabel('Percentage of Senescent Cells (%)', fontsize=12)
    plt.ylabel('Clustering Index\n(>1 = Clustered, <1 = Dispersed)', fontsize=12)
    plt.colorbar(label='Percentage of Senescent Cells (%)')
    plt.grid(True, alpha=0.3)

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'senescent_clustering_vs_percentage.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Create a spatial pattern summary for all samples
    plt.figure(figsize=(12, 10))

    # Prepare data for boxplot
    boxplot_data = [
        df['senescent_nn_mean_distance'].dropna(),
        df['non_senescent_nn_mean_distance'].dropna(),
    ]

    # Create boxplot
    plt.boxplot(
        boxplot_data,
        labels=['Senescent', 'Non-senescent'],
        patch_artist=True,
        boxprops=dict(facecolor='lightblue'),
        medianprops=dict(color='red'),
        showfliers=False
    )

    # Add individual points
    for i, data in enumerate(boxplot_data):
        x = np.random.normal(i+1, 0.04, size=len(data))
        plt.scatter(x, data, alpha=0.6, s=40, edgecolor='black', linewidth=0.5)

    plt.title('Nearest Neighbor Distances by Cell Type', fontsize=14)
    plt.ylabel('Mean Nearest Neighbor Distance', fontsize=12)
    plt.grid(axis='y', alpha=0.3)

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'nearest_neighbor_distances.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 3. Create a heatmap of correlation between spatial metrics
    plt.figure(figsize=(12, 10))

    # Select relevant columns for correlation
    correlation_cols = [
        'senescent_percentage', 'sen_clustering_index', 'nn_ratio',
        'sen_dispersion_index', 'senescent_to_senescent_ratio',
        'mixed_neighbor_percentage'
    ]

    # Calculate correlation matrix
    corr_matrix = df[correlation_cols].corr()

    # Create heatmap
    sns.heatmap(
        corr_matrix,
        annot=True,
        cmap='coolwarm',
        vmin=-1,
        vmax=1,
        linewidths=0.5,
        fmt='.2f'
    )

    plt.title('Correlation Between Spatial Metrics', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'spatial_metrics_correlation.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 4. Create a distribution of senescent cell clustering index
    plt.figure(figsize=(10, 8))

    # Create histogram with KDE
    sns.histplot(
        df['sen_clustering_index'].dropna(),
        kde=True,
        stat='density',
        color='skyblue',
        edgecolor='black',
        alpha=0.7
    )

    # Add vertical line at random distribution (1.0)
    plt.axvline(x=1.0, color='red', linestyle='--', alpha=0.7, label='Random Distribution')

    # Add mean line
    mean_value = df['sen_clustering_index'].mean()
    plt.axvline(x=mean_value, color='green', linestyle='-', alpha=0.7, label=f'Mean = {mean_value:.2f}')

    plt.title('Distribution of Senescent Cell Clustering Index', fontsize=14)
    plt.xlabel('Clustering Index\n(>1 = Clustered, <1 = Dispersed)', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)

    # Save the visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'clustering_index_distribution.png'), dpi=300, bbox_inches='tight')
    plt.close()

Here it's possible to visualise the classification

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from skimage import io, measure, segmentation, color
import seaborn as sns
from scipy import ndimage
import cv2
from tqdm import tqdm

def visualize_classification_on_masks(cell_mask_dir, nuclei_dir, results_csv, output_dir="visualization_results"):
    """
    Visualize the senescent cell classification results by overlaying them on the original masks.

    Parameters:
    -----------
    cell_mask_dir : str
        Directory containing the cell mask files
    nuclei_dir : str
        Directory containing the nuclei mask files
    results_csv : str
        Path to the cell_classification_results.csv file
    output_dir : str
        Directory to save the visualization results
    """
    print("=== Visualization of Senescent Cell Classification ===")

    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Load classification results
    results_df = pd.read_csv(results_csv)
    print(f"Loaded {len(results_df)} classified cells")

    # Extract sample IDs and cell IDs from the results
    # Assuming cell_id format is "sample_id_cell_id"
    results_df['original_sample_id'] = results_df['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
    results_df['original_cell_id'] = results_df['cell_id'].apply(lambda x: int(x.split('_')[-1]))

    # Group by sample ID
    sample_groups = results_df.groupby('original_sample_id')

    # Find and process mask files for each sample
    cell_files = [f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]
    nuclei_files = [f for f in os.listdir(nuclei_dir) if f.endswith(('.tif', '.tiff')) and not f.startswith('.')]

    # Process each sample in the results
    for sample_id, group in tqdm(sample_groups, desc="Visualizing samples"):
        # Find the corresponding cell and nuclei mask files
        cell_file = next((f for f in cell_files if sample_id in f), None)
        nuclei_file = next((f for f in nuclei_files if sample_id in f), None)

        if cell_file is None or nuclei_file is None:
            print(f"Warning: Could not find mask files for sample {sample_id}")
            continue

        # Load the mask images
        cell_mask_path = os.path.join(cell_mask_dir, cell_file)
        nuclei_mask_path = os.path.join(nuclei_dir, nuclei_file)

        try:
            # Load masks
            cell_mask = load_mask_image(cell_mask_path)
            nuclei_mask = load_mask_image(nuclei_mask_path)

            if cell_mask is None or nuclei_mask is None:
                print(f"Error: Could not load masks for {sample_id}")
                continue

            # Create colored overlay based on classification
            colored_mask = create_classification_overlay(cell_mask, nuclei_mask, group)

            # Save the visualization
            output_file = os.path.join(output_dir, f"{sample_id}_classification_overlay.png")
            plt.imsave(output_file, colored_mask)

            # Create a more detailed visualization with boundaries
            detailed_vis = create_detailed_visualization(cell_mask, nuclei_mask, group)
            detailed_output = os.path.join(output_dir, f"{sample_id}_detailed_classification.png")
            plt.imsave(detailed_output, detailed_vis)

        except Exception as e:
            print(f"Error processing {sample_id}: {str(e)}")

    # Create summary visualizations
    create_summary_visualizations(results_df, output_dir)

    print(f"\nVisualization complete! Results saved to {output_dir}/")

def load_mask_image(filepath):
    """Loads a mask image, ensuring it's properly formatted for visualization."""
    try:
        img = io.imread(filepath)

        # Handle multi-channel images
        if img.ndim > 2:
            if img.shape[2] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[2] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else:
                img = img[:,:,0]

        # If it's already a labeled mask (values > 1), use it directly
        if np.max(img) > 1:
            return img

        # Otherwise, label the connected components
        labeled_img, num_features = ndimage.label(img > 0)
        return labeled_img

    except Exception as e:
        print(f"Error loading image {filepath}: {str(e)}")
        return None

def create_classification_overlay(cell_mask, nuclei_mask, classification_group):
    """
    Create a colored overlay showing senescent vs non-senescent cells.

    Parameters:
    -----------
    cell_mask : ndarray
        Labeled cell mask
    nuclei_mask : ndarray
        Labeled nuclei mask
    classification_group : DataFrame
        DataFrame containing classification results for this sample

    Returns:
    --------
    colored_mask : ndarray
        RGB image with colored overlay
    """
    # Create a blank RGB image
    h, w = cell_mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)

    # Define colors for each class
    # Senescent: Red, Non-senescent: Blue, Background: Black
    senescent_color = np.array([220, 50, 50])     # Red
    non_senescent_color = np.array([50, 50, 220]) # Blue
    nuclei_color = np.array([255, 255, 100])      # Yellow

    # Create mapping of cell IDs to classifications
    cell_classifications = dict(zip(
        classification_group['original_cell_id'],
        classification_group['cell_type']
    ))

    # Get unique cell IDs
    cell_ids = np.unique(cell_mask)
    if 0 in cell_ids:  # Skip background (0)
        cell_ids = cell_ids[cell_ids > 0]

    # Color each cell according to its classification
    for cell_id in cell_ids:
        if cell_id in cell_classifications:
            cell_type = cell_classifications[cell_id]
            if cell_type == 'Senescent':
                colored_mask[cell_mask == cell_id] = senescent_color
            else:
                colored_mask[cell_mask == cell_id] = non_senescent_color

    # Overlay nuclei
    nuclei_ids = np.unique(nuclei_mask)
    if 0 in nuclei_ids:
        nuclei_ids = nuclei_ids[nuclei_ids > 0]

    for nuc_id in nuclei_ids:
        colored_mask[nuclei_mask == nuc_id] = nuclei_color

    return colored_mask

def create_detailed_visualization(cell_mask, nuclei_mask, classification_group):
    """
    Create a detailed visualization with cell boundaries and labels.

    Parameters:
    -----------
    cell_mask : ndarray
        Labeled cell mask
    nuclei_mask : ndarray
        Labeled nuclei mask
    classification_group : DataFrame
        DataFrame containing classification results for this sample

    Returns:
    --------
    detailed_vis : ndarray
        RGB image with boundaries, labels, and classification overlay
    """
    # Create a blank RGB image
    h, w = cell_mask.shape
    detailed_vis = np.zeros((h, w, 3), dtype=np.uint8)

    # Define colors
    senescent_color = np.array([220, 100, 100])    # Lighter red
    non_senescent_color = np.array([100, 100, 220]) # Lighter blue
    boundary_color = np.array([255, 255, 255])      # White
    nuclei_color = np.array([200, 200, 50])         # Lighter yellow

    # Create mapping of cell IDs to classifications
    cell_classifications = dict(zip(
        classification_group['original_cell_id'],
        classification_group['cell_type']
    ))

    # Get unique cell IDs
    cell_ids = np.unique(cell_mask)
    if 0 in cell_ids:  # Skip background (0)
        cell_ids = cell_ids[cell_ids > 0]

    # Color each cell according to its classification
    for cell_id in cell_ids:
        if cell_id in cell_classifications:
            cell_type = cell_classifications[cell_id]
            # Fill cell
            if cell_type == 'Senescent':
                detailed_vis[cell_mask == cell_id] = senescent_color
            else:
                detailed_vis[cell_mask == cell_id] = non_senescent_color

    # Find and highlight cell boundaries
    cell_boundaries = segmentation.find_boundaries(cell_mask, mode='outer')
    detailed_vis[cell_boundaries] = boundary_color

    # Overlay nuclei
    nuclei_ids = np.unique(nuclei_mask)
    if 0 in nuclei_ids:
        nuclei_ids = nuclei_ids[nuclei_ids > 0]

    for nuc_id in nuclei_ids:
        detailed_vis[nuclei_mask == nuc_id] = nuclei_color

    return detailed_vis

def create_summary_visualizations(results_df, output_dir):
    """
    Create summary visualizations of the classification results.

    Parameters:
    -----------
    results_df : DataFrame
        DataFrame containing all classification results
    output_dir : str
        Directory to save the visualization results
    """
    # 1. Improved UMAP visualization
    plt.figure(figsize=(12, 10))

    # Scatter plot with UMAP results colored by cell type
    plt.scatter(
        results_df['umap_x'],
        results_df['umap_y'],
        c=results_df['cell_type'].map({'Senescent': 'red', 'Non-senescent': 'blue'}),
        s=30,
        alpha=0.7,
        edgecolor='k',
        linewidth=0.5
    )

    # Add contour lines to show density
    for cell_type, color in zip(['Senescent', 'Non-senescent'], ['red', 'blue']):
        subset = results_df[results_df['cell_type'] == cell_type]
        if len(subset) > 10:  # Need enough points for kernel density
            sns.kdeplot(
                x=subset['umap_x'],
                y=subset['umap_y'],
                levels=5,
                color=color,
                alpha=0.3,
                linewidths=1
            )

    # Add legend
    plt.legend(
        handles=[
            mpatches.Patch(color='red', label='Senescent'),
            mpatches.Patch(color='blue', label='Non-senescent')
        ],
        title='Cell Type',
        loc='upper right'
    )

    plt.title('UMAP Projection of Cells by Senescence Classification', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)

    # Save the UMAP visualization
    plt.savefig(os.path.join(output_dir, 'enhanced_umap_clustering.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Feature importance visualization
    key_features = [
        'cell_area', 'nuclei_count', 'avg_nucleus_area',
        'cell_perimeter', 'cell_circularity', 'nucleus_to_cell_area_ratio',
        'cell_eccentricity', 'nucleus_displacement'
    ]

    # Calculate feature differences between cell types
    feature_differences = {}
    for feature in key_features:
        if feature in results_df.columns:
            sen_mean = results_df[results_df['cell_type'] == 'Senescent'][feature].mean()
            non_sen_mean = results_df[results_df['cell_type'] == 'Non-senescent'][feature].mean()

            # Calculate ratio or difference
            if non_sen_mean != 0:
                ratio = sen_mean / non_sen_mean
            else:
                ratio = float('inf') if sen_mean != 0 else 1.0

            feature_differences[feature] = (feature, ratio, sen_mean, non_sen_mean)

    # Sort by ratio (most discriminative first)
    sorted_features = sorted(
        feature_differences.values(),
        key=lambda x: abs(x[1] - 1.0),
        reverse=True
    )

    # Create bar chart of feature importance
    fig, ax = plt.subplots(figsize=(12, 8))

    features = [f[0].replace('_', ' ').title() for f in sorted_features]
    ratios = [f[1] for f in sorted_features]

    # Transform ratios for better visualization
    log_ratios = []
    for r in ratios:
        if r > 1:
            log_ratios.append(r - 1)  # For values > 1, show how much greater than 1
        else:
            log_ratios.append(-(1/r - 1))  # For values < 1, show negative of how much less than 1

    bars = ax.bar(
        features,
        log_ratios,
        color=[
            'red' if r > 0 else 'blue' for r in log_ratios
        ]
    )

    # Add exact ratio values as text
    for i, bar in enumerate(bars):
        height = bar.get_height()
        if height >= 0:
            va = 'bottom'
            offset = 0.05
        else:
            va = 'top'
            offset = -0.1
        ax.text(
            bar.get_x() + bar.get_width()/2,
            height + offset,
            f'Ratio: {ratios[i]:.2f}',
            ha='center',
            va=va,
            rotation=45,
            fontsize=9
        )

    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    plt.title('Feature Importance for Senescence Classification', fontsize=14)
    plt.ylabel('Log Ratio (Senescent / Non-senescent)', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Add legend
    plt.legend(
        handles=[
            mpatches.Patch(color='red', label='Higher in Senescent'),
            mpatches.Patch(color='blue', label='Higher in Non-senescent')
        ],
        loc='upper right'
    )

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_importance.png'), dpi=300)
    plt.close()

    # 3. Create a multi-feature visualization
    plt.figure(figsize=(12, 10))

    # Scatter plot with cell area vs nuclear area, colored by classification
    scatter = plt.scatter(
        results_df['cell_area'],
        results_df['avg_nucleus_area'],
        c=results_df['cell_type'].map({'Senescent': 'red', 'Non-senescent': 'blue'}),
        s=results_df['nuclei_count'] * 10,  # Size by nuclei count
        alpha=0.7,
        edgecolor='k',
        linewidth=0.5
    )

    # Add contour lines to show density
    for cell_type, color in zip(['Senescent', 'Non-senescent'], ['red', 'blue']):
        subset = results_df[results_df['cell_type'] == cell_type]
        if len(subset) > 10:  # Need enough points for kernel density
            sns.kdeplot(
                x=subset['cell_area'],
                y=subset['avg_nucleus_area'],
                levels=3,
                color=color,
                alpha=0.3,
                linewidths=1
            )

    # Create custom legend
    handles = [
        mpatches.Patch(color='red', label='Senescent'),
        mpatches.Patch(color='blue', label='Non-senescent'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray',
                  markersize=5, label='1 Nucleus'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray',
                  markersize=10, label='2 Nuclei'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray',
                  markersize=15, label='3 Nuclei')
    ]
    plt.legend(handles=handles, loc='upper left')

    plt.title('Cell Area vs Nuclear Area with Nuclei Count', fontsize=14)
    plt.xlabel('Cell Area', fontsize=12)
    plt.ylabel('Average Nucleus Area', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)

    plt.savefig(os.path.join(output_dir, 'multi_feature_visualization.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # 4. Boxplot visualization with statistical significance
    plt.figure(figsize=(15, 10))

    # Select up to 6 most discriminative features
    top_features = sorted_features[:6]

    # Create subplot grid
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    for i, (feature, ratio, sen_mean, non_sen_mean) in enumerate(top_features):
        if i < 6:  # Limit to 6 features
            ax = axes[i]

            # Create violin plot with individual points
            sns.violinplot(
                x='cell_type',
                y=feature,
                data=results_df,
                palette={'Senescent': 'red', 'Non-senescent': 'blue'},
                ax=ax,
                inner='quartile',
                alpha=0.7
            )

            # Add individual points
            sns.stripplot(
                x='cell_type',
                y=feature,
                data=results_df,
                color='black',
                size=2,
                alpha=0.3,
                ax=ax,
                jitter=True
            )

            # Set title and labels
            ax.set_title(feature.replace('_', ' ').title(), fontsize=12)
            ax.set_xlabel('')

            # Add ratio text
            ax.text(
                0.5, 0.95,
                f'Ratio: {ratio:.2f}',
                transform=ax.transAxes,
                ha='center',
                va='top',
                bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5')
            )

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_distribution_comparison.png'), dpi=300)
    plt.close()

def integrate_with_existing_code(main_func, cell_dir, nuclei_dir, results_csv, output_dir):
    """
    This function shows how to integrate the new visualization code with the existing code.

    Parameters:
    -----------
    main_func : function
        The main function from the original code
    cell_dir : str
        Directory containing cell mask files
    nuclei_dir : str
        Directory containing nuclei mask files
    results_csv : str
        Path to the cell_classification_results.csv file
    output_dir : str
        Directory to save visualization results
    """
    # First run the main analysis if results don't exist yet
    if not os.path.exists(results_csv):
        print("Running senescent cell classification analysis...")
        main_func(cell_dir, nuclei_dir, os.path.dirname(results_csv))

    # Then run the visualization
    print("\nCreating enhanced visualizations...")
    visualize_classification_on_masks(cell_dir, nuclei_dir, results_csv, output_dir)

if __name__ == "__main__":
    # Example usage
    cell_mask_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
    nuclei_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
    results_csv = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/cell_classification_results.csv"
    visualization_dir = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/visualization"

    # If you want to run just the visualization:
    visualize_classification_on_masks(cell_mask_dir, nuclei_dir, results_csv, visualization_dir)

    # If you want to integrate with the original code:
    # from original_code import main
    # integrate_with_existing_code(main, cell_mask_dir, nuclei_dir, results_csv, visualization_dir)

=== Visualization of Senescent Cell Classification ===
Loaded 2472 classified cells


Visualizing samples: 100%|██████████| 8/8 [00:40<00:00,  5.07s/it]

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.violinplot(

Passing


Visualization complete! Results saved to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/visualization/


<Figure size 1500x1000 with 0 Axes>

In [7]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import umap
from tqdm import tqdm

# --- Configuration & Parameters ---
INPUT_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/cell_classification_results.csv" # Path to your existing results
OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V3" # New directory for refined results
# User updated this threshold
MULTINUCLEATION_THRESHOLD = 1 # Cells with nuclei_count > this threshold will be marked senescent

# Quantile for preliminary senescence classification based on the normalized score.
PRELIMINARY_SENESCENCE_QUANTILE = 0.90 # Target ~15% initially senescent by score (before multinucleation rule)

FEATURES_FOR_CLUSTERING = [
    'cell_area',
    'cell_perimeter',
    'cell_eccentricity',
    'cell_circularity',
    'cell_aspect_ratio',
    'avg_nucleus_area',
    'max_nucleus_area',
    'avg_nucleus_eccentricity',
    'nucleus_area_std',
    'nucleus_displacement',
    'nucleus_to_cell_area_ratio'
]

SENESCENCE_SCORE_WEIGHTS = {
    'cell_area': 1.5,
    'cell_perimeter': 0.5,
    'cell_eccentricity': 0.5,
    'cell_circularity': -1.0,
    'cell_aspect_ratio': 0.5,
    'avg_nucleus_area': 1.0,
    'avg_nucleus_eccentricity': 0.3,
    'nucleus_to_cell_area_ratio': -1.5,
    'nuclear_enlargement': 1.0,
    'cell_enlargement': 1.5,
    'nucleus_displacement': 0.2,
}

AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']


def load_and_prepare_data(csv_path):
    """Loads the data and checks for required columns."""
    print(f"Loading data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded {len(df)} cells.")
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        return None

    essential_cols_for_operation = list(SENESCENCE_SCORE_WEIGHTS.keys()) + ['nuclei_count', 'cell_id', 'sample_id']
    all_needed_cols = list(set(FEATURES_FOR_CLUSTERING + essential_cols_for_operation))

    missing_cols = [col for col in all_needed_cols if col not in df.columns]
    if missing_cols:
        missing_scoring_or_rules_cols = [col for col in essential_cols_for_operation if col not in df.columns]
        if missing_scoring_or_rules_cols:
            print(f"Error: Critical columns for scoring/rules are missing: {missing_scoring_or_rules_cols}")
            return None
        print(f"Warning: Some columns listed in FEATURES_FOR_CLUSTERING are missing: {missing_cols}. UMAP/KMeans might be affected if these are used.")

    return df

def calculate_senescence_score(df, score_weights):
    """Calculates a per-cell senescence score."""
    print("Calculating per-cell senescence score...")

    features_for_scoring_present = [f for f in score_weights.keys() if f in df.columns]
    if not features_for_scoring_present:
        print("Error: No features for senescence score calculation are present in the DataFrame.")
        df['senescence_score'] = np.nan
        df['senescence_score_normalized'] = np.nan
        return df

    score_df = df[features_for_scoring_present].copy()

    for col in AREA_FEATURES_TO_LOG:
        if col in score_df.columns:
            score_df[col] = np.log1p(score_df[col])
            print(f"  Log-transformed scoring feature: {col}")

    scaler = StandardScaler()
    numeric_score_cols = score_df.select_dtypes(include=np.number).columns
    if not numeric_score_cols.empty:
        score_features_standardized = scaler.fit_transform(score_df[numeric_score_cols])
        score_features_standardized_df = pd.DataFrame(score_features_standardized, columns=numeric_score_cols, index=score_df.index)
    else:
        print("  Warning: No numeric columns found for standardization in scoring features.")
        score_features_standardized_df = pd.DataFrame(index=score_df.index)

    df['senescence_score'] = 0.0
    for feature, weight in score_weights.items():
        if feature in score_features_standardized_df.columns:
            df['senescence_score'] += score_features_standardized_df[feature] * weight
        elif feature in df.columns:
            print(f"  Warning: Scoring feature '{feature}' was not in standardized set, using original value (ensure this is intended).")
        else:
             print(f"  Warning: Feature '{feature}' for scoring not found. Skipping.")

    if df['senescence_score'].isna().all() or (df['senescence_score'].max() == df['senescence_score'].min()):
        print("  Warning: Senescence scores are all NaN or uniform. Normalization will result in NaN or 0.")
        df['senescence_score_normalized'] = np.nan if df['senescence_score'].isna().all() else 0.0
    else:
        df['senescence_score_normalized'] = (df['senescence_score'] - df['senescence_score'].min()) / \
                                           (df['senescence_score'].max() - df['senescence_score'].min())
    print("Senescence score calculation complete.")
    return df

def perform_refined_clustering(df, feature_columns_for_clustering, umap_n_neighbors=30, umap_min_dist=0.1, umap_random_state=42, kmeans_n_clusters=2, kmeans_random_state=42):
    """
    Performs UMAP and k-Means clustering on selected features.
    """
    print(f"\nPerforming refined clustering using features: {feature_columns_for_clustering}...")

    actual_clustering_features = [col for col in feature_columns_for_clustering if col in df.columns]
    if not actual_clustering_features:
        print("Error: None of the specified FEATURES_FOR_CLUSTERING are present in the DataFrame. Skipping UMAP/KMeans.")
        df['umap_x_refined'] = np.nan
        df['umap_y_refined'] = np.nan
        df['cluster_refined'] = np.nan
        return df

    features_for_clustering_df = df[actual_clustering_features].copy()

    for col in actual_clustering_features:
        if not pd.api.types.is_numeric_dtype(features_for_clustering_df[col]):
            try:
                features_for_clustering_df[col] = pd.to_numeric(features_for_clustering_df[col])
                print(f"  Column {col} converted to numeric for clustering.")
            except ValueError:
                print(f"  Warning: Column {col} for clustering is not numeric and could not be converted. It will be excluded.")
                features_for_clustering_df = features_for_clustering_df.drop(columns=[col])
                actual_clustering_features.remove(col)

    features_for_clustering_df = features_for_clustering_df.fillna(features_for_clustering_df.mean())

    if features_for_clustering_df.empty or features_for_clustering_df.shape[1] == 0:
        print("Error: No valid numeric features available for clustering after processing. Skipping UMAP/KMeans.")
        df['umap_x_refined'] = np.nan
        df['umap_y_refined'] = np.nan
        df['cluster_refined'] = np.nan
        return df

    print("\nApplying log transformation to selected area/perimeter features for clustering...")
    for col in AREA_FEATURES_TO_LOG:
        if col in features_for_clustering_df.columns:
            features_for_clustering_df[col] = np.log1p(features_for_clustering_df[col])
            print(f"  Log-transformed clustering feature: {col}")

    print("\nStandardizing features for clustering...")
    scaler = StandardScaler()
    features_standardized = scaler.fit_transform(features_for_clustering_df)

    print("\nPerforming UMAP reduction...")
    actual_umap_n_neighbors = min(umap_n_neighbors, len(features_standardized) - 1)
    if actual_umap_n_neighbors < 2 :
        print(f"  Warning: Not enough samples ({len(features_standardized)}) for UMAP with n_neighbors={umap_n_neighbors}. Skipping UMAP and KMeans.")
        df['umap_x_refined'] = np.nan
        df['umap_y_refined'] = np.nan
        df['cluster_refined'] = np.nan
        return df

    print(f"  Using UMAP n_neighbors: {actual_umap_n_neighbors}")
    reducer = umap.UMAP(n_neighbors=actual_umap_n_neighbors, min_dist=umap_min_dist, random_state=umap_random_state, n_components=2)
    embedding = reducer.fit_transform(features_standardized)
    df['umap_x_refined'] = embedding[:, 0]
    df['umap_y_refined'] = embedding[:, 1]

    print(f"\nPerforming k-Means clustering (k={kmeans_n_clusters})...")
    if len(embedding) >= kmeans_n_clusters:
        kmeans = KMeans(n_clusters=kmeans_n_clusters, random_state=kmeans_random_state, n_init='auto')
        df['cluster_refined'] = kmeans.fit_predict(embedding)
    else:
        print(f"  Warning: Not enough samples ({len(embedding)}) for KMeans with n_clusters={kmeans_n_clusters}. Skipping KMeans.")
        df['cluster_refined'] = np.nan

    print("Refined clustering complete.")
    return df

def classify_cells(df, preliminary_quantile, multinucleation_threshold):
    """Assigns preliminary and final cell types based on senescence score and multinucleation."""
    print("\nClassifying cells...")

    if 'senescence_score_normalized' not in df.columns or df['senescence_score_normalized'].isna().all():
        print("  Error: 'senescence_score_normalized' is missing or all NaN. Cannot perform preliminary classification.")
        df['cell_type_preliminary'] = 'Unknown'
        prelim_sen_count_by_score = 0
    else:
        score_threshold = df['senescence_score_normalized'].quantile(preliminary_quantile)
        print(f"  Using senescence score quantile {preliminary_quantile} (threshold = {score_threshold:.4f}) for preliminary classification.")
        df['cell_type_preliminary'] = df['senescence_score_normalized'].apply(
            lambda x: 'Senescent' if x >= score_threshold else 'Non-senescent'
        )
        prelim_sen_count_by_score = (df['cell_type_preliminary'] == 'Senescent').sum()
        print(f"  {prelim_sen_count_by_score} cells ({prelim_sen_count_by_score/len(df)*100:.2f}%) preliminarily classified as Senescent by score.")

    # --- Detailed breakdown for multinucleation ---
    df['cell_type_final'] = df['cell_type_preliminary'] # Start final classification from preliminary

    if 'nuclei_count' in df.columns:
        multinucleated_mask = df['nuclei_count'] > multinucleation_threshold
        total_multinucleated = multinucleated_mask.sum()

        print(f"\n--- Multinucleation Rule (nuclei_count > {multinucleation_threshold}) ---")
        print(f"  Total cells considered polynucleated: {total_multinucleated} ({total_multinucleated/len(df)*100:.2f}%)")

        if total_multinucleated > 0:
            # Cells that are polynucleated AND were already called Senescent by score
            multinucleated_and_sen_by_score = (multinucleated_mask & (df['cell_type_preliminary'] == 'Senescent')).sum()
            print(f"    Of these {total_multinucleated} polynucleated cells:")
            print(f"      - {multinucleated_and_sen_by_score} were ALREADY 'Senescent' by score.")

            # Cells that are polynucleated AND were Non-senescent by score (these will be reclassified)
            multinucleated_reclassified_to_senescent = (multinucleated_mask & (df['cell_type_preliminary'] == 'Non-senescent')).sum()
            print(f"      - {multinucleated_reclassified_to_senescent} were 'Non-senescent' by score and are NOW RECLASSIFIED to 'Senescent'.")

        # Apply the rule: all multinucleated cells are finally 'Senescent'
        df.loc[multinucleated_mask, 'cell_type_final'] = 'Senescent'
    else:
        print("  Warning: 'nuclei_count' column not found. Cannot apply multinucleation rule.")

    final_sen_count = (df['cell_type_final'] == 'Senescent').sum()
    print(f"\nTotal cells finally classified as Senescent: {final_sen_count} ({final_sen_count/len(df)*100:.2f}%)")
    print("Cell classification complete.")
    return df

def visualize_refined_results(df, output_dir):
    """Generates and saves visualizations for the refined analysis."""
    print("\nGenerating visualizations for refined results...")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 1. UMAP colored by Senescence Score (Continuous)
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().any() and 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
        plt.figure(figsize=(12, 10))
        scatter = plt.scatter(
            df['umap_x_refined'], df['umap_y_refined'],
            c=df['senescence_score_normalized'],
            cmap='viridis', s=15, alpha=0.7
        )
        plt.colorbar(scatter, label='Normalized Senescence Score')
        plt.title('Refined UMAP: Cells Colored by Senescence Score', fontsize=16)
        plt.xlabel('UMAP Dimension 1 (Refined)', fontsize=12)
        plt.ylabel('UMAP Dimension 2 (Refined)', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig(os.path.join(output_dir, 'umap_refined_by_senescence_score.png'), dpi=300, bbox_inches='tight')
        plt.close()
    else:
        print("  Skipping UMAP by senescence score plot (UMAP data or score data missing/all NaN).")

    # 2. UMAP colored by Final Cell Type (Binary/Categorical)
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().any() and 'cell_type_final' in df.columns:
        plt.figure(figsize=(12, 10))
        unique_cell_types = df['cell_type_final'].unique()
        palette = {ctype: ('red' if ctype == 'Senescent' else ('blue' if ctype == 'Non-senescent' else 'grey')) for ctype in unique_cell_types}

        for cell_type, color in palette.items():
            subset = df[df['cell_type_final'] == cell_type]
            if not subset.empty:
                 plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=cell_type, color=color, s=15, alpha=0.7)

        plt.title('Refined UMAP: Cells Colored by Final Classification', fontsize=16)
        plt.xlabel('UMAP Dimension 1 (Refined)', fontsize=12)
        plt.ylabel('UMAP Dimension 2 (Refined)', fontsize=12)
        if palette:
            plt.legend(title='Final Cell Type')
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig(os.path.join(output_dir, 'umap_refined_by_final_cell_type.png'), dpi=300, bbox_inches='tight')
        plt.close()
    else:
        print("  Skipping UMAP by final cell type plot (UMAP or classification data missing).")

    # 3. Distribution of Senescence Score
    if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
        plt.figure(figsize=(10, 6))
        sns.histplot(df['senescence_score_normalized'].dropna(), kde=True, bins=50)
        plt.title('Distribution of Normalized Senescence Score', fontsize=16)
        plt.xlabel('Normalized Senescence Score', fontsize=12)
        plt.ylabel('Frequency', fontsize=12)
        if 'cell_type_preliminary' in df.columns:
            score_threshold_val = df['senescence_score_normalized'].quantile(PRELIMINARY_SENESCENCE_QUANTILE)
            plt.axvline(score_threshold_val, color='r', linestyle='--', label=f'Quantile Threshold ({PRELIMINARY_SENESCENCE_QUANTILE*100:.0f}th percentile)')
            plt.legend()
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig(os.path.join(output_dir, 'senescence_score_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()
    else:
        print("  Skipping senescence score distribution plot (score data missing or all NaN).")

    # 4. Key Feature Comparison for Final Cell Types
    if 'cell_type_final' in df.columns and df['cell_type_final'].nunique() > 1:
        key_comparison_features = [
            'cell_area', 'avg_nucleus_area', 'cell_circularity',
            'nucleus_to_cell_area_ratio', 'senescence_score_normalized', 'nuclei_count'
        ]
        key_comparison_features = [f for f in key_comparison_features if f in df.columns and df[f].notna().any()]

        if key_comparison_features:
            num_features_to_plot = len(key_comparison_features)
            if num_features_to_plot > 0:
                cols_subplot = 3
                rows_subplot = (num_features_to_plot + cols_subplot - 1) // cols_subplot
                fig, axes = plt.subplots(rows_subplot, cols_subplot, figsize=(5 * cols_subplot, 5 * rows_subplot), squeeze=False)
                axes = axes.flatten()

                current_palette = palette if 'palette' in locals() else 'coolwarm'
                order = sorted(df['cell_type_final'].unique())


                for i, feature in enumerate(key_comparison_features):
                    sns.boxplot(x='cell_type_final', y=feature, data=df, ax=axes[i], palette=current_palette, order=order)
                    axes[i].set_title(feature.replace('_', ' ').title(), fontsize=14)
                    axes[i].set_xlabel('Final Cell Type', fontsize=10)
                    axes[i].set_ylabel(feature.replace('_', ' ').title(), fontsize=10)

                for j in range(i + 1, len(axes)):
                    fig.delaxes(axes[j])

                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'feature_comparison_by_final_type.png'), dpi=300, bbox_inches='tight')
                plt.close()
        else:
            print("  Skipping feature comparison plot (no key features with valid data found).")
    else:
        print("  Skipping feature comparison plot (final cell type data insufficient).")

    print("Visualizations saved.")

def main_refined_analysis():
    """Main function to run the refined senescence analysis."""

    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    df = load_and_prepare_data(INPUT_CSV_PATH)
    if df is None:
        return

    df = calculate_senescence_score(df, SENESCENCE_SCORE_WEIGHTS)
    df = perform_refined_clustering(df, FEATURES_FOR_CLUSTERING)

    # Call classify_cells with the globally defined MULTINUCLEATION_THRESHOLD
    df = classify_cells(df, PRELIMINARY_SENESCENCE_QUANTILE, MULTINUCLEATION_THRESHOLD)

    output_csv_path = os.path.join(OUTPUT_DIR, 'cell_classification_results_refined.csv')
    cols_to_save = [col for col in df.columns if col not in ['umap_x', 'umap_y', 'cluster', 'cell_type']]

    if 'umap_x_refined' in df.columns and df['umap_x_refined'].isna().all():
        cols_to_save = [col for col in cols_to_save if col not in ['umap_x_refined', 'umap_y_refined']]
    if 'cluster_refined' in df.columns and df['cluster_refined'].isna().all():
        cols_to_save = [col for col in cols_to_save if col not in ['cluster_refined']]

    df_to_save = df[cols_to_save]
    df_to_save.to_csv(output_csv_path, index=False)
    print(f"\nRefined results saved to: {output_csv_path}")

    visualize_refined_results(df, OUTPUT_DIR)

    print("\nRefined analysis complete!")

if __name__ == '__main__':
    main_refined_analysis()


Loading data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/cell_classification_results.csv...
Successfully loaded 2472 cells.
Calculating per-cell senescence score...
  Log-transformed scoring feature: cell_area
  Log-transformed scoring feature: avg_nucleus_area
  Log-transformed scoring feature: cell_perimeter
Senescence score calculation complete.

Performing refined clustering using features: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio']...

Applying log transformation to selected area/perimeter features for clustering...
  Log-transformed clustering feature: cell_area
  Log-transformed clustering feature: avg_nucleus_area
  Log-transformed clustering feature: max_nucleus_area
  Log-transformed clustering feature: cell_perimeter

Standardizing featu

  warn(



Performing k-Means clustering (k=2)...
Refined clustering complete.

Classifying cells...
  Using senescence score quantile 0.9 (threshold = 0.5531) for preliminary classification.
  248 cells (10.03%) preliminarily classified as Senescent by score.

--- Multinucleation Rule (nuclei_count > 1) ---
  Total cells considered polynucleated: 225 (9.10%)
    Of these 225 polynucleated cells:
      - 42 were ALREADY 'Senescent' by score.
      - 183 were 'Non-senescent' by score and are NOW RECLASSIFIED to 'Senescent'.

Total cells finally classified as Senescent: 431 (17.44%)
Cell classification complete.

Refined results saved to: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V3/cell_classification_results_refined.csv

Generating visualizations for refined results...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type_final', y=feature, data=df, ax=axes[i], palette=current_palette, order=order)

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type_final', y=feature, data=df, ax=axes[i], palette=current_palette, order=order)

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='cell_type_final', y=feature, data=df, ax=axes[i], palette=current_palette, order=order)

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(x='

Visualizations saved.

Refined analysis complete!


In [8]:
import os
import re # Added for extract_sample_id
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches # Added for legends in mask visualization
import seaborn as sns
from skimage import io, measure, segmentation # Added segmentation for find_boundaries
import cv2 # Added for cvtColor if needed by load_image_as_labeled_mask
from scipy import ndimage # Added for ndimage.label
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import umap
from tqdm import tqdm

# --- Configuration & Parameters ---
INPUT_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/cell_classification_results.csv"
OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4" # New version for this run

# Directories for original mask images - UPDATE THESE PATHS
CELL_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
NUCLEI_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
MASK_VISUALIZATION_SUBDIR = "mask_overlays" # Subdirectory for saving mask visualizations

MULTINUCLEATION_THRESHOLD = 1
PRELIMINARY_SENESCENCE_QUANTILE = 0.85

FEATURES_FOR_CLUSTERING = [
    'cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity',
    'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area',
    'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement',
    'nucleus_to_cell_area_ratio'
]

SENESCENCE_SCORE_WEIGHTS = {
    'cell_area': 1.5, 'cell_perimeter': 0.5, 'cell_eccentricity': 0.5,
    'cell_circularity': -1.0, 'cell_aspect_ratio': 0.5, 'avg_nucleus_area': 1.0,
    'avg_nucleus_eccentricity': 0.3, 'nucleus_to_cell_area_ratio': -1.0,
    'nuclear_enlargement': 1.0, 'cell_enlargement': 1.5, 'nucleus_displacement': 0.2,
}

AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']

# --- Helper Functions ---
def extract_sample_id(filename):
    """
    Extract the sample ID from a filename based on the specific naming pattern.
    (Adapted from user's original notebook)
    """
    base_name = os.path.splitext(filename)[0]
    if base_name.startswith('denoised_'):
        base_name = base_name[len('denoised_'):]
    # Regex to capture the part up to seqXXX
    pattern = re.compile(r'([\d\.]+Pa_[^_]+_[^_]+_[^_]+_[^_]+_[^_]+_seq\d+)')
    match = pattern.search(base_name)
    if match:
        return match.group(1)
    # Fallback if regex doesn't match (simplified)
    parts = base_name.split('_')
    for i, part in enumerate(parts):
        if part.startswith('seq') and i >= 2:
            return '_'.join(parts[:i+1]) # Join parts up to and including seqXXX
    # More general fallback
    common_prefix = "_".join(filename.split('_')[:6]) # Adjust if needed
    return common_prefix if 'seq' in common_prefix else os.path.splitext(os.path.basename(filename))[0]


def load_image_as_labeled_mask(filepath):
    """Loads a mask image, ensuring it's a labeled integer mask."""
    print(f"    Loading mask: {os.path.basename(filepath)}")
    try:
        img = io.imread(filepath)
        # Handle multi-channel images by converting to grayscale
        if img.ndim > 2:
            if img.shape[-1] == 3:  # RGB
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[-1] == 4:  # RGBA
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else: # Take first channel if not RGB/RGBA
                img = img[..., 0]

        # If the image is already a labeled mask (e.g., integer types with max > 1)
        if img.dtype.kind in 'iu' and np.max(img) > 1: # integer unsigned/signed
            print(f"    Mask {os.path.basename(filepath)} appears to be already labeled.")
            return img.astype(np.uint16) # Ensure consistent type for labels

        # If it's binary or float, threshold and label
        if img.dtype.kind == 'f': # float
            img = (img > 0.5).astype(np.uint8) # Threshold for float images
        elif np.max(img) == 1: # Binary
             img = img.astype(np.uint8)

        if np.max(img) <=1 : # If binary after potential conversion
            labeled_img, num_features = ndimage.label(img)
            print(f"    Labeled binary mask {os.path.basename(filepath)}, found {num_features} features.")
            return labeled_img.astype(np.uint16)

        # If it was uint8 with labels (e.g. from CellProfiler)
        print(f"    Mask {os.path.basename(filepath)} treated as pre-labeled uint8/uint16.")
        return img.astype(np.uint16)

    except Exception as e:
        print(f"    Error loading image {filepath}: {str(e)}")
        return None

def load_and_prepare_data(csv_path):
    """Loads the data and checks for required columns."""
    print(f"Loading data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        # Ensure 'sample_id' is present from the original CSV structure
        if 'sample_id' not in df.columns and 'cell_id' in df.columns:
             # Attempt to derive sample_id if it's embedded in cell_id
             # This assumes cell_id is like 'actual_sample_id_originalcelllabel'
             df['derived_sample_id'] = df['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
             print("Derived 'derived_sample_id' from 'cell_id'. Will use this for matching masks.")
        elif 'sample_id' in df.columns:
            df['derived_sample_id'] = df['sample_id'] # Use existing if available
            print("Using existing 'sample_id' column for matching masks.")
        else:
            print("Error: Cannot determine sample_id. 'sample_id' column missing and cannot derive from 'cell_id'.")
            return None

        print(f"Successfully loaded {len(df)} cells.")
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        return None

    essential_cols_for_operation = list(SENESCENCE_SCORE_WEIGHTS.keys()) + ['nuclei_count', 'cell_id'] # 'sample_id' checked above
    all_needed_cols = list(set(FEATURES_FOR_CLUSTERING + essential_cols_for_operation))

    missing_cols = [col for col in all_needed_cols if col not in df.columns]
    if missing_cols:
        missing_critical_cols = [col for col in essential_cols_for_operation if col not in df.columns]
        if missing_critical_cols:
            print(f"Error: Critical columns for scoring/rules are missing: {missing_critical_cols}")
            return None
        print(f"Warning: Some columns listed in FEATURES_FOR_CLUSTERING are missing: {missing_cols}.")

    return df

def calculate_senescence_score(df, score_weights):
    """Calculates a per-cell senescence score."""
    print("Calculating per-cell senescence score...")
    features_for_scoring_present = [f for f in score_weights.keys() if f in df.columns]
    if not features_for_scoring_present:
        print("Error: No features for senescence score calculation are present in the DataFrame.")
        df['senescence_score'] = np.nan
        df['senescence_score_normalized'] = np.nan
        return df
    score_df = df[features_for_scoring_present].copy()
    for col in AREA_FEATURES_TO_LOG:
        if col in score_df.columns:
            score_df[col] = np.log1p(score_df[col])
    scaler = StandardScaler()
    numeric_score_cols = score_df.select_dtypes(include=np.number).columns
    if not numeric_score_cols.empty:
        score_features_standardized = scaler.fit_transform(score_df[numeric_score_cols])
        score_features_standardized_df = pd.DataFrame(score_features_standardized, columns=numeric_score_cols, index=score_df.index)
    else:
        score_features_standardized_df = pd.DataFrame(index=score_df.index)
    df['senescence_score'] = 0.0
    for feature, weight in score_weights.items():
        if feature in score_features_standardized_df.columns:
            df['senescence_score'] += score_features_standardized_df[feature] * weight
    if df['senescence_score'].isna().all() or (df['senescence_score'].max() == df['senescence_score'].min()):
        df['senescence_score_normalized'] = np.nan if df['senescence_score'].isna().all() else 0.0
    else:
        df['senescence_score_normalized'] = (df['senescence_score'] - df['senescence_score'].min()) / \
                                           (df['senescence_score'].max() - df['senescence_score'].min())
    print("Senescence score calculation complete.")
    return df

def perform_refined_clustering(df, feature_columns_for_clustering, umap_n_neighbors=30, umap_min_dist=0.1, umap_random_state=42, kmeans_n_clusters=2, kmeans_random_state=42):
    """Performs UMAP and k-Means clustering on selected features."""
    print(f"\nPerforming refined clustering using features: {feature_columns_for_clustering}...")
    actual_clustering_features = [col for col in feature_columns_for_clustering if col in df.columns]
    if not actual_clustering_features:
        df['umap_x_refined'], df['umap_y_refined'], df['cluster_refined'] = np.nan, np.nan, np.nan
        return df
    features_df = df[actual_clustering_features].copy()
    for col in actual_clustering_features:
        if not pd.api.types.is_numeric_dtype(features_df[col]):
            try: features_df[col] = pd.to_numeric(features_df[col])
            except ValueError:
                features_df = features_df.drop(columns=[col])
                actual_clustering_features.remove(col)
    features_df = features_df.fillna(features_df.mean())
    if features_df.empty or features_df.shape[1] == 0:
        df['umap_x_refined'], df['umap_y_refined'], df['cluster_refined'] = np.nan, np.nan, np.nan
        return df
    for col in AREA_FEATURES_TO_LOG:
        if col in features_df.columns: features_df[col] = np.log1p(features_df[col])
    scaler = StandardScaler()
    features_standardized = scaler.fit_transform(features_df)
    actual_umap_n_neighbors = min(umap_n_neighbors, len(features_standardized) - 1)
    if actual_umap_n_neighbors < 2 :
        df['umap_x_refined'], df['umap_y_refined'], df['cluster_refined'] = np.nan, np.nan, np.nan
        return df
    reducer = umap.UMAP(n_neighbors=actual_umap_n_neighbors, min_dist=umap_min_dist, random_state=umap_random_state, n_components=2)
    embedding = reducer.fit_transform(features_standardized)
    df['umap_x_refined'], df['umap_y_refined'] = embedding[:, 0], embedding[:, 1]
    if len(embedding) >= kmeans_n_clusters:
        kmeans = KMeans(n_clusters=kmeans_n_clusters, random_state=kmeans_random_state, n_init='auto')
        df['cluster_refined'] = kmeans.fit_predict(embedding)
    else: df['cluster_refined'] = np.nan
    print("Refined clustering complete.")
    return df

def classify_cells(df, preliminary_quantile, multinucleation_threshold):
    """Assigns preliminary and final cell types."""
    print("\nClassifying cells...")
    if 'senescence_score_normalized' not in df.columns or df['senescence_score_normalized'].isna().all():
        df['cell_type_preliminary'] = 'Unknown'
    else:
        score_threshold = df['senescence_score_normalized'].quantile(preliminary_quantile)
        print(f"  Using senescence score quantile {preliminary_quantile} (threshold = {score_threshold:.4f}) for preliminary classification.")
        df['cell_type_preliminary'] = df['senescence_score_normalized'].apply(
            lambda x: 'Senescent' if x >= score_threshold else 'Non-senescent')
        prelim_sen_count = (df['cell_type_preliminary'] == 'Senescent').sum()
        print(f"  {prelim_sen_count} cells ({prelim_sen_count/len(df)*100:.2f}%) preliminarily classified as Senescent by score.")
    df['cell_type_final'] = df['cell_type_preliminary']
    if 'nuclei_count' in df.columns:
        multinucleated_mask = df['nuclei_count'] > multinucleation_threshold
        total_multinucleated = multinucleated_mask.sum()
        print(f"\n--- Multinucleation Rule (nuclei_count > {multinucleation_threshold}) ---")
        print(f"  Total cells considered polynucleated: {total_multinucleated} ({total_multinucleated/len(df)*100:.2f}%)")
        if total_multinucleated > 0:
            multinucleated_and_sen_by_score = (multinucleated_mask & (df['cell_type_preliminary'] == 'Senescent')).sum()
            print(f"    Of these {total_multinucleated} polynucleated cells:")
            print(f"      - {multinucleated_and_sen_by_score} were ALREADY 'Senescent' by score.")
            multinucleated_reclassified_to_senescent = (multinucleated_mask & (df['cell_type_preliminary'] == 'Non-senescent')).sum()
            print(f"      - {multinucleated_reclassified_to_senescent} were 'Non-senescent' by score and are NOW RECLASSIFIED to 'Senescent'.")
        df.loc[multinucleated_mask, 'cell_type_final'] = 'Senescent'
    else: print("  Warning: 'nuclei_count' column not found. Cannot apply multinucleation rule.")
    final_sen_count = (df['cell_type_final'] == 'Senescent').sum()
    print(f"\nTotal cells finally classified as Senescent: {final_sen_count} ({final_sen_count/len(df)*100:.2f}%)")
    print("Cell classification complete.")
    return df

def visualize_refined_results(df, output_dir):
    """Generates and saves summary visualizations for the refined analysis."""
    # (Implementation from previous response, ensure it's up-to-date)
    print("\nGenerating summary visualizations for refined results...")
    if not os.path.exists(output_dir): os.makedirs(output_dir)
    if 'umap_x_refined' in df.columns and df['umap_x_refined'].notna().any() and 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
        plt.figure(figsize=(12, 10)); scatter = plt.scatter(df['umap_x_refined'], df['umap_y_refined'], c=df['senescence_score_normalized'], cmap='viridis', s=15, alpha=0.7)
        plt.colorbar(scatter, label='Normalized Senescence Score'); plt.title('Refined UMAP: Cells Colored by Senescence Score', fontsize=16)
        plt.xlabel('UMAP Dimension 1 (Refined)'); plt.ylabel('UMAP Dimension 2 (Refined)'); plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig(os.path.join(output_dir, 'umap_refined_by_senescence_score.png'), dpi=300, bbox_inches='tight'); plt.close()
    if 'umap_x_refined' in df.columns and df['umap_x_refined'].notna().any() and 'cell_type_final' in df.columns:
        plt.figure(figsize=(12, 10)); unique_types = df['cell_type_final'].unique(); palette = {t: ('red' if t == 'Senescent' else ('blue' if t == 'Non-senescent' else 'grey')) for t in unique_types}
        for ct, col in palette.items(): subset = df[df['cell_type_final'] == ct]; plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=ct, color=col, s=15, alpha=0.7)
        plt.title('Refined UMAP: Cells Colored by Final Classification', fontsize=16); plt.xlabel('UMAP Dimension 1 (Refined)'); plt.ylabel('UMAP Dimension 2 (Refined)')
        if palette: plt.legend(title='Final Cell Type'); plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig(os.path.join(output_dir, 'umap_refined_by_final_cell_type.png'), dpi=300, bbox_inches='tight'); plt.close()
    if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
        plt.figure(figsize=(10, 6)); sns.histplot(df['senescence_score_normalized'].dropna(), kde=True, bins=50)
        plt.title('Distribution of Normalized Senescence Score', fontsize=16); plt.xlabel('Normalized Senescence Score'); plt.ylabel('Frequency')
        if 'cell_type_preliminary' in df.columns: score_thresh = df['senescence_score_normalized'].quantile(PRELIMINARY_SENESCENCE_QUANTILE); plt.axvline(score_thresh, color='r', linestyle='--', label=f'Quantile Threshold ({PRELIMINARY_SENESCENCE_QUANTILE*100:.0f}th percentile)'); plt.legend()
        plt.grid(True, linestyle='--', alpha=0.5); plt.savefig(os.path.join(output_dir, 'senescence_score_distribution.png'), dpi=300, bbox_inches='tight'); plt.close()
    if 'cell_type_final' in df.columns and df['cell_type_final'].nunique() > 1:
        features = ['cell_area', 'avg_nucleus_area', 'cell_circularity', 'nucleus_to_cell_area_ratio', 'senescence_score_normalized', 'nuclei_count']
        features = [f for f in features if f in df.columns and df[f].notna().any()]
        if features:
            cols_plot = 3; rows_plot = (len(features) + cols_plot - 1) // cols_plot; fig, axes = plt.subplots(rows_plot, cols_plot, figsize=(5*cols_plot, 5*rows_plot), squeeze=False); axes = axes.flatten()
            pal = palette if 'palette' in locals() else 'coolwarm'; ord_list = sorted(df['cell_type_final'].unique())
            for i, feat in enumerate(features): sns.boxplot(x='cell_type_final', y=feat, data=df, ax=axes[i], palette=pal, order=ord_list); axes[i].set_title(feat.replace('_',' ').title()); axes[i].set_xlabel('Final Cell Type'); axes[i].set_ylabel(feat.replace('_',' ').title())
            for j in range(i + 1, len(axes)): fig.delaxes(axes[j])
            plt.tight_layout(); plt.savefig(os.path.join(output_dir, 'feature_comparison_by_final_type.png'), dpi=300, bbox_inches='tight'); plt.close()
    print("Summary visualizations saved.")


def visualize_classification_on_masks(df_results, cell_mask_dir, nuclei_mask_dir, output_dir_masks):
    """
    Visualizes the final cell classification by overlaying it on the original mask images.
    """
    print(f"\nGenerating classification overlays on original masks in: {output_dir_masks}")
    if not os.path.exists(output_dir_masks):
        os.makedirs(output_dir_masks)

    # Define colors for visualization
    senescent_color = [255, 0, 0]  # Red
    non_senescent_color = [0, 0, 255]  # Blue
    nuclei_overlay_color = [0, 255, 0]  # Green for nuclei outline or fill
    boundary_color = [200, 200, 200] # Light grey for cell boundaries

    # Get unique sample IDs from the results DataFrame
    # Use 'derived_sample_id' which was created in load_and_prepare_data
    if 'derived_sample_id' not in df_results.columns:
        print("Error: 'derived_sample_id' column not found in results. Cannot match to mask files.")
        return

    unique_sample_ids = df_results['derived_sample_id'].unique()

    # Create a lookup dictionary for cell classifications
    # cell_id in df_results is 'derived_sample_id_originalcelllabel'
    classification_lookup = pd.Series(df_results.cell_type_final.values, index=df_results.cell_id).to_dict()

    available_cell_mask_files = [f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff'))]
    available_nuclei_mask_files = [f for f in os.listdir(nuclei_mask_dir) if f.endswith(('.tif', '.tiff'))]

    for sample_id_from_csv in tqdm(unique_sample_ids, desc="Processing samples for mask visualization"):
        # Find corresponding mask files
        cell_mask_file = None
        for f_name in available_cell_mask_files:
            extracted_id = extract_sample_id(f_name)
            if extracted_id == sample_id_from_csv:
                cell_mask_file = f_name
                break

        nuclei_mask_file = None
        for f_name in available_nuclei_mask_files:
            extracted_id = extract_sample_id(f_name)
            if extracted_id == sample_id_from_csv:
                nuclei_mask_file = f_name
                break

        if not cell_mask_file:
            print(f"  Warning: Cell mask file not found for sample ID: {sample_id_from_csv}")
            continue

        print(f"\n  Processing sample: {sample_id_from_csv}")
        cell_mask_path = os.path.join(cell_mask_dir, cell_mask_file)

        # Load cell mask (should be a labeled mask where each integer is a cell ID)
        labeled_cell_mask = load_image_as_labeled_mask(cell_mask_path)
        if labeled_cell_mask is None:
            continue

        # Create an empty RGB image for the overlay
        overlay_image = np.zeros((labeled_cell_mask.shape[0], labeled_cell_mask.shape[1], 3), dtype=np.uint8)

        # Iterate through each cell label in the mask
        cell_props = measure.regionprops(labeled_cell_mask)
        for props in cell_props:
            original_cell_label = props.label # This is the integer ID from the mask

            # Construct the full cell_id as it appears in the CSV results
            # This assumes 'cell_id' in CSV is 'derived_sample_id_originalcelllabel'
            # and 'original_cell_id' column was created as int(full_cell_id.split('_')[-1])
            # For lookup, we need the full ID.
            full_cell_id_for_lookup = f"{sample_id_from_csv}_{original_cell_label}"

            cell_type = classification_lookup.get(full_cell_id_for_lookup, 'Unknown')

            current_color = non_senescent_color
            if cell_type == 'Senescent':
                current_color = senescent_color
            elif cell_type == 'Unknown':
                current_color = [128, 128, 128] # Grey for unknown

            # Color the cell region
            overlay_image[labeled_cell_mask == original_cell_label] = current_color

            # Draw cell boundaries (optional, can make image busy but more informative)
            # cell_boundary = segmentation.find_boundaries(labeled_cell_mask == original_cell_label, mode='inner')
            # overlay_image[cell_boundary] = boundary_color


        # Optionally, overlay nuclei
        if nuclei_mask_file:
            nuclei_mask_path = os.path.join(nuclei_mask_dir, nuclei_mask_file)
            labeled_nuclei_mask = load_image_as_labeled_mask(nuclei_mask_path)
            if labeled_nuclei_mask is not None:
                # Find boundaries of nuclei to make them more visible as outlines
                nuclei_boundaries = segmentation.find_boundaries(labeled_nuclei_mask, mode='inner', background=0)
                overlay_image[nuclei_boundaries] = nuclei_overlay_color
                # Or fill nuclei:
                # overlay_image[labeled_nuclei_mask > 0] = nuclei_overlay_color


        # Save the overlay image
        output_filename = os.path.join(output_dir_masks, f"{sample_id_from_csv}_classification_overlay.png")

        # Add a legend to the image (matplotlib approach)
        fig_legend, ax_legend = plt.subplots(figsize=(overlay_image.shape[1]/100, overlay_image.shape[0]/100), dpi=100) # Adjust size as needed
        ax_legend.imshow(overlay_image)

        sen_patch = mpatches.Patch(color=np.array(senescent_color)/255., label='Senescent')
        non_sen_patch = mpatches.Patch(color=np.array(non_senescent_color)/255., label='Non-senescent')
        nuc_patch = mpatches.Patch(color=np.array(nuclei_overlay_color)/255., label='Nuclei Outline')
        handles = [sen_patch, non_sen_patch]
        if nuclei_mask_file and labeled_nuclei_mask is not None: # only add nuclei legend if nuclei were processed
            handles.append(nuc_patch)

        ax_legend.legend(handles=handles, loc='upper right', fontsize='small', bbox_to_anchor=(1.25, 1)) # Adjust bbox_to_anchor
        ax_legend.axis('off') # Turn off axis numbers and ticks
        plt.tight_layout()
        plt.savefig(output_filename, dpi=150) # Adjust DPI as needed
        plt.close(fig_legend)
        print(f"    Saved overlay for {sample_id_from_csv} to {output_filename}")

    print("Mask overlay visualization complete.")


def main_refined_analysis():
    """Main function to run the refined senescence analysis."""
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    df = load_and_prepare_data(INPUT_CSV_PATH)
    if df is None: return

    df = calculate_senescence_score(df, SENESCENCE_SCORE_WEIGHTS)
    df = perform_refined_clustering(df, FEATURES_FOR_CLUSTERING)
    df = classify_cells(df, PRELIMINARY_SENESCENCE_QUANTILE, MULTINUCLEATION_THRESHOLD)

    output_csv_path = os.path.join(OUTPUT_DIR, 'cell_classification_results_refined.csv')
    cols_to_save = [col for col in df.columns if col not in ['umap_x', 'umap_y', 'cluster', 'cell_type']]
    if 'umap_x_refined' in df.columns and df['umap_x_refined'].isna().all():
        cols_to_save = [col for col in cols_to_save if col not in ['umap_x_refined', 'umap_y_refined']]
    if 'cluster_refined' in df.columns and df['cluster_refined'].isna().all():
        cols_to_save = [col for col in cols_to_save if col not in ['cluster_refined']]
    df_to_save = df[cols_to_save]
    df_to_save.to_csv(output_csv_path, index=False)
    print(f"\nRefined results saved to: {output_csv_path}")

    visualize_refined_results(df, OUTPUT_DIR) # Summary visualizations

    # New call for visualizing on actual masks
    mask_overlay_output_path = os.path.join(OUTPUT_DIR, MASK_VISUALIZATION_SUBDIR)
    visualize_classification_on_masks(df, CELL_MASK_DIR, NUCLEI_MASK_DIR, mask_overlay_output_path)

    print("\nRefined analysis and mask visualization complete!")

if __name__ == '__main__':
    main_refined_analysis()


Loading data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence/cell_classification_results.csv...
Using existing 'sample_id' column for matching masks.
Successfully loaded 2472 cells.
Calculating per-cell senescence score...
Senescence score calculation complete.

Performing refined clustering using features: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio']...


  warn(


Refined clustering complete.

Classifying cells...
  Using senescence score quantile 0.85 (threshold = 0.5115) for preliminary classification.
  371 cells (15.01%) preliminarily classified as Senescent by score.

--- Multinucleation Rule (nuclei_count > 1) ---
  Total cells considered polynucleated: 225 (9.10%)
    Of these 225 polynucleated cells:
      - 54 were ALREADY 'Senescent' by score.
      - 171 were 'Non-senescent' by score and are NOW RECLASSIFIED to 'Senescent'.

Total cells finally classified as Senescent: 542 (21.93%)
Cell classification complete.

Refined results saved to: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/cell_classification_results_refined.csv

Generating summary visualizations for refined results...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  for i, feat in enumerate(features): sns.boxplot(x='cell_type_final', y=feat, data=df, ax=axes[i], palette=pal, order=ord_list); axes[i].set_title(feat.replace('_',' ').title()); axes[i].set_xlabel('Final Cell Type'); axes[i].set_ylabel(feat.replace('_',' ').title())

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  for i, feat in enumerate(features): sns.boxplot(x='cell_type_final', y=feat, data=df, ax=axes[i], palette=pal, order=ord_list); axes[i].set_title(feat.replace('_',' ').title()); axes[i].set_xlabel('Final Cell Type'); axes[i].set_ylabel(feat.replace('_',' ').title())

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and se

Summary visualizations saved.

Generating classification overlays on original masks in: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays


Processing samples for mask visualization:   0%|          | 0/8 [00:00<?, ?it/s]


  Processing sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_mask_merged_conservative.tif
    Mask 0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Cadherins_filtered_mask.tif
    Mask denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  12%|█▎        | 1/8 [00:01<00:10,  1.52s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/0Pa_U_05mar19_20x_L2RA_Flat_seq001_classification_overlay.png

  Processing sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq002
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq002_cell_mask_merged_conservative.tif
    Mask 0Pa_U_05mar19_20x_L2RA_Flat_seq002_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Cadherins_filtered_mask.tif
    Mask denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  25%|██▌       | 2/8 [00:03<00:09,  1.57s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/0Pa_U_05mar19_20x_L2RA_Flat_seq002_classification_overlay.png

  Processing sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq003
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq003_cell_mask_merged_conservative.tif
    Mask 0Pa_U_05mar19_20x_L2RA_Flat_seq003_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Cadherins_filtered_mask.tif
    Mask denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  38%|███▊      | 3/8 [00:04<00:07,  1.60s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/0Pa_U_05mar19_20x_L2RA_Flat_seq003_classification_overlay.png

  Processing sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001_cell_mask_merged_conservative.tif
    Mask 1.4Pa_U_05mar19_20x_L2R_Flat_seq001_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq001_Cadherins_filtered_mask.tif
    Mask denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq001_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  50%|█████     | 4/8 [00:06<00:05,  1.49s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/1.4Pa_U_05mar19_20x_L2R_Flat_seq001_classification_overlay.png

  Processing sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002_cell_mask_merged_conservative.tif
    Mask 1.4Pa_U_05mar19_20x_L2R_Flat_seq002_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq002_Cadherins_filtered_mask.tif
    Mask denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq002_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  62%|██████▎   | 5/8 [00:07<00:04,  1.39s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/1.4Pa_U_05mar19_20x_L2R_Flat_seq002_classification_overlay.png

  Processing sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003_cell_mask_merged_conservative.tif
    Mask 1.4Pa_U_05mar19_20x_L2R_Flat_seq003_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq003_Cadherins_filtered_mask.tif
    Mask denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq003_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  75%|███████▌  | 6/8 [00:09<00:03,  1.63s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/1.4Pa_U_05mar19_20x_L2R_Flat_seq003_classification_overlay.png

  Processing sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004_cell_mask_merged_conservative.tif
    Mask 1.4Pa_U_05mar19_20x_L2R_Flat_seq004_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq004_Cadherins_filtered_mask.tif
    Mask denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq004_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization:  88%|████████▊ | 7/8 [00:11<00:01,  1.73s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq004 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/1.4Pa_U_05mar19_20x_L2R_Flat_seq004_classification_overlay.png

  Processing sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005_cell_mask_merged_conservative.tif
    Mask 1.4Pa_U_05mar19_20x_L2R_Flat_seq005_cell_mask_merged_conservative.tif appears to be already labeled.
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq005_Cadherins_filtered_mask.tif
    Mask denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq005_Cadherins_filtered_mask.tif appears to be already labeled.


Processing samples for mask visualization: 100%|██████████| 8/8 [00:12<00:00,  1.61s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq005 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V4/mask_overlays/1.4Pa_U_05mar19_20x_L2R_Flat_seq005_classification_overlay.png
Mask overlay visualization complete.

Refined analysis and mask visualization complete!





try something else

In [12]:
pip install scanpy

Collecting scanpy
  Downloading scanpy-1.11.1-py3-none-any.whl.metadata (9.9 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.4-py3-none-any.whl.metadata (9.3 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting scikit-learn<1.6.0,>=1.1 (from scanpy)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.11.2-py3-none-any.whl.metadata (1.9 kB)
Downloading scanpy-1.11.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading anndata-0.11.4-py3-none-any.whl (144 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.

In [22]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
import umap

try:
    import scanpy as sc
    SCANPY_AVAILABLE = True
except ImportError:
    print("Scanpy library not found. Diffusion map functionality will be skipped.")
    SCANPY_AVAILABLE = False

# --- Configuration & Parameters ---
INPUT_REFINED_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv"
EXPLORATORY_OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V4_fix" # Incremented version

FEATURES_FOR_ANALYSIS = [ # Features used for UMAP, DiffMap, DBSCAN on features, GMM on features
    'cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity',
    'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area',
    'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement',
    'nucleus_to_cell_area_ratio',
    'nuclear_enlargement', 'cell_enlargement'
]
AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']

# Diffusion Map parameters
N_DIFFUSION_COMPONENTS = 10
N_DCS_TO_PLOT = 3
N_NEIGHBORS_FOR_SCANPY = 15

# DBSCAN Parameters - IMPORTANT: TUNE THESE MANUALLY AFTER FIRST RUN!
# These are now just defaults if specific values aren't passed to run_dbscan_and_plot
DBSCAN_EPS_DEFAULT = 0.75
DBSCAN_MIN_SAMPLES_DEFAULT = 10

ESTIMATE_DBSCAN_EPS = True # Set to False after you've chosen your eps values.
K_FOR_EPS_ESTIMATION = 10
RUN_DBSCAN_ON_DIFFMAP = True
N_DCS_FOR_DBSCAN = 3

# GMM Parameters
GMM_N_COMPONENTS_RANGE = range(2, 5)
GMM_COVARIANCE_TYPE = 'full'

# Rule-Based Gating Parameters (Example - MODIFY THESE RULES)
RULE_BASED_GATES = [
    {
        'name': 'Polynucleated_Large',
        'conditions': [
            ('nuclei_count', '>', 1),
            ('cell_area', '>', 2500) # Example threshold, adjust!
        ],
        'output_label': 'Rule_Sen_Poly_Large'
    },
    {
        'name': 'Very_Large_Cell',
        'conditions': [
            ('cell_area', '>', 3500) # Example threshold, adjust!
        ],
        'output_label': 'Rule_Sen_VeryLarge'
    },
    {
        'name': 'High_Score_Not_Otherwise_Caught',
        'conditions': [
            ('senescence_score_normalized', '>', 0.85) # Example threshold, adjust!
        ],
        'output_label': 'Rule_Sen_HighScore'
    }
]
RULE_BASED_DEFAULT_LABEL = 'Rule_NonSenescent'


def load_data(csv_path):
    """Loads the refined data."""
    print(f"Loading refined data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded {len(df)} cells.")
        for col_check in ['senescence_score_normalized', 'cell_type_final', 'nuclei_count', 'cell_area']:
            if col_check not in df.columns:
                print(f"Warning: Essential column '{col_check}' not found. Some functionalities might be affected.")
        return df
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        return None

def preprocess_features_for_ml(df, feature_columns, log_transform_cols):
    """Prepares features specifically for ML algorithms (scaling)."""
    print(f"\nPreprocessing features for ML. Selected: {feature_columns}")

    actual_features_for_ml = [col for col in feature_columns if col in df.columns]
    if not actual_features_for_ml:
        print("  Error: None of the specified FEATURES_FOR_ANALYSIS are present in the DataFrame for ML.")
        return None, None

    features_for_scaling_df = df[actual_features_for_ml].copy()

    for col in log_transform_cols:
        if col in features_for_scaling_df.columns:
            features_for_scaling_df[col] = np.log1p(features_for_scaling_df[col])
            print(f"  Log-transformed for scaling: {col}")

    if features_for_scaling_df.isnull().sum().any():
        print(f"  Handling NaNs using mean imputation for {features_for_scaling_df.isnull().sum().sum()} values (for scaled features).")
        features_for_scaling_df = features_for_scaling_df.fillna(features_for_scaling_df.mean())

    cols_to_drop_scaled = features_for_scaling_df.columns[features_for_scaling_df.isna().all()].tolist()
    if cols_to_drop_scaled:
        print(f"  Dropping all-NaN columns from scaled set: {cols_to_drop_scaled}")
        features_for_scaling_df = features_for_scaling_df.drop(columns=cols_to_drop_scaled)
        actual_features_for_ml = [f for f in actual_features_for_ml if f not in cols_to_drop_scaled]

    if features_for_scaling_df.empty or not actual_features_for_ml:
        print(" Error: No features remaining for scaling after processing.")
        return None, None

    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features_for_scaling_df)
    print("  Features standardized for ML algorithms.")

    return features_scaled, actual_features_for_ml


def compute_and_plot_diffusion_map(df, scaled_features, feature_names_used, output_dir):
    """Computes and plots diffusion map."""
    if not SCANPY_AVAILABLE: print("Skipping diffusion map: Scanpy not available."); return df
    print("\n--- Computing Diffusion Map ---")
    if scaled_features is None or scaled_features.shape[0] == 0 : print("  No scaled features for Diffusion Map. Skipping."); return df

    adata = sc.AnnData(scaled_features, var=pd.DataFrame(index=feature_names_used))
    adata.obs_names = df.index.astype(str)
    if 'senescence_score_normalized' in df.columns: adata.obs['senescence_score_normalized'] = df['senescence_score_normalized'].values
    if 'cell_type_final' in df.columns: adata.obs['cell_type_final'] = df['cell_type_final'].astype('category').values

    actual_n_neighbors = min(N_NEIGHBORS_FOR_SCANPY, adata.n_obs - 1)
    if actual_n_neighbors < 2: print(f"  Not enough samples for Scanpy neighbors. Skipping."); return df

    print(f"  Computing neighbors (k={actual_n_neighbors})...")
    sc.pp.neighbors(adata, n_neighbors=actual_n_neighbors, use_rep='X')
    print("  Running sc.tl.diffmap...")
    sc.tl.diffmap(adata, n_comps=N_DIFFUSION_COMPONENTS)

    if 'X_diffmap' in adata.obsm:
        num_dc = min(N_DIFFUSION_COMPONENTS, adata.obsm['X_diffmap'].shape[1] - 1)
        for i in range(num_dc): df[f'dc_{i+1}'] = adata.obsm['X_diffmap'][:, i+1]
        print(f"  Added {num_dc} DCs to DataFrame.")

        pairs = [(f'dc_{i}', f'dc_{j}') for i in range(1, N_DCS_TO_PLOT + 1) for j in range(i + 1, N_DCS_TO_PLOT + 1) if f'dc_{i}' in df.columns and f'dc_{j}' in df.columns]
        for dcx, dcy in pairs:
            if df[dcx].notna().any() and df[dcy].notna().any():
                if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
                    plt.figure(figsize=(10,8)); plt.scatter(df[dcx], df[dcy], c=df['senescence_score_normalized'], cmap='viridis', s=12, alpha=0.7); plt.colorbar(label='Norm. Senescence Score')
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Score'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper()); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_score.png'),dpi=300,bbox_inches='tight'); plt.close()
                if 'cell_type_final' in df.columns:
                    plt.figure(figsize=(10,8)); types=df['cell_type_final'].unique(); pal={t:('red' if t=='Senescent' else ('blue' if t=='Non-senescent' else 'grey')) for t in types}
                    for ct,col in pal.items(): subset=df[df['cell_type_final']==ct]; plt.scatter(subset[dcx],subset[dcy],label=ct,color=col,s=12,alpha=0.7)
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Classif.'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper());
                    if pal: plt.legend(title='Final Cell Type'); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_type.png'),dpi=300,bbox_inches='tight'); plt.close()
        print(f"  DiffMap pair plots for top {N_DCS_TO_PLOT} DCs saved.")
    else: print("  Error: 'X_diffmap' not found in AnnData object after sc.tl.diffmap.")
    return df

def run_dbscan_and_plot(df, data_for_dbscan, data_desc, output_dir, umap_emb=None, current_eps_val=None, current_min_samples_val=None):
    """Runs DBSCAN and plots results. Uses specific eps and min_samples if provided."""
    print(f"\n--- Running DBSCAN on {data_desc} ---")
    if data_for_dbscan is None or data_for_dbscan.shape[0] == 0:
        print(f"  No data for DBSCAN on {data_desc}. Skipping.")
        df[f'dbscan_{data_desc.lower().replace(" ","_")}']=-1
        return df

    eps_to_use = current_eps_val if current_eps_val is not None else DBSCAN_EPS_DEFAULT
    min_s_to_use = current_min_samples_val if current_min_samples_val is not None else DBSCAN_MIN_SAMPLES_DEFAULT

    if ESTIMATE_DBSCAN_EPS and current_eps_val is None :
        k_est = min(K_FOR_EPS_ESTIMATION, data_for_dbscan.shape[0]-1); k_est=max(1,k_est)
        nn=NearestNeighbors(n_neighbors=k_est); nn.fit(data_for_dbscan); dists, _ = nn.kneighbors(data_for_dbscan)
        actual_k_for_dists = min(k_est, dists.shape[1])
        if actual_k_for_dists > 0:
            k_dists = dists[:,actual_k_for_dists-1]
            k_dists_sorted = np.sort(k_dists)
            plt.figure(figsize=(8,6)); plt.plot(k_dists_sorted); plt.title(f'{actual_k_for_dists}-Dist Graph for Eps ({data_desc})');
            plt.xlabel("Points sorted by distance"); plt.ylabel(f"{actual_k_for_dists}-th NN Distance (eps candidate)"); plt.grid(True,alpha=0.3);
            eps_path=os.path.join(output_dir, f'dbscan_eps_est_{data_desc.lower().replace(" ","_")}.png'); plt.savefig(eps_path,dpi=300); plt.close();
            print(f"  Saved k-dist graph: {eps_path}. PLEASE INSPECT THIS PLOT TO SET appropriate DBSCAN_EPS for {data_desc}.")
            if len(k_dists_sorted)>10:
                sug_eps=np.percentile(k_dists_sorted,90);
                print(f"  A percentile-based suggestion for eps for {data_desc} is: {sug_eps:.3f}. The script will use eps={eps_to_use} (default or passed).")
        else:
            print(f"  Could not determine k-distances for eps estimation for {data_desc}. Using eps={eps_to_use}")

    print(f"  Running DBSCAN with eps={eps_to_use}, min_samples={min_s_to_use} on {data_desc}...")
    db=DBSCAN(eps=eps_to_use,min_samples=min_s_to_use).fit(data_for_dbscan)
    clust_col=f'dbscan_{data_desc.lower().replace(" ","_")}'; df[clust_col]=db.labels_
    n_clust=len(set(db.labels_))-(1 if -1 in db.labels_ else 0); n_noise=list(db.labels_).count(-1)
    print(f"  DBSCAN on {data_desc}: {n_clust} clusters, {n_noise} noise ({n_noise/len(df)*100:.2f}%).")

    if umap_emb is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10));
        labels_unique_dbscan=sorted(df[clust_col].unique());
        n_actual_clusters = len([l for l in labels_unique_dbscan if l != -1])
        dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)
        cdict = {}; cluster_idx = 0
        for lbl in labels_unique_dbscan:
            if lbl == -1: cdict[lbl] = (0.5, 0.5, 0.5, 1)
            else: cdict[lbl] = dbscan_cmap_obj(cluster_idx); cluster_idx += 1
        for k_val in labels_unique_dbscan:
            mask=(df[clust_col]==k_val); xy=umap_emb[mask]
            if xy.shape[0]>0:
                 plt.scatter(xy[:,0],xy[:,1], s=(20 if k_val!=-1 else 10), c=[cdict[k_val]],
                             marker=('o' if k_val!=-1 else 'x'), label=('Noise' if k_val == -1 else f'Cluster {k_val}'))
        plt.title(f'DBSCAN on {data_desc} (UMAP proj.)\neps={eps_to_use:.3f}, min_s={min_s_to_use}');
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2');
        plt.legend(title='DBSCAN Cluster',bbox_to_anchor=(1.05,1),loc='upper left',markerscale=1.5);
        plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]);
        plt.savefig(os.path.join(output_dir, f'dbscan_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close();
        print(f"  DBSCAN on {data_desc} plotted on UMAP.")
    return df

def run_gmm_and_plot(df, data_for_gmm, data_desc, output_dir, umap_embedding=None):
    """Runs Gaussian Mixture Model clustering and plots results."""
    print(f"\n--- Running Gaussian Mixture Model (GMM) on {data_desc} ---")
    if data_for_gmm is None or data_for_gmm.shape[0] == 0:
        print(f"  No data available for GMM on {data_desc}. Skipping.")
        df[f'gmm_cluster_{data_desc.lower().replace(" ", "_")}'] = -1
        df[f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'] = np.nan
        return df

    best_gmm = None; lowest_bic = np.inf
    print(f"  Testing GMM with n_components in {list(GMM_N_COMPONENTS_RANGE)} using BIC...")
    for n_components in GMM_N_COMPONENTS_RANGE:
        if n_components > data_for_gmm.shape[0]: continue
        gmm = GaussianMixture(n_components=n_components, covariance_type=GMM_COVARIANCE_TYPE, random_state=42, n_init=5)
        gmm.fit(data_for_gmm); bic = gmm.bic(data_for_gmm)
        print(f"    GMM with {n_components} components: BIC = {bic:.2f}")
        if bic < lowest_bic: lowest_bic = bic; best_gmm = gmm

    if best_gmm is None:
        print("  GMM fitting failed. Skipping GMM."); df[f'gmm_cluster_{data_desc.lower().replace(" ", "_")}'] = -1; df[f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'] = np.nan
        return df

    print(f"  Best GMM found with {best_gmm.n_components} components (BIC={lowest_bic:.2f}).")
    cluster_col_name = f'gmm_{data_desc.lower().replace(" ", "_")}'; prob_col_name = f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'
    df[cluster_col_name] = best_gmm.predict(data_for_gmm); df[prob_col_name] = np.max(best_gmm.predict_proba(data_for_gmm), axis=1)

    if 'senescence_score_normalized' in df.columns:
        print(f"  Mean senescence_score_normalized per GMM component (for {data_desc}):\n{df.groupby(cluster_col_name)['senescence_score_normalized'].mean().sort_values()}")

    if umap_embedding is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12, 10)); unique_gmm_labels = sorted(df[cluster_col_name].unique())
        gmm_palette = sns.color_palette("viridis", n_colors=len(unique_gmm_labels))
        for i, label in enumerate(unique_gmm_labels):
            subset = df[df[cluster_col_name] == label]
            plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=f'GMM Comp. {label}', color=gmm_palette[i], s=15, alpha=0.7)
        plt.title(f'GMM ({best_gmm.n_components} comp.) on {data_desc} (UMAP proj.)', fontsize=14)
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='GMM Component', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
        plt.savefig(os.path.join(output_dir, f'gmm_on_umap_{data_desc.lower().replace(" ", "_")}.png'), dpi=300); plt.close()
        print(f"  GMM on {data_desc} results plotted on UMAP.")

        plt.figure(figsize=(12, 10)); scatter_gmm_prob = plt.scatter(df['umap_x_refined'], df['umap_y_refined'], c=df[prob_col_name], cmap='magma', s=15, alpha=0.7, vmin=0, vmax=1)
        plt.colorbar(scatter_gmm_prob, label='Max Probability of GMM Assignment')
        plt.title(f'GMM Max Assignment Probability on {data_desc} (UMAP proj.)', fontsize=14)
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.grid(True, alpha=0.3); plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'gmm_prob_on_umap_{data_desc.lower().replace(" ", "_")}.png'), dpi=300); plt.close()
        print(f"  GMM max probability on {data_desc} plotted on UMAP.")
    return df

def apply_rule_based_gating(df_main, rules, default_label, output_dir, umap_embedding=None):
    """Applies a series of defined rules to classify cells. Operates on df_main."""
    print("\n--- Applying Rule-Based Gating ---")
    df_main['rule_based_classification'] = default_label

    # Check if all features required by any rule exist in df_main
    all_rule_features = set()
    for rule in rules:
        for condition in rule['conditions']:
            all_rule_features.add(condition[0])

    missing_features_in_df = [feat for feat in all_rule_features if feat not in df_main.columns]
    if missing_features_in_df:
        print(f"  Error: The following features required by rules are missing from the DataFrame: {missing_features_in_df}. Skipping rule-based gating.")
        return df_main

    for rule_idx, rule in enumerate(rules):
        print(f"  Applying rule: {rule['name']}")
        eligible_mask = (df_main['rule_based_classification'] == default_label)
        if not eligible_mask.any():
            print(f"    No cells eligible for rule '{rule['name']}'.")
            continue

        rule_condition_mask = pd.Series([True] * len(df_main), index=df_main.index)
        for feature, operator, value in rule['conditions']:
            # This check is now redundant due to the one above, but kept for safety per condition
            if feature not in df_main.columns:
                print(f"    Feature '{feature}' not found in DataFrame for rule '{rule['name']}'. Skipping this rule.")
                rule_condition_mask[:] = False
                break
            try:
                # Ensure the column is numeric before comparison, handle potential errors
                feature_series = pd.to_numeric(df_main[feature], errors='coerce')
                if feature_series.isnull().any():
                    print(f"    Warning: Feature '{feature}' contains non-numeric values after coercion for rule '{rule['name']}'. Comparisons may be affected.")

                if   operator == '>':  rule_condition_mask &= (feature_series > value)
                elif operator == '<':  rule_condition_mask &= (feature_series < value)
                elif operator == '>=': rule_condition_mask &= (feature_series >= value)
                elif operator == '<=': rule_condition_mask &= (feature_series <= value)
                elif operator == '==': rule_condition_mask &= (feature_series == value)
                elif operator == '!=': rule_condition_mask &= (feature_series != value)
                else:
                    print(f"    Unknown operator '{operator}' in rule '{rule['name']}'. Skipping condition.")
                    rule_condition_mask[:] = False; break

            except Exception as e: # Catch any other error during comparison
                print(f"    Error comparing feature '{feature}' in rule '{rule['name']}': {e}. Skipping condition.")
                rule_condition_mask[:] = False; break

        if not rule_condition_mask.all() and not rule_condition_mask.any() and rule_condition_mask is not False : # If mask became all False due to an issue
             print(f"    Rule '{rule['name']}' resulted in an invalid condition mask. No cells labeled.")
             continue


        if rule_condition_mask.any():
            cells_to_label_now = eligible_mask & rule_condition_mask
            df_main.loc[cells_to_label_now, 'rule_based_classification'] = rule['output_label']
            print(f"    {cells_to_label_now.sum()} cells labeled as '{rule['output_label']}'.")
        else:
            print(f"    No cells met all conditions for rule '{rule['name']}'.")

    print(f"\nRule-based classification counts:\n{df_main['rule_based_classification'].value_counts()}")

    if umap_embedding is not None and 'umap_x_refined' in df_main.columns and 'umap_y_refined' in df_main.columns:
        plt.figure(figsize=(12, 10)); unique_rule_labels = sorted(df_main['rule_based_classification'].unique())
        # Ensure enough colors if many rule labels
        if len(unique_rule_labels) > 0:
            rule_palette = sns.color_palette("Set3", n_colors=max(10, len(unique_rule_labels)))
            for i, label in enumerate(unique_rule_labels):
                subset = df_main[df_main['rule_based_classification'] == label]
                plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=label, color=rule_palette[i % len(rule_palette)], s=15, alpha=0.7) # Modulo for safety
            plt.title('Rule-Based Gating Classification (UMAP proj.)', fontsize=14)
            plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule-Based Class', bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
            plt.savefig(os.path.join(output_dir, 'rule_based_gating_on_umap.png'), dpi=300); plt.close()
            print("  Rule-based gating results plotted on UMAP.")
    return df_main


def main_exploratory_analysis():
    """Main function to run exploratory analysis."""
    if not os.path.exists(EXPLORATORY_OUTPUT_DIR):
        os.makedirs(EXPLORATORY_OUTPUT_DIR)
        print(f"Created output directory: {EXPLORATORY_OUTPUT_DIR}")

    df = load_data(INPUT_REFINED_CSV_PATH)
    if df is None: return

    # scaled_features are for ML algos, feature_names_used_for_scaling are their names
    # The main 'df' is used for rule-based gating as it contains all columns.
    scaled_features, feature_names_used_for_scaling = preprocess_features_for_ml(df, FEATURES_FOR_ANALYSIS, AREA_FEATURES_TO_LOG)

    if scaled_features is None:
        print("Scaled feature preprocessing failed. Some ML-based analyses might be skipped or fail.")

    umap_embedding_for_plotting = None
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().all():
        print("\nUsing existing UMAP coordinates from input CSV for visualizations.")
        umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
    elif scaled_features is not None:
        print("\nRecomputing UMAP for visualization...")
        try:
            reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42, n_components=2)
            embedding = reducer.fit_transform(scaled_features)
            df['umap_x_refined'] = embedding[:, 0]; df['umap_y_refined'] = embedding[:, 1]
            umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
            print("  UMAP recomputed.")
        except Exception as e: print(f"  Error recomputing UMAP: {e}.")
    else: print("\nSkipping UMAP computation as scaled features are unavailable.")

    if scaled_features is not None and feature_names_used_for_scaling is not None:
        df = compute_and_plot_diffusion_map(df, scaled_features, feature_names_used_for_scaling, EXPLORATORY_OUTPUT_DIR)

        # DBSCAN on Scaled Features - User needs to set DBSCAN_EPS_SCALED_FEATURES
        # based on dbscan_eps_est_scaled_features.png from the previous run.
        # Example: User looked at plot and chose this
        DBSCAN_EPS_SCALED_FEATURES = 2.3
        DBSCAN_MIN_SAMPLES_SCALED_FEATURES = 10 # Default or user-tuned
        print(f"\nNOTE: For DBSCAN on Scaled Features, using DBSCAN_EPS = {DBSCAN_EPS_SCALED_FEATURES}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_SCALED_FEATURES}")
        df = run_dbscan_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR,
                                 umap_emb=umap_embedding_for_plotting,
                                 current_eps_val=DBSCAN_EPS_SCALED_FEATURES,
                                 current_min_samples_val=DBSCAN_MIN_SAMPLES_SCALED_FEATURES)

        df = run_gmm_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

        if RUN_DBSCAN_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols = [f'dc_{i+1}' for i in range(N_DCS_FOR_DBSCAN) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols:
                data_dc = df[dc_cols].values
                # DBSCAN on DCs - User needs to set DBSCAN_EPS_DCS
                # based on dbscan_eps_est_top_X_dcs.png from the previous run.
                DBSCAN_EPS_DCS = 0.01 # Example: User looked at plot and chose this
                DBSCAN_MIN_SAMPLES_DCS = 10 # Default or user-tuned
                print(f"\nNOTE: For DBSCAN on Top DCs, using DBSCAN_EPS = {DBSCAN_EPS_DCS}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_DCS}")
                df = run_dbscan_and_plot(df, data_dc, f"Top_{len(dc_cols)}_DCs", EXPLORATORY_OUTPUT_DIR,
                                         umap_emb=umap_embedding_for_plotting,
                                         current_eps_val=DBSCAN_EPS_DCS,
                                         current_min_samples_val=DBSCAN_MIN_SAMPLES_DCS)
            else: print(f"\nSkipping DBSCAN on DCs: Not enough valid DC columns.")

        global RUN_GMM_ON_DIFFMAP, N_DCS_FOR_GMM
        if RUN_GMM_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols_gmm = [f'dc_{i+1}' for i in range(N_DCS_FOR_GMM) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols_gmm:
                data_dc_gmm = df[dc_cols_gmm].values
                df = run_gmm_and_plot(df, data_dc_gmm, f"Top_{len(dc_cols_gmm)}_DCs", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)
            else: print(f"\nSkipping GMM on DCs: Not enough valid DC columns.")

    # Apply rule-based gating using the main df, which contains all original and calculated columns
    # The first argument to apply_rule_based_gating is the DataFrame it will operate on for checking rules.
    df = apply_rule_based_gating(df, RULE_BASED_GATES, RULE_BASED_DEFAULT_LABEL, EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

    exploratory_csv_path = os.path.join(EXPLORATORY_OUTPUT_DIR, 'exploratory_analysis_results_v4.csv')
    df.to_csv(exploratory_csv_path, index=False)
    print(f"\nExploratory analysis results saved to: {exploratory_csv_path}")
    print("\nExploratory analysis script finished.")

RUN_GMM_ON_DIFFMAP = True
N_DCS_FOR_GMM = 3

if __name__ == '__main__':
    main_exploratory_analysis()


Created output directory: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V4_fix
Loading refined data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv...
Successfully loaded 2472 cells.

Preprocessing features for ML. Selected: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio', 'nuclear_enlargement', 'cell_enlargement']
  Log-transformed for scaling: cell_area
  Log-transformed for scaling: avg_nucleus_area
  Log-transformed for scaling: max_nucleus_area
  Log-transformed for scaling: cell_perimeter
  Features standardized for ML algorithms.

Using existing UMAP coordinates from input CSV for visualizations.

--- Computing Diffusion Map ---
  Computing n

  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)


  DBSCAN on Scaled_Features plotted on UMAP.

--- Running Gaussian Mixture Model (GMM) on Scaled_Features ---
  Testing GMM with n_components in [2, 3, 4] using BIC...


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 2 components: BIC = -11876.19


  # that has no feature names.
  # that has no feature names.


    GMM with 3 components: BIC = -20370.16


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 4 components: BIC = -26836.71
  Best GMM found with 4 components (BIC=-26836.71).
  Mean senescence_score_normalized per GMM component (for Scaled_Features):
gmm_scaled_features
1    0.316948
0    0.365966
3    0.500682
2    0.500958
Name: senescence_score_normalized, dtype: float64
  GMM on Scaled_Features results plotted on UMAP.
  GMM max probability on Scaled_Features plotted on UMAP.

NOTE: For DBSCAN on Top DCs, using DBSCAN_EPS = 0.01, MIN_SAMPLES = 10

--- Running DBSCAN on Top_3_DCs ---
  Running DBSCAN with eps=0.01, min_samples=10 on Top_3_DCs...
  DBSCAN on Top_3_DCs: 2 clusters, 42 noise (1.70%).


  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)


  DBSCAN on Top_3_DCs plotted on UMAP.

--- Running Gaussian Mixture Model (GMM) on Top_3_DCs ---
  Testing GMM with n_components in [2, 3, 4] using BIC...


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 2 components: BIC = -42744.54
    GMM with 3 components: BIC = -43790.09


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 4 components: BIC = -44590.14
  Best GMM found with 4 components (BIC=-44590.14).
  Mean senescence_score_normalized per GMM component (for Top_3_DCs):
gmm_top_3_dcs
0    0.319583
1    0.398703
2    0.427525
3    0.543944
Name: senescence_score_normalized, dtype: float64
  GMM on Top_3_DCs results plotted on UMAP.
  GMM max probability on Top_3_DCs plotted on UMAP.

--- Applying Rule-Based Gating ---
  Applying rule: Polynucleated_Large
    200 cells labeled as 'Rule_Sen_Poly_Large'.
  Applying rule: Very_Large_Cell
    728 cells labeled as 'Rule_Sen_VeryLarge'.
  Applying rule: High_Score_Not_Otherwise_Caught
    0 cells labeled as 'Rule_Sen_HighScore'.

Rule-based classification counts:
rule_based_classification
Rule_NonSenescent      1544
Rule_Sen_VeryLarge      728
Rule_Sen_Poly_Large     200
Name: count, dtype: int64
  Rule-based gating results plotted on UMAP.

Exploratory analysis results saved to: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/f

In [23]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
import umap

try:
    import scanpy as sc
    SCANPY_AVAILABLE = True
except ImportError:
    print("Scanpy library not found. Diffusion map functionality will be skipped.")
    SCANPY_AVAILABLE = False

# --- Configuration & Parameters ---
INPUT_REFINED_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv"
EXPLORATORY_OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules" # Incremented version

FEATURES_FOR_ANALYSIS = [ # Features used for UMAP, DiffMap, DBSCAN on features, GMM on features
    'cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity',
    'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area',
    'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement',
    'nucleus_to_cell_area_ratio',
    'nuclear_enlargement', 'cell_enlargement'
]
AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']

# Diffusion Map parameters
N_DIFFUSION_COMPONENTS = 10
N_DCS_TO_PLOT = 3
N_NEIGHBORS_FOR_SCANPY = 15

# DBSCAN Parameters - IMPORTANT: TUNE THESE MANUALLY AFTER FIRST RUN!
DBSCAN_EPS_DEFAULT = 0.75
DBSCAN_MIN_SAMPLES_DEFAULT = 10
ESTIMATE_DBSCAN_EPS = True
K_FOR_EPS_ESTIMATION = 10
RUN_DBSCAN_ON_DIFFMAP = True
N_DCS_FOR_DBSCAN = 3

# GMM Parameters
GMM_N_COMPONENTS_RANGE = range(2, 5)
GMM_COVARIANCE_TYPE = 'full'

# Rule-Based Gating Parameters (Updated with user's new rules)
# Rules are applied sequentially. A cell gets the label of the first rule it matches.
RULE_BASED_GATES = [
    {
        'name': 'Polynucleated', # Changed from Polynucleated_Large for broader capture
        'conditions': [
            ('nuclei_count', '>', 1),
        ],
        'output_label': 'Rule_Sen_Poly' # Cells with >1 nucleus
    },
    {
        'name': 'Very_Large_Cell',
        'conditions': [
            ('cell_area', '>', 5000) # User-defined threshold
        ],
        'output_label': 'Rule_Sen_VeryLarge'
    },
    {
        'name': 'Low_Circularity', # New rule
        'conditions': [
            ('cell_circularity', '<', 0.2)
        ],
        'output_label': 'Rule_Sen_LowCirc'
    },
    {
        'name': 'Low_NucToCellRatio', # New rule
        'conditions': [
            ('nucleus_to_cell_area_ratio', '<', 0.1)
        ],
        'output_label': 'Rule_Sen_LowNucRatio'
    },
    {
        'name': 'High_Score_Not_Otherwise_Caught',
        'conditions': [
            ('senescence_score_normalized', '>', 0.75) # User-defined threshold
        ],
        'output_label': 'Rule_Sen_HighScore'
    }
]
RULE_BASED_DEFAULT_LABEL = 'Rule_NonSenescent'


def load_data(csv_path):
    """Loads the refined data."""
    print(f"Loading refined data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded {len(df)} cells.")
        # Check for essential columns for visualization and rules
        # Added all features used in new rules to this check
        for col_check in ['senescence_score_normalized', 'cell_type_final',
                          'nuclei_count', 'cell_area', 'cell_circularity',
                          'nucleus_to_cell_area_ratio']:
            if col_check not in df.columns:
                print(f"Warning: Essential column '{col_check}' not found. Some functionalities might be affected.")
        return df
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        return None

def preprocess_features_for_ml(df, feature_columns, log_transform_cols):
    """Prepares features specifically for ML algorithms (scaling)."""
    print(f"\nPreprocessing features for ML. Selected: {feature_columns}")

    actual_features_for_ml = [col for col in feature_columns if col in df.columns]
    if not actual_features_for_ml:
        print("  Error: None of the specified FEATURES_FOR_ANALYSIS are present in the DataFrame for ML.")
        return None, None

    features_for_scaling_df = df[actual_features_for_ml].copy()

    for col in log_transform_cols:
        if col in features_for_scaling_df.columns:
            features_for_scaling_df[col] = np.log1p(features_for_scaling_df[col])
            print(f"  Log-transformed for scaling: {col}")

    if features_for_scaling_df.isnull().sum().any():
        print(f"  Handling NaNs using mean imputation for {features_for_scaling_df.isnull().sum().sum()} values (for scaled features).")
        features_for_scaling_df = features_for_scaling_df.fillna(features_for_scaling_df.mean())

    cols_to_drop_scaled = features_for_scaling_df.columns[features_for_scaling_df.isna().all()].tolist()
    if cols_to_drop_scaled:
        print(f"  Dropping all-NaN columns from scaled set: {cols_to_drop_scaled}")
        features_for_scaling_df = features_for_scaling_df.drop(columns=cols_to_drop_scaled)
        actual_features_for_ml = [f for f in actual_features_for_ml if f not in cols_to_drop_scaled]

    if features_for_scaling_df.empty or not actual_features_for_ml:
        print(" Error: No features remaining for scaling after processing.")
        return None, None

    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features_for_scaling_df)
    print("  Features standardized for ML algorithms.")

    return features_scaled, actual_features_for_ml


def compute_and_plot_diffusion_map(df, scaled_features, feature_names_used, output_dir):
    """Computes and plots diffusion map."""
    if not SCANPY_AVAILABLE: print("Skipping diffusion map: Scanpy not available."); return df
    print("\n--- Computing Diffusion Map ---")
    if scaled_features is None or scaled_features.shape[0] == 0 : print("  No scaled features for Diffusion Map. Skipping."); return df

    adata = sc.AnnData(scaled_features, var=pd.DataFrame(index=feature_names_used))
    adata.obs_names = df.index.astype(str)
    if 'senescence_score_normalized' in df.columns: adata.obs['senescence_score_normalized'] = df['senescence_score_normalized'].values
    if 'cell_type_final' in df.columns: adata.obs['cell_type_final'] = df['cell_type_final'].astype('category').values

    actual_n_neighbors = min(N_NEIGHBORS_FOR_SCANPY, adata.n_obs - 1)
    if actual_n_neighbors < 2: print(f"  Not enough samples for Scanpy neighbors. Skipping."); return df

    print(f"  Computing neighbors (k={actual_n_neighbors})...")
    sc.pp.neighbors(adata, n_neighbors=actual_n_neighbors, use_rep='X')
    print("  Running sc.tl.diffmap...")
    sc.tl.diffmap(adata, n_comps=N_DIFFUSION_COMPONENTS)

    if 'X_diffmap' in adata.obsm:
        num_dc = min(N_DIFFUSION_COMPONENTS, adata.obsm['X_diffmap'].shape[1] - 1)
        for i in range(num_dc): df[f'dc_{i+1}'] = adata.obsm['X_diffmap'][:, i+1]
        print(f"  Added {num_dc} DCs to DataFrame.")

        pairs = [(f'dc_{i}', f'dc_{j}') for i in range(1, N_DCS_TO_PLOT + 1) for j in range(i + 1, N_DCS_TO_PLOT + 1) if f'dc_{i}' in df.columns and f'dc_{j}' in df.columns]
        for dcx, dcy in pairs:
            if df[dcx].notna().any() and df[dcy].notna().any():
                if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
                    plt.figure(figsize=(10,8)); plt.scatter(df[dcx], df[dcy], c=df['senescence_score_normalized'], cmap='viridis', s=12, alpha=0.7); plt.colorbar(label='Norm. Senescence Score')
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Score'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper()); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_score.png'),dpi=300,bbox_inches='tight'); plt.close()
                if 'cell_type_final' in df.columns:
                    plt.figure(figsize=(10,8)); types=df['cell_type_final'].unique(); pal={t:('red' if t=='Senescent' else ('blue' if t=='Non-senescent' else 'grey')) for t in types}
                    for ct,col in pal.items(): subset=df[df['cell_type_final']==ct]; plt.scatter(subset[dcx],subset[dcy],label=ct,color=col,s=12,alpha=0.7)
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Classif.'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper());
                    if pal: plt.legend(title='Final Cell Type'); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_type.png'),dpi=300,bbox_inches='tight'); plt.close()
        print(f"  DiffMap pair plots for top {N_DCS_TO_PLOT} DCs saved.")
    else: print("  Error: 'X_diffmap' not found in AnnData object after sc.tl.diffmap.")
    return df

def run_dbscan_and_plot(df, data_for_dbscan, data_desc, output_dir, umap_emb=None, current_eps_val=None, current_min_samples_val=None):
    """Runs DBSCAN and plots results. Uses specific eps and min_samples if provided."""
    print(f"\n--- Running DBSCAN on {data_desc} ---")
    if data_for_dbscan is None or data_for_dbscan.shape[0] == 0:
        print(f"  No data for DBSCAN on {data_desc}. Skipping.")
        df[f'dbscan_{data_desc.lower().replace(" ","_")}']=-1
        return df

    eps_to_use = current_eps_val if current_eps_val is not None else DBSCAN_EPS_DEFAULT
    min_s_to_use = current_min_samples_val if current_min_samples_val is not None else DBSCAN_MIN_SAMPLES_DEFAULT

    if ESTIMATE_DBSCAN_EPS and current_eps_val is None :
        k_est = min(K_FOR_EPS_ESTIMATION, data_for_dbscan.shape[0]-1); k_est=max(1,k_est)
        nn=NearestNeighbors(n_neighbors=k_est); nn.fit(data_for_dbscan); dists, _ = nn.kneighbors(data_for_dbscan)
        actual_k_for_dists = min(k_est, dists.shape[1])
        if actual_k_for_dists > 0:
            k_dists = dists[:,actual_k_for_dists-1]
            k_dists_sorted = np.sort(k_dists)
            plt.figure(figsize=(8,6)); plt.plot(k_dists_sorted); plt.title(f'{actual_k_for_dists}-Dist Graph for Eps ({data_desc})');
            plt.xlabel("Points sorted by distance"); plt.ylabel(f"{actual_k_for_dists}-th NN Distance (eps candidate)"); plt.grid(True,alpha=0.3);
            eps_path=os.path.join(output_dir, f'dbscan_eps_est_{data_desc.lower().replace(" ","_")}.png'); plt.savefig(eps_path,dpi=300); plt.close();
            print(f"  Saved k-dist graph: {eps_path}. PLEASE INSPECT THIS PLOT TO SET appropriate DBSCAN_EPS for {data_desc}.")
            if len(k_dists_sorted)>10:
                sug_eps=np.percentile(k_dists_sorted,90);
                print(f"  A percentile-based suggestion for eps for {data_desc} is: {sug_eps:.3f}. The script will use eps={eps_to_use} (default or passed).")
        else:
            print(f"  Could not determine k-distances for eps estimation for {data_desc}. Using eps={eps_to_use}")

    print(f"  Running DBSCAN with eps={eps_to_use}, min_samples={min_s_to_use} on {data_desc}...")
    db=DBSCAN(eps=eps_to_use,min_samples=min_s_to_use).fit(data_for_dbscan)
    clust_col=f'dbscan_{data_desc.lower().replace(" ","_")}'; df[clust_col]=db.labels_
    n_clust=len(set(db.labels_))-(1 if -1 in db.labels_ else 0); n_noise=list(db.labels_).count(-1)
    print(f"  DBSCAN on {data_desc}: {n_clust} clusters, {n_noise} noise ({n_noise/len(df)*100:.2f}%).")

    if umap_emb is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10));
        labels_unique_dbscan=sorted(df[clust_col].unique());
        n_actual_clusters = len([l for l in labels_unique_dbscan if l != -1])
        dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)
        cdict = {}; cluster_idx = 0
        for lbl in labels_unique_dbscan:
            if lbl == -1: cdict[lbl] = (0.5, 0.5, 0.5, 1)
            else: cdict[lbl] = dbscan_cmap_obj(cluster_idx); cluster_idx += 1
        for k_val in labels_unique_dbscan:
            mask=(df[clust_col]==k_val); xy=umap_emb[mask]
            if xy.shape[0]>0:
                 plt.scatter(xy[:,0],xy[:,1], s=(20 if k_val!=-1 else 10), c=[cdict[k_val]],
                             marker=('o' if k_val!=-1 else 'x'), label=('Noise' if k_val == -1 else f'Cluster {k_val}'))
        plt.title(f'DBSCAN on {data_desc} (UMAP proj.)\neps={eps_to_use:.3f}, min_s={min_s_to_use}');
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2');
        plt.legend(title='DBSCAN Cluster',bbox_to_anchor=(1.05,1),loc='upper left',markerscale=1.5);
        plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]);
        plt.savefig(os.path.join(output_dir, f'dbscan_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close();
        print(f"  DBSCAN on {data_desc} plotted on UMAP.")
    return df

def run_gmm_and_plot(df, data_for_gmm, data_desc, output_dir, umap_embedding=None):
    """Runs Gaussian Mixture Model clustering and plots results."""
    print(f"\n--- Running Gaussian Mixture Model (GMM) on {data_desc} ---")
    if data_for_gmm is None or data_for_gmm.shape[0] == 0:
        print(f"  No data available for GMM on {data_desc}. Skipping.")
        df[f'gmm_cluster_{data_desc.lower().replace(" ", "_")}'] = -1
        df[f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'] = np.nan
        return df

    best_gmm = None; lowest_bic = np.inf
    print(f"  Testing GMM with n_components in {list(GMM_N_COMPONENTS_RANGE)} using BIC...")
    for n_components in GMM_N_COMPONENTS_RANGE:
        if n_components > data_for_gmm.shape[0]: continue
        gmm = GaussianMixture(n_components=n_components, covariance_type=GMM_COVARIANCE_TYPE, random_state=42, n_init=5)
        gmm.fit(data_for_gmm); bic = gmm.bic(data_for_gmm)
        print(f"    GMM with {n_components} components: BIC = {bic:.2f}")
        if bic < lowest_bic: lowest_bic = bic; best_gmm = gmm

    if best_gmm is None:
        print("  GMM fitting failed. Skipping GMM."); df[f'gmm_cluster_{data_desc.lower().replace(" ", "_")}'] = -1; df[f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'] = np.nan
        return df

    print(f"  Best GMM found with {best_gmm.n_components} components (BIC={lowest_bic:.2f}).")
    cluster_col_name = f'gmm_{data_desc.lower().replace(" ", "_")}'; prob_col_name = f'gmm_prob_max_{data_desc.lower().replace(" ", "_")}'
    df[cluster_col_name] = best_gmm.predict(data_for_gmm); df[prob_col_name] = np.max(best_gmm.predict_proba(data_for_gmm), axis=1)

    if 'senescence_score_normalized' in df.columns:
        print(f"  Mean senescence_score_normalized per GMM component (for {data_desc}):\n{df.groupby(cluster_col_name)['senescence_score_normalized'].mean().sort_values()}")

    if umap_embedding is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12, 10)); unique_gmm_labels = sorted(df[cluster_col_name].unique())
        gmm_palette = sns.color_palette("viridis", n_colors=len(unique_gmm_labels))
        for i, label in enumerate(unique_gmm_labels):
            subset = df[df[cluster_col_name] == label]
            plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=f'GMM Comp. {label}', color=gmm_palette[i], s=15, alpha=0.7)
        plt.title(f'GMM ({best_gmm.n_components} comp.) on {data_desc} (UMAP proj.)', fontsize=14)
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='GMM Component', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
        plt.savefig(os.path.join(output_dir, f'gmm_on_umap_{data_desc.lower().replace(" ", "_")}.png'), dpi=300); plt.close()
        print(f"  GMM on {data_desc} results plotted on UMAP.")

        plt.figure(figsize=(12, 10)); scatter_gmm_prob = plt.scatter(df['umap_x_refined'], df['umap_y_refined'], c=df[prob_col_name], cmap='magma', s=15, alpha=0.7, vmin=0, vmax=1)
        plt.colorbar(scatter_gmm_prob, label='Max Probability of GMM Assignment')
        plt.title(f'GMM Max Assignment Probability on {data_desc} (UMAP proj.)', fontsize=14)
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.grid(True, alpha=0.3); plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'gmm_prob_on_umap_{data_desc.lower().replace(" ", "_")}.png'), dpi=300); plt.close()
        print(f"  GMM max probability on {data_desc} plotted on UMAP.")
    return df

def apply_rule_based_gating(df_main, rules, default_label, output_dir, umap_embedding=None):
    """Applies a series of defined rules to classify cells. Operates on df_main."""
    print("\n--- Applying Rule-Based Gating ---")
    # This column will store the specific rule label that a cell matches
    df_main['rule_based_classification_granular'] = default_label

    all_rule_features = set()
    for rule in rules:
        for condition in rule['conditions']:
            all_rule_features.add(condition[0])

    missing_features_in_df = [feat for feat in all_rule_features if feat not in df_main.columns]
    if missing_features_in_df:
        print(f"  Error: The following features required by rules are missing from the DataFrame: {missing_features_in_df}. Skipping rule-based gating.")
        df_main['rule_based_binary_status'] = 'Unknown_Due_To_Missing_Features' # Indicate error
        return df_main

    for rule_idx, rule in enumerate(rules):
        print(f"  Applying rule: {rule['name']}")
        # Cells are eligible if they still have the default label for granular classification
        eligible_mask = (df_main['rule_based_classification_granular'] == default_label)
        if not eligible_mask.any():
            print(f"    No cells eligible for rule '{rule['name']}' (all already classified by prior rules).")
            continue

        rule_condition_mask = pd.Series([True] * len(df_main), index=df_main.index)
        for feature, operator, value in rule['conditions']:
            if feature not in df_main.columns:
                print(f"    Feature '{feature}' not found in DataFrame for rule '{rule['name']}'. Skipping this rule.")
                rule_condition_mask[:] = False; break
            try:
                feature_series = pd.to_numeric(df_main[feature], errors='coerce')
                # Check if coercion introduced NaNs where the original wasn't NaN (means type issue)
                if feature_series.isnull().sum() > df_main[feature].isnull().sum():
                    print(f"    Warning: Feature '{feature}' had values that could not be converted to numeric for rule '{rule['name']}'. These will not meet numeric conditions.")

                # Apply condition, NaNs in feature_series will result in False for comparisons
                if   operator == '>':  rule_condition_mask &= (feature_series > value)
                elif operator == '<':  rule_condition_mask &= (feature_series < value)
                elif operator == '>=': rule_condition_mask &= (feature_series >= value)
                elif operator == '<=': rule_condition_mask &= (feature_series <= value)
                elif operator == '==': rule_condition_mask &= (feature_series == value)
                elif operator == '!=': rule_condition_mask &= (feature_series != value)
                else:
                    print(f"    Unknown operator '{operator}' in rule '{rule['name']}'. Skipping condition.")
                    rule_condition_mask[:] = False; break
            except Exception as e:
                print(f"    Error comparing feature '{feature}' in rule '{rule['name']}': {e}. Skipping condition.")
                rule_condition_mask[:] = False; break

        if not rule_condition_mask.all() and not rule_condition_mask.any() and isinstance(rule_condition_mask, pd.Series) and not rule_condition_mask.empty : # If mask became all False due to an issue
             # This check was a bit problematic, simplifying: if no cells meet the rule after conditions:
             pass # The cells_to_label_now check below will handle it.


        if rule_condition_mask.any():
            cells_to_label_now = eligible_mask & rule_condition_mask
            df_main.loc[cells_to_label_now, 'rule_based_classification_granular'] = rule['output_label']
            print(f"    {cells_to_label_now.sum()} cells labeled as '{rule['output_label']}'.")
        else:
            print(f"    No cells met all conditions for rule '{rule['name']}' among the eligible ones.")

    print(f"\nGranular rule-based classification counts:\n{df_main['rule_based_classification_granular'].value_counts()}")

    # Create the binary 'Senescent' / 'Non-senescent' column based on granular rules
    df_main['rule_based_binary_status'] = np.where(
        df_main['rule_based_classification_granular'] == default_label,
        'Non-senescent', # Or simply default_label if you prefer
        'Senescent'
    )
    print(f"\nBinary rule-based classification counts:\n{df_main['rule_based_binary_status'].value_counts()}")


    # Plot granular rule-based classification
    if umap_embedding is not None and 'umap_x_refined' in df_main.columns and 'umap_y_refined' in df_main.columns:
        plt.figure(figsize=(12, 10)); unique_granular_labels = sorted(df_main['rule_based_classification_granular'].unique())
        if len(unique_granular_labels) > 0:
            granular_palette = sns.color_palette("Paired", n_colors=max(10, len(unique_granular_labels)))
            for i, label in enumerate(unique_granular_labels):
                subset = df_main[df_main['rule_based_classification_granular'] == label]
                plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=label, color=granular_palette[i % len(granular_palette)], s=15, alpha=0.7)
            plt.title('Rule-Based Gating (Granular Labels - UMAP proj.)', fontsize=14)
            plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule-Based Class (Granular)', bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, alpha=0.3); plt.tight_layout(rect=[0,0,0.80,1]) # Adjusted rect for potentially longer legend
            plt.savefig(os.path.join(output_dir, 'rule_based_gating_granular_on_umap.png'), dpi=300); plt.close()
            print("  Granular rule-based gating results plotted on UMAP.")

        # Plot binary rule-based classification
        plt.figure(figsize=(12, 10)); unique_binary_labels = sorted(df_main['rule_based_binary_status'].unique())
        binary_palette_map = {'Senescent': 'red', 'Non-senescent': 'blue', default_label: 'blue'} # Ensure default maps to non-senescent color
        # Add any other specific colors if 'Unknown_Due_To_Missing_Features' occurs
        if 'Unknown_Due_To_Missing_Features' in unique_binary_labels: binary_palette_map['Unknown_Due_To_Missing_Features'] = 'grey'


        for label in unique_binary_labels:
            subset = df_main[df_main['rule_based_binary_status'] == label]
            plt.scatter(subset['umap_x_refined'], subset['umap_y_refined'], label=label, color=binary_palette_map.get(label, 'grey'), s=15, alpha=0.7)
        plt.title('Rule-Based Gating (Binary Status - UMAP proj.)', fontsize=14)
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule-Based Status (Binary)', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
        plt.savefig(os.path.join(output_dir, 'rule_based_gating_binary_on_umap.png'), dpi=300); plt.close()
        print("  Binary rule-based gating results plotted on UMAP.")

    return df_main


def main_exploratory_analysis():
    """Main function to run exploratory analysis."""
    if not os.path.exists(EXPLORATORY_OUTPUT_DIR):
        os.makedirs(EXPLORATORY_OUTPUT_DIR)
        print(f"Created output directory: {EXPLORATORY_OUTPUT_DIR}")

    df = load_data(INPUT_REFINED_CSV_PATH)
    if df is None: return

    # scaled_features are for ML algos, feature_names_used_for_scaling are their names
    scaled_features, feature_names_used_for_scaling = preprocess_features_for_ml(df, FEATURES_FOR_ANALYSIS, AREA_FEATURES_TO_LOG)

    if scaled_features is None:
        print("Scaled feature preprocessing failed. Some ML-based analyses might be skipped or fail.")

    umap_embedding_for_plotting = None
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().all():
        print("\nUsing existing UMAP coordinates from input CSV for visualizations.")
        umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
    elif scaled_features is not None:
        print("\nRecomputing UMAP for visualization...")
        try:
            reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42, n_components=2)
            embedding = reducer.fit_transform(scaled_features)
            df['umap_x_refined'] = embedding[:, 0]; df['umap_y_refined'] = embedding[:, 1]
            umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
            print("  UMAP recomputed.")
        except Exception as e: print(f"  Error recomputing UMAP: {e}.")
    else: print("\nSkipping UMAP computation as scaled features are unavailable.")

    if scaled_features is not None and feature_names_used_for_scaling is not None:
        df = compute_and_plot_diffusion_map(df, scaled_features, feature_names_used_for_scaling, EXPLORATORY_OUTPUT_DIR)

        # DBSCAN on Scaled Features - User needs to set DBSCAN_EPS_SCALED_FEATURES
        DBSCAN_EPS_SCALED_FEATURES = 2.3 # From your previous successful run
        DBSCAN_MIN_SAMPLES_SCALED_FEATURES = 10
        print(f"\nNOTE: For DBSCAN on Scaled Features, using DBSCAN_EPS = {DBSCAN_EPS_SCALED_FEATURES}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_SCALED_FEATURES}")
        df = run_dbscan_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR,
                                 umap_emb=umap_embedding_for_plotting,
                                 current_eps_val=DBSCAN_EPS_SCALED_FEATURES,
                                 current_min_samples_val=DBSCAN_MIN_SAMPLES_SCALED_FEATURES)

        df = run_gmm_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

        if RUN_DBSCAN_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols = [f'dc_{i+1}' for i in range(N_DCS_FOR_DBSCAN) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols:
                data_dc = df[dc_cols].values
                DBSCAN_EPS_DCS = 0.01 # From your previous successful run
                DBSCAN_MIN_SAMPLES_DCS = 10
                print(f"\nNOTE: For DBSCAN on Top DCs, using DBSCAN_EPS = {DBSCAN_EPS_DCS}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_DCS}")
                df = run_dbscan_and_plot(df, data_dc, f"Top_{len(dc_cols)}_DCs", EXPLORATORY_OUTPUT_DIR,
                                         umap_emb=umap_embedding_for_plotting,
                                         current_eps_val=DBSCAN_EPS_DCS,
                                         current_min_samples_val=DBSCAN_MIN_SAMPLES_DCS)
            else: print(f"\nSkipping DBSCAN on DCs: Not enough valid DC columns.")

        global RUN_GMM_ON_DIFFMAP, N_DCS_FOR_GMM
        if RUN_GMM_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols_gmm = [f'dc_{i+1}' for i in range(N_DCS_FOR_GMM) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols_gmm:
                data_dc_gmm = df[dc_cols_gmm].values
                df = run_gmm_and_plot(df, data_dc_gmm, f"Top_{len(dc_cols_gmm)}_DCs", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)
            else: print(f"\nSkipping GMM on DCs: Not enough valid DC columns.")

    # Apply rule-based gating using the main df.
    # The 'df' passed here is the one that has been progressively updated.
    df = apply_rule_based_gating(df, RULE_BASED_GATES, RULE_BASED_DEFAULT_LABEL, EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

    exploratory_csv_path = os.path.join(EXPLORATORY_OUTPUT_DIR, 'exploratory_analysis_results_v5_rules.csv') # Incremented output filename
    df.to_csv(exploratory_csv_path, index=False)
    print(f"\nExploratory analysis results saved to: {exploratory_csv_path}")
    print("\nExploratory analysis script finished.")

RUN_GMM_ON_DIFFMAP = True
N_DCS_FOR_GMM = 3

if __name__ == '__main__':
    main_exploratory_analysis()


Created output directory: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules
Loading refined data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv...
Successfully loaded 2472 cells.

Preprocessing features for ML. Selected: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio', 'nuclear_enlargement', 'cell_enlargement']
  Log-transformed for scaling: cell_area
  Log-transformed for scaling: avg_nucleus_area
  Log-transformed for scaling: max_nucleus_area
  Log-transformed for scaling: cell_perimeter
  Features standardized for ML algorithms.

Using existing UMAP coordinates from input CSV for visualizations.

--- Computing Diffusion Map ---
  Computing

  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)


  DBSCAN on Scaled_Features plotted on UMAP.

--- Running Gaussian Mixture Model (GMM) on Scaled_Features ---
  Testing GMM with n_components in [2, 3, 4] using BIC...


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 2 components: BIC = -11876.19


  # that has no feature names.
  # that has no feature names.


    GMM with 3 components: BIC = -20370.16


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 4 components: BIC = -26836.71
  Best GMM found with 4 components (BIC=-26836.71).
  Mean senescence_score_normalized per GMM component (for Scaled_Features):
gmm_scaled_features
1    0.316948
0    0.365966
3    0.500682
2    0.500958
Name: senescence_score_normalized, dtype: float64
  GMM on Scaled_Features results plotted on UMAP.
  GMM max probability on Scaled_Features plotted on UMAP.

NOTE: For DBSCAN on Top DCs, using DBSCAN_EPS = 0.01, MIN_SAMPLES = 10

--- Running DBSCAN on Top_3_DCs ---
  Running DBSCAN with eps=0.01, min_samples=10 on Top_3_DCs...
  DBSCAN on Top_3_DCs: 2 clusters, 42 noise (1.70%).


  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters if n_actual_clusters > 0 else 1)


  DBSCAN on Top_3_DCs plotted on UMAP.

--- Running Gaussian Mixture Model (GMM) on Top_3_DCs ---
  Testing GMM with n_components in [2, 3, 4] using BIC...


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 2 components: BIC = -42744.54


  # that has no feature names.
  # that has no feature names.


    GMM with 3 components: BIC = -43790.09


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM with 4 components: BIC = -44590.14
  Best GMM found with 4 components (BIC=-44590.14).
  Mean senescence_score_normalized per GMM component (for Top_3_DCs):
gmm_top_3_dcs
0    0.319583
1    0.398703
2    0.427525
3    0.543944
Name: senescence_score_normalized, dtype: float64
  GMM on Top_3_DCs results plotted on UMAP.
  GMM max probability on Top_3_DCs plotted on UMAP.

--- Applying Rule-Based Gating ---
  Applying rule: Polynucleated
    225 cells labeled as 'Rule_Sen_Poly'.
  Applying rule: Very_Large_Cell
    307 cells labeled as 'Rule_Sen_VeryLarge'.
  Applying rule: Low_Circularity
    3 cells labeled as 'Rule_Sen_LowCirc'.
  Applying rule: Low_NucToCellRatio
    103 cells labeled as 'Rule_Sen_LowNucRatio'.
  Applying rule: High_Score_Not_Otherwise_Caught
    0 cells labeled as 'Rule_Sen_HighScore'.

Granular rule-based classification counts:
rule_based_classification_granular
Rule_NonSenescent       1834
Rule_Sen_VeryLarge       307
Rule_Sen_Poly            225
Rule_Sen_

In [25]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from skimage import io, measure, segmentation # Added segmentation
import cv2 # Added cv2
from scipy import ndimage # Added ndimage
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
import umap

try:
    import scanpy as sc
    SCANPY_AVAILABLE = True
except ImportError:
    print("Scanpy library not found. Diffusion map functionality will be skipped.")
    SCANPY_AVAILABLE = False

# --- Configuration & Parameters ---
INPUT_REFINED_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv"
EXPLORATORY_OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz" # Incremented version

# !! UPDATE THESE PATHS to your original mask image directories !!
CELL_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
NUCLEI_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
MASK_VISUALIZATION_SUBDIR_RULES = "mask_overlays_rule_based"


FEATURES_FOR_ANALYSIS = [
    'cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity',
    'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area',
    'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement',
    'nucleus_to_cell_area_ratio',
    'nuclear_enlargement', 'cell_enlargement'
]
AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']

N_DIFFUSION_COMPONENTS = 10
N_DCS_TO_PLOT = 3
N_NEIGHBORS_FOR_SCANPY = 15

DBSCAN_EPS_DEFAULT = 0.75
DBSCAN_MIN_SAMPLES_DEFAULT = 10
ESTIMATE_DBSCAN_EPS = True
K_FOR_EPS_ESTIMATION = 10
RUN_DBSCAN_ON_DIFFMAP = True
N_DCS_FOR_DBSCAN = 3

GMM_N_COMPONENTS_RANGE = range(2, 5)
GMM_COVARIANCE_TYPE = 'full'

RULE_BASED_GATES = [
    {   'name': 'Polynucleated',
        'conditions': [('nuclei_count', '>', 1)],
        'output_label': 'Rule_Sen_Poly' },
    {   'name': 'Very_Large_Cell',
        'conditions': [('cell_area', '>', 5000)],
        'output_label': 'Rule_Sen_VeryLarge' },
    {   'name': 'Low_Circularity',
        'conditions': [('cell_circularity', '<', 0.2)],
        'output_label': 'Rule_Sen_LowCirc' },
    {   'name': 'Low_NucToCellRatio',
        'conditions': [('nucleus_to_cell_area_ratio', '<', 0.1)],
        'output_label': 'Rule_Sen_LowNucRatio' },
    {   'name': 'High_Score_Not_Otherwise_Caught',
        'conditions': [('senescence_score_normalized', '>', 0.75)],
        'output_label': 'Rule_Sen_HighScore' }
]
RULE_BASED_DEFAULT_LABEL = 'Rule_NonSenescent'

# --- Helper Functions ---
def extract_sample_id(filename):
    """Extracts sample ID from filename (adapted from user's original notebook)."""
    base_name = os.path.splitext(filename)[0]
    if base_name.startswith('denoised_'):
        base_name = base_name[len('denoised_'):]
    pattern = re.compile(r'([\d\.]+Pa_[^_]+_[^_]+_[^_]+_[^_]+_[^_]+_seq\d+)')
    match = pattern.search(base_name)
    if match:
        return match.group(1)
    parts = base_name.split('_')
    for i, part in enumerate(parts):
        if part.startswith('seq') and i >= 2:
            return '_'.join(parts[:i+1])
    common_prefix = "_".join(filename.split('_')[:6])
    return common_prefix if 'seq' in common_prefix else os.path.splitext(os.path.basename(filename))[0]

def load_image_as_labeled_mask(filepath):
    """Loads a mask image, ensuring it's a labeled integer mask."""
    print(f"    Loading mask: {os.path.basename(filepath)}")
    try:
        img = io.imread(filepath)
        if img.ndim > 2: # Handle multi-channel
            if img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[-1] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else: img = img[..., 0]

        if img.dtype.kind in 'iu' and np.max(img) > 1: # Already labeled integer mask
            return img.astype(np.uint16)

        # If binary or float, threshold and label
        if img.dtype.kind == 'f': img = (img > 0.5).astype(np.uint8)
        elif np.max(img) == 1: img = img.astype(np.uint8)

        if np.max(img) <=1 : # Binary after potential conversion
            labeled_img, num_features = ndimage.label(img)
            print(f"    Labeled binary mask {os.path.basename(filepath)}, found {num_features} features.")
            return labeled_img.astype(np.uint16)

        return img.astype(np.uint16) # Assume uint8 with labels otherwise
    except Exception as e:
        print(f"    Error loading image {filepath}: {str(e)}"); return None

def load_data(csv_path):
    """Loads the refined data and derives sample_id if needed."""
    print(f"Loading refined data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded {len(df)} cells.")

        if 'cell_id' in df.columns:
            df['derived_sample_id'] = df['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
            print("  Derived 'derived_sample_id' from 'cell_id' for mask matching.")
        else:
            print("Error: 'cell_id' column missing. Cannot derive sample IDs for mask matching.")
            return None

        essential_cols = ['senescence_score_normalized', 'nuclei_count', 'cell_area',
                          'cell_circularity', 'nucleus_to_cell_area_ratio']
        for col_check in essential_cols:
            if col_check not in df.columns:
                print(f"Warning: Essential column '{col_check}' for rules/scoring not found.")
        return df
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}"); return None

def preprocess_features_for_ml(df, feature_columns, log_transform_cols):
    """Prepares features specifically for ML algorithms (scaling)."""
    print(f"\nPreprocessing features for ML. Selected: {feature_columns}")
    actual_features_for_ml = [col for col in feature_columns if col in df.columns]
    if not actual_features_for_ml: return None, None
    features_for_scaling_df = df[actual_features_for_ml].copy()
    for col in log_transform_cols:
        if col in features_for_scaling_df.columns:
            features_for_scaling_df[col] = np.log1p(features_for_scaling_df[col])
            print(f"  Log-transformed for scaling: {col}")
    if features_for_scaling_df.isnull().sum().any():
        features_for_scaling_df = features_for_scaling_df.fillna(features_for_scaling_df.mean())
    cols_to_drop_scaled = features_for_scaling_df.columns[features_for_scaling_df.isna().all()].tolist()
    if cols_to_drop_scaled:
        features_for_scaling_df = features_for_scaling_df.drop(columns=cols_to_drop_scaled)
        actual_features_for_ml = [f for f in actual_features_for_ml if f not in cols_to_drop_scaled]
    if features_for_scaling_df.empty or not actual_features_for_ml: return None, None
    scaler = StandardScaler(); features_scaled = scaler.fit_transform(features_for_scaling_df)
    print("  Features standardized for ML algorithms.")
    return features_scaled, actual_features_for_ml

def compute_and_plot_diffusion_map(df, scaled_features, feature_names_used, output_dir):
    """Computes and plots diffusion map."""
    if not SCANPY_AVAILABLE: print("Skipping diffusion map: Scanpy not available."); return df
    print("\n--- Computing Diffusion Map ---")
    if scaled_features is None or scaled_features.shape[0] == 0 : print("  No scaled features for Diffusion Map. Skipping."); return df
    adata = sc.AnnData(scaled_features, var=pd.DataFrame(index=feature_names_used))
    adata.obs_names = df.index.astype(str)
    if 'senescence_score_normalized' in df.columns: adata.obs['senescence_score_normalized'] = df['senescence_score_normalized'].values
    if 'cell_type_final' in df.columns: adata.obs['cell_type_final'] = df['cell_type_final'].astype('category').values
    actual_n_neighbors = min(N_NEIGHBORS_FOR_SCANPY, adata.n_obs - 1)
    if actual_n_neighbors < 2: print(f"  Not enough samples for Scanpy neighbors. Skipping."); return df
    print(f"  Computing neighbors (k={actual_n_neighbors})..."); sc.pp.neighbors(adata, n_neighbors=actual_n_neighbors, use_rep='X')
    print("  Running sc.tl.diffmap..."); sc.tl.diffmap(adata, n_comps=N_DIFFUSION_COMPONENTS)
    if 'X_diffmap' in adata.obsm:
        num_dc = min(N_DIFFUSION_COMPONENTS, adata.obsm['X_diffmap'].shape[1] - 1)
        for i in range(num_dc): df[f'dc_{i+1}'] = adata.obsm['X_diffmap'][:, i+1]
        print(f"  Added {num_dc} DCs to DataFrame.")
        pairs = [(f'dc_{i}', f'dc_{j}') for i in range(1, N_DCS_TO_PLOT + 1) for j in range(i + 1, N_DCS_TO_PLOT + 1) if f'dc_{i}' in df.columns and f'dc_{j}' in df.columns]
        for dcx, dcy in pairs:
            if df[dcx].notna().any() and df[dcy].notna().any():
                if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
                    plt.figure(figsize=(10,8)); plt.scatter(df[dcx], df[dcy], c=df['senescence_score_normalized'], cmap='viridis', s=12, alpha=0.7); plt.colorbar(label='Norm. Senescence Score')
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Score'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper()); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_score.png'),dpi=300,bbox_inches='tight'); plt.close()
                if 'cell_type_final' in df.columns:
                    plt.figure(figsize=(10,8)); types=df['cell_type_final'].unique(); pal={t:('red' if t=='Senescent' else ('blue' if t=='Non-senescent' else 'grey')) for t in types}
                    for ct,col in pal.items(): subset=df[df['cell_type_final']==ct]; plt.scatter(subset[dcx],subset[dcy],label=ct,color=col,s=12,alpha=0.7)
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Prev. Classif.'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper());
                    if pal: plt.legend(title='Previous Final Cell Type'); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_prev_type.png'),dpi=300,bbox_inches='tight'); plt.close()
        print(f"  DiffMap pair plots for top {N_DCS_TO_PLOT} DCs saved.")
    else: print("  Error: 'X_diffmap' not found.")
    return df

def run_dbscan_and_plot(df, data_for_dbscan, data_desc, output_dir, umap_emb=None, current_eps_val=None, current_min_samples_val=None):
    """Runs DBSCAN and plots results. Uses specific eps and min_samples if provided."""
    print(f"\n--- Running DBSCAN on {data_desc} ---")
    if data_for_dbscan is None or data_for_dbscan.shape[0] == 0:
        print(f"  No data for DBSCAN on {data_desc}. Skipping.")
        df[f'dbscan_{data_desc.lower().replace(" ","_")}']=-1
        return df

    eps_to_use = current_eps_val if current_eps_val is not None else DBSCAN_EPS_DEFAULT
    min_s_to_use = current_min_samples_val if current_min_samples_val is not None else DBSCAN_MIN_SAMPLES_DEFAULT

    if ESTIMATE_DBSCAN_EPS and current_eps_val is None :
        k_est = min(K_FOR_EPS_ESTIMATION, data_for_dbscan.shape[0]-1); k_est=max(1,k_est)
        nn=NearestNeighbors(n_neighbors=k_est); nn.fit(data_for_dbscan); dists, _ = nn.kneighbors(data_for_dbscan)
        actual_k_for_dists = min(k_est, dists.shape[1])
        if actual_k_for_dists > 0:
            k_dists = dists[:,actual_k_for_dists-1]
            k_dists_sorted = np.sort(k_dists)
            plt.figure(figsize=(8,6)); plt.plot(k_dists_sorted); plt.title(f'{actual_k_for_dists}-Dist Graph for Eps ({data_desc})');
            plt.xlabel("Points sorted by distance"); plt.ylabel(f"{actual_k_for_dists}-th NN Distance (eps candidate)"); plt.grid(True,alpha=0.3);
            eps_path=os.path.join(output_dir, f'dbscan_eps_est_{data_desc.lower().replace(" ","_")}.png'); plt.savefig(eps_path,dpi=300); plt.close();
            print(f"  Saved k-dist graph: {eps_path}. PLEASE INSPECT THIS PLOT TO SET appropriate DBSCAN_EPS for {data_desc}.")
            if len(k_dists_sorted)>10:
                sug_eps=np.percentile(k_dists_sorted,90);
                print(f"  A percentile-based suggestion for eps for {data_desc} is: {sug_eps:.3f}. The script will use eps={eps_to_use} (default or passed).")
        else:
            print(f"  Could not determine k-distances for eps estimation for {data_desc}. Using eps={eps_to_use}")

    print(f"  Running DBSCAN with eps={eps_to_use}, min_samples={min_s_to_use} on {data_desc}...")
    db=DBSCAN(eps=eps_to_use,min_samples=min_s_to_use).fit(data_for_dbscan)
    clust_col=f'dbscan_{data_desc.lower().replace(" ","_")}'; df[clust_col]=db.labels_
    n_clust=len(set(db.labels_))-(1 if -1 in db.labels_ else 0); n_noise=list(db.labels_).count(-1)
    print(f"  DBSCAN on {data_desc}: {n_clust} clusters, {n_noise} noise ({n_noise/len(df)*100:.2f}%).")

    if umap_emb is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10))
        labels_unique_dbscan=sorted(df[clust_col].unique())
        n_actual_clusters = len([l for l in labels_unique_dbscan if l != -1])
        # Ensure cmap is correctly called
        if n_actual_clusters > 0:
            dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)
        else: # Handle case with no actual clusters (only noise)
            dbscan_cmap_obj = plt.cm.get_cmap('Spectral', 1)

        cdict = {}
        cluster_idx = 0
        for lbl in labels_unique_dbscan:
            if lbl == -1:
                cdict[lbl] = (0.5, 0.5, 0.5, 1) # Grey for noise
            else:
                cdict[lbl] = dbscan_cmap_obj(cluster_idx)
                cluster_idx += 1

        for k_val in labels_unique_dbscan:
            mask=(df[clust_col]==k_val)
            xy=umap_emb[mask]
            if xy.shape[0]>0:
                 # Corrected indentation for the plt.scatter call
                 plt.scatter(xy[:,0],xy[:,1],
                             s=(20 if k_val!=-1 else 10),
                             c=[cdict[k_val]],
                             marker=('o' if k_val!=-1 else 'x'),
                             label=('Noise' if k_val == -1 else f'Cluster {k_val}'))
        plt.title(f'DBSCAN on {data_desc} (UMAP proj.)\neps={eps_to_use:.3f}, min_s={min_s_to_use}')
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2')
        plt.legend(title='DBSCAN Cluster',bbox_to_anchor=(1.05,1),loc='upper left',markerscale=1.5)
        plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
        plt.savefig(os.path.join(output_dir, f'dbscan_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
        print(f"  DBSCAN on {data_desc} plotted on UMAP.")
    return df

def run_gmm_and_plot(df, data_for_gmm, data_desc, output_dir, umap_embedding=None):
    """Runs GMM and plots results."""
    print(f"\n--- Running GMM on {data_desc} ---")
    if data_for_gmm is None: df[f'gmm_cluster_{data_desc.lower().replace(" ","_")}']=-1; df[f'gmm_prob_max_{data_desc.lower().replace(" ","_")}']=np.nan; return df
    best_gmm=None; lowest_bic=np.inf
    for n_comp in GMM_N_COMPONENTS_RANGE:
        if n_comp > data_for_gmm.shape[0]: continue
        gmm=GaussianMixture(n_components=n_comp,covariance_type=GMM_COVARIANCE_TYPE,random_state=42,n_init=5).fit(data_for_gmm); bic=gmm.bic(data_for_gmm); print(f"    GMM {n_comp} comps: BIC={bic:.2f}")
        if bic<lowest_bic: lowest_bic=bic; best_gmm=gmm
    if best_gmm is None: df[f'gmm_cluster_{data_desc.lower().replace(" ","_")}']=-1; df[f'gmm_prob_max_{data_desc.lower().replace(" ","_")}']=np.nan; return df
    print(f"  Best GMM: {best_gmm.n_components} components (BIC={lowest_bic:.2f}).")
    clust_col,prob_col = f'gmm_{data_desc.lower().replace(" ","_")}',f'gmm_prob_max_{data_desc.lower().replace(" ","_")}'
    df[clust_col]=best_gmm.predict(data_for_gmm); df[prob_col]=np.max(best_gmm.predict_proba(data_for_gmm),axis=1)
    if 'senescence_score_normalized' in df.columns: print(f"  Mean sen_score per GMM comp ({data_desc}):\n{df.groupby(clust_col)['senescence_score_normalized'].mean().sort_values()}")
    if umap_embedding is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10)); labels=sorted(df[clust_col].unique()); pal=sns.color_palette("viridis",n_colors=len(labels))
        for i,lbl in enumerate(labels): subset=df[df[clust_col]==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=f'GMM Comp. {lbl}',color=pal[i],s=15,alpha=0.7)
        plt.title(f'GMM ({best_gmm.n_components} comp.) on {data_desc} (UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='GMM Comp.',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]); plt.savefig(os.path.join(output_dir,f'gmm_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
        plt.figure(figsize=(12,10)); scatter_prob=plt.scatter(df['umap_x_refined'],df['umap_y_refined'],c=df[prob_col],cmap='magma',s=15,alpha=0.7,vmin=0,vmax=1); plt.colorbar(scatter_prob,label='Max GMM Prob.'); plt.title(f'GMM Max Prob. on {data_desc} (UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.grid(True,alpha=0.3); plt.tight_layout(); plt.savefig(os.path.join(output_dir,f'gmm_prob_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
    return df

def apply_rule_based_gating(df_main, rules, default_label, output_dir, umap_embedding=None):
    """Applies rules to classify cells, creates granular and binary status, and plots."""
    print("\n--- Applying Rule-Based Gating ---")
    df_main['rule_based_classification_granular'] = default_label
    all_rule_features = set(cond[0] for rule in rules for cond in rule['conditions'])
    missing_features = [feat for feat in all_rule_features if feat not in df_main.columns]
    if missing_features:
        print(f"  Error: Features for rules missing from DataFrame: {missing_features}. Skipping."); return df_main

    for rule in rules:
        print(f"  Applying rule: {rule['name']}")
        eligible_mask = (df_main['rule_based_classification_granular'] == default_label)
        if not eligible_mask.any(): print(f"    No cells eligible for rule '{rule['name']}'."); continue
        current_condition_mask = pd.Series([True] * len(df_main), index=df_main.index)
        for feature, operator, value in rule['conditions']:
            if feature not in df_main.columns: print(f"    Feature '{feature}' not found. Skipping rule."); current_condition_mask[:]=False; break
            try:
                feat_series = pd.to_numeric(df_main[feature], errors='coerce')
                if feat_series.isnull().any() and not df_main[feature].isnull().all() : print(f"    Warning: Coercion to numeric for '{feature}' created NaNs.")
                if   operator == '>':  current_condition_mask &= (feat_series > value)
                elif operator == '<':  current_condition_mask &= (feat_series < value)
                elif operator == '>=': current_condition_mask &= (feat_series >= value)
                elif operator == '<=': current_condition_mask &= (feat_series <= value)
                elif operator == '==': current_condition_mask &= (feat_series == value)
                elif operator == '!=': current_condition_mask &= (feat_series != value)
                else: print(f"    Unknown operator '{operator}'."); current_condition_mask[:]=False; break
            except Exception as e: print(f"    Error comparing '{feature}': {e}."); current_condition_mask[:]=False; break

        # Check if current_condition_mask is valid before proceeding
        if isinstance(current_condition_mask, pd.Series) and not current_condition_mask.empty:
            if not current_condition_mask.any():
                print(f"    No cells met conditions for rule '{rule['name']}'.")
                continue
            cells_to_label = eligible_mask & current_condition_mask
            df_main.loc[cells_to_label, 'rule_based_classification_granular'] = rule['output_label']
            print(f"    {cells_to_label.sum()} cells labeled as '{rule['output_label']}'.")
        else: # Mask became invalid (e.g. all False due to error in condition)
            print(f"    Rule '{rule['name']}' resulted in an invalid condition mask or no cells met conditions. No cells labeled by this rule.")


    print(f"\nGranular rule-based counts:\n{df_main['rule_based_classification_granular'].value_counts()}")
    df_main['rule_based_binary_status'] = np.where(df_main['rule_based_classification_granular'] == default_label, 'Non-senescent', 'Senescent')
    print(f"\nBinary rule-based counts:\n{df_main['rule_based_binary_status'].value_counts()}")

    if umap_embedding is not None and 'umap_x_refined' in df_main.columns and 'umap_y_refined' in df_main.columns:
        plt.figure(figsize=(12,10)); granular_labels = sorted(df_main['rule_based_classification_granular'].unique())
        if granular_labels:
            pal_gran = sns.color_palette("Paired",n_colors=max(10,len(granular_labels)))
            for i,lbl in enumerate(granular_labels): subset=df_main[df_main['rule_based_classification_granular']==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=lbl,color=pal_gran[i%len(pal_gran)],s=15,alpha=0.7)
            plt.title('Rule-Based Gating (Granular - UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule Class (Granular)',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.80,1]); plt.savefig(os.path.join(output_dir,'rule_based_gating_granular_on_umap.png'),dpi=300); plt.close()
        plt.figure(figsize=(12,10)); binary_labels = sorted(df_main['rule_based_binary_status'].unique()); bin_pal={'Senescent':'red','Non-senescent':'blue','Unknown_Due_To_Missing_Features':'grey'}
        for lbl in binary_labels: subset=df_main[df_main['rule_based_binary_status']==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=lbl,color=bin_pal.get(lbl,'grey'),s=15,alpha=0.7)
        plt.title('Rule-Based Gating (Binary - UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule Status (Binary)',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]); plt.savefig(os.path.join(output_dir,'rule_based_gating_binary_on_umap.png'),dpi=300); plt.close()
    return df_main

def visualize_rule_classification_on_masks(df_results, cell_mask_dir, nuclei_mask_dir, output_dir_masks, classification_column='rule_based_binary_status'):
    """Visualizes a specified classification column (e.g., rule-based) on original mask images."""
    print(f"\nGenerating classification overlays on original masks using column '{classification_column}' in: {output_dir_masks}")
    if not os.path.exists(output_dir_masks): os.makedirs(output_dir_masks)
    sen_color, nonsen_color, nuc_color, unknown_color = [255,0,0], [0,0,255], [0,255,0], [128,128,128]
    if 'derived_sample_id' not in df_results.columns or classification_column not in df_results.columns:
        print(f"Error: 'derived_sample_id' or '{classification_column}' not found. Cannot proceed."); return
    unique_samples = df_results['derived_sample_id'].unique()
    classification_lookup = pd.Series(df_results[classification_column].values, index=df_results.cell_id).to_dict()
    available_cell_masks = {extract_sample_id(f): f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff'))}
    available_nuclei_masks = {extract_sample_id(f): f for f in os.listdir(nuclei_mask_dir) if f.endswith(('.tif', '.tiff'))}

    for sample_id_csv in tqdm(unique_samples, desc=f"Mask viz ({classification_column})"):
        cell_mask_filename = available_cell_masks.get(sample_id_csv)
        nuclei_mask_filename = available_nuclei_masks.get(sample_id_csv)
        if not cell_mask_filename: print(f"  Cell mask not found for {sample_id_csv}"); continue
        print(f"\n  Overlaying sample: {sample_id_csv}")
        cell_mask = load_image_as_labeled_mask(os.path.join(cell_mask_dir, cell_mask_filename))
        if cell_mask is None: continue
        overlay = np.zeros((cell_mask.shape[0], cell_mask.shape[1], 3), dtype=np.uint8)
        for props in measure.regionprops(cell_mask):
            full_cell_id = f"{sample_id_csv}_{props.label}"
            status = classification_lookup.get(full_cell_id, 'Unknown')
            color_to_use = unknown_color
            if status == 'Senescent': color_to_use = sen_color
            elif status == 'Non-senescent': color_to_use = nonsen_color
            elif status != 'Unknown' and classification_column == 'rule_based_classification_granular': # For granular, use a fallback if not Sen/NonSen
                 color_to_use = [np.random.randint(50,200) for _ in range(3)] # Randomish color for other granular rules
            overlay[cell_mask == props.label] = color_to_use
        if nuclei_mask_filename:
            nuc_mask = load_image_as_labeled_mask(os.path.join(nuclei_mask_dir, nuclei_mask_filename))
            if nuc_mask is not None:
                nuc_boundaries = segmentation.find_boundaries(nuc_mask, mode='inner', background=0)
                overlay[nuc_boundaries] = nuc_color
        fig_leg, ax_leg = plt.subplots(figsize=(max(10, overlay.shape[1]/100), max(8, overlay.shape[0]/100)), dpi=100) # Ensure min size
        ax_leg.imshow(overlay)
        handles = [mpatches.Patch(color=np.array(sen_color)/255., label='Senescent'),
                   mpatches.Patch(color=np.array(nonsen_color)/255., label='Non-senescent')]
        if 'Unknown' in df_results[classification_column].unique(): handles.append(mpatches.Patch(color=np.array(unknown_color)/255., label='Unknown'))
        if nuclei_mask_filename and nuc_mask is not None: handles.append(mpatches.Patch(color=np.array(nuc_color)/255., label='Nuclei Outline'))
        ax_leg.legend(handles=handles, loc='upper right', fontsize='small', bbox_to_anchor=(1.45, 1)); ax_leg.axis('off'); plt.tight_layout()
        plt.savefig(os.path.join(output_dir_masks, f"{sample_id_csv}_{classification_column}_overlay.png"), dpi=150); plt.close(fig_leg)
    print(f"Mask overlay visualization for '{classification_column}' complete.")


def main_exploratory_analysis():
    """Main function to run exploratory analysis."""
    if not os.path.exists(EXPLORATORY_OUTPUT_DIR):
        os.makedirs(EXPLORATORY_OUTPUT_DIR)
        print(f"Created output directory: {EXPLORATORY_OUTPUT_DIR}")

    df = load_data(INPUT_REFINED_CSV_PATH)
    if df is None: return

    scaled_features, feature_names_used_for_scaling = preprocess_features_for_ml(df, FEATURES_FOR_ANALYSIS, AREA_FEATURES_TO_LOG)

    umap_embedding_for_plotting = None
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().all():
        print("\nUsing existing UMAP coordinates from input CSV for visualizations.")
        umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
    elif scaled_features is not None:
        print("\nRecomputing UMAP for visualization...")
        try:
            reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42, n_components=2)
            embedding = reducer.fit_transform(scaled_features)
            df['umap_x_refined'] = embedding[:, 0]; df['umap_y_refined'] = embedding[:, 1]
            umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
            print("  UMAP recomputed.")
        except Exception as e: print(f"  Error recomputing UMAP: {e}.")
    else: print("\nSkipping UMAP computation as scaled features are unavailable.")

    if scaled_features is not None and feature_names_used_for_scaling is not None:
        df = compute_and_plot_diffusion_map(df, scaled_features, feature_names_used_for_scaling, EXPLORATORY_OUTPUT_DIR)

        # --- DBSCAN on Scaled Features ---
        # User should set this based on k-distance plot from previous run or set ESTIMATE_DBSCAN_EPS = True
        DBSCAN_EPS_FOR_SCALED_FEATURES = 2.3 # Example from your previous output
        DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES = 10
        print(f"\nNOTE: For DBSCAN on Scaled Features, using EPS = {DBSCAN_EPS_FOR_SCALED_FEATURES}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES}")
        df = run_dbscan_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR,
                                 umap_emb=umap_embedding_for_plotting,
                                 current_eps_val=DBSCAN_EPS_FOR_SCALED_FEATURES,
                                 current_min_samples_val=DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES)

        df = run_gmm_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

        if RUN_DBSCAN_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols = [f'dc_{i+1}' for i in range(N_DCS_FOR_DBSCAN) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols:
                data_dc = df[dc_cols].values
                # User should set this based on k-distance plot for DCs
                DBSCAN_EPS_FOR_DCS = 0.01 # Example from your previous output
                DBSCAN_MIN_SAMPLES_FOR_DCS = 10
                print(f"\nNOTE: For DBSCAN on Top DCs, using EPS = {DBSCAN_EPS_FOR_DCS}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_FOR_DCS}")
                df = run_dbscan_and_plot(df, data_dc, f"Top_{len(dc_cols)}_DCs", EXPLORATORY_OUTPUT_DIR,
                                         umap_emb=umap_embedding_for_plotting,
                                         current_eps_val=DBSCAN_EPS_FOR_DCS,
                                         current_min_samples_val=DBSCAN_MIN_SAMPLES_FOR_DCS)
            else: print(f"\nSkipping DBSCAN on DCs: Not enough valid DC columns.")

        global RUN_GMM_ON_DIFFMAP, N_DCS_FOR_GMM
        if RUN_GMM_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols_gmm = [f'dc_{i+1}' for i in range(N_DCS_FOR_GMM) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols_gmm:
                data_dc_gmm = df[dc_cols_gmm].values
                df = run_gmm_and_plot(df, data_dc_gmm, f"Top_{len(dc_cols_gmm)}_DCs", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)
            else: print(f"\nSkipping GMM on DCs: Not enough valid DC columns.")

    # Apply rule-based gating using the main df.
    df = apply_rule_based_gating(df, RULE_BASED_GATES, RULE_BASED_DEFAULT_LABEL, EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

    # Visualize the rule-based classification on masks
    mask_overlay_output_path = os.path.join(EXPLORATORY_OUTPUT_DIR, MASK_VISUALIZATION_SUBDIR_RULES)
    if 'rule_based_binary_status' in df.columns: # Check if rule-based classification ran successfully
        visualize_rule_classification_on_masks(df, CELL_MASK_DIR, NUCLEI_MASK_DIR, mask_overlay_output_path, classification_column='rule_based_binary_status')
        # Optional: visualize granular rule classification on masks
        # visualize_rule_classification_on_masks(df, CELL_MASK_DIR, NUCLEI_MASK_DIR, mask_overlay_output_path, classification_column='rule_based_classification_granular')
    else:
        print("Skipping mask visualization for rule-based classification as the classification column is missing.")


    exploratory_csv_path = os.path.join(EXPLORATORY_OUTPUT_DIR, 'exploratory_analysis_results_v5_rules_maskviz.csv')
    df.to_csv(exploratory_csv_path, index=False)
    print(f"\nExploratory analysis results saved to: {exploratory_csv_path}")
    print("\nExploratory analysis script finished.")

RUN_GMM_ON_DIFFMAP = True
N_DCS_FOR_GMM = 3

if __name__ == '__main__':
    main_exploratory_analysis()


Created output directory: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz
Loading refined data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv...
Successfully loaded 2472 cells.
  Derived 'derived_sample_id' from 'cell_id' for mask matching.

Preprocessing features for ML. Selected: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio', 'nuclear_enlargement', 'cell_enlargement']
  Log-transformed for scaling: cell_area
  Log-transformed for scaling: avg_nucleus_area
  Log-transformed for scaling: max_nucleus_area
  Log-transformed for scaling: cell_perimeter
  Features standardized for ML algorithms.

Using existing UMAP coordinates from i

  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)


  DBSCAN on Scaled_Features plotted on UMAP.

--- Running GMM on Scaled_Features ---


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 2 comps: BIC=-11876.19


  # that has no feature names.
  # that has no feature names.


    GMM 3 comps: BIC=-20370.16


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 4 comps: BIC=-26836.71
  Best GMM: 4 components (BIC=-26836.71).
  Mean sen_score per GMM comp (Scaled_Features):
gmm_scaled_features
1    0.316948
0    0.365966
3    0.500682
2    0.500958
Name: senescence_score_normalized, dtype: float64

NOTE: For DBSCAN on Top DCs, using EPS = 0.01, MIN_SAMPLES = 10

--- Running DBSCAN on Top_3_DCs ---
  Running DBSCAN with eps=0.01, min_samples=10 on Top_3_DCs...
  DBSCAN on Top_3_DCs: 2 clusters, 42 noise (1.70%).


  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)


  DBSCAN on Top_3_DCs plotted on UMAP.

--- Running GMM on Top_3_DCs ---


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 2 comps: BIC=-42744.54
    GMM 3 comps: BIC=-43790.09


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 4 comps: BIC=-44590.14
  Best GMM: 4 components (BIC=-44590.14).
  Mean sen_score per GMM comp (Top_3_DCs):
gmm_top_3_dcs
0    0.319583
1    0.398703
2    0.427525
3    0.543944
Name: senescence_score_normalized, dtype: float64

--- Applying Rule-Based Gating ---
  Applying rule: Polynucleated
    225 cells labeled as 'Rule_Sen_Poly'.
  Applying rule: Very_Large_Cell
    307 cells labeled as 'Rule_Sen_VeryLarge'.
  Applying rule: Low_Circularity
    3 cells labeled as 'Rule_Sen_LowCirc'.
  Applying rule: Low_NucToCellRatio
    103 cells labeled as 'Rule_Sen_LowNucRatio'.
  Applying rule: High_Score_Not_Otherwise_Caught
    0 cells labeled as 'Rule_Sen_HighScore'.

Granular rule-based counts:
rule_based_classification_granular
Rule_NonSenescent       1834
Rule_Sen_VeryLarge       307
Rule_Sen_Poly            225
Rule_Sen_LowNucRatio     103
Rule_Sen_LowCirc           3
Name: count, dtype: int64

Binary rule-based counts:
rule_based_binary_status
Non-senescent    1834
Senescent  

Mask viz (rule_based_binary_status):   0%|          | 0/8 [00:00<?, ?it/s]


  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  12%|█▎        | 1/8 [00:01<00:10,  1.56s/it]


  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq002
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq002_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  25%|██▌       | 2/8 [00:03<00:09,  1.61s/it]


  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq003
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq003_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  38%|███▊      | 3/8 [00:04<00:07,  1.59s/it]


  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq001_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  50%|█████     | 4/8 [00:06<00:05,  1.48s/it]


  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq002_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  62%|██████▎   | 5/8 [00:07<00:03,  1.33s/it]


  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq003_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  75%|███████▌  | 6/8 [00:08<00:02,  1.32s/it]


  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq004_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  88%|████████▊ | 7/8 [00:10<00:01,  1.48s/it]


  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq005_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status): 100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Mask overlay visualization for 'rule_based_binary_status' complete.

Exploratory analysis results saved to: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/exploratory_analysis_results_v5_rules_maskviz.csv

Exploratory analysis script finished.


In [27]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from skimage import io, measure, segmentation # Added segmentation
import cv2 # Added cv2
from scipy import ndimage # Added ndimage
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
import umap

try:
    import scanpy as sc
    SCANPY_AVAILABLE = True
except ImportError:
    print("Scanpy library not found. Diffusion map functionality will be skipped.")
    SCANPY_AVAILABLE = False

# --- Configuration & Parameters ---
INPUT_REFINED_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv"
EXPLORATORY_OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz"

# !! UPDATE THESE PATHS to your original mask image directories !!
CELL_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
NUCLEI_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"
MASK_VISUALIZATION_SUBDIR_RULES = "mask_overlays_rule_based"


FEATURES_FOR_ANALYSIS = [
    'cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity',
    'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area',
    'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement',
    'nucleus_to_cell_area_ratio',
    'nuclear_enlargement', 'cell_enlargement'
]
AREA_FEATURES_TO_LOG = ['cell_area', 'avg_nucleus_area', 'max_nucleus_area', 'cell_perimeter']

N_DIFFUSION_COMPONENTS = 10
N_DCS_TO_PLOT = 3
N_NEIGHBORS_FOR_SCANPY = 15

DBSCAN_EPS_DEFAULT = 0.75
DBSCAN_MIN_SAMPLES_DEFAULT = 10
ESTIMATE_DBSCAN_EPS = True
K_FOR_EPS_ESTIMATION = 10
RUN_DBSCAN_ON_DIFFMAP = True
N_DCS_FOR_DBSCAN = 3

GMM_N_COMPONENTS_RANGE = range(2, 5)
GMM_COVARIANCE_TYPE = 'full'

RULE_BASED_GATES = [
    {   'name': 'Polynucleated',
        'conditions': [('nuclei_count', '>', 1)],
        'output_label': 'Rule_Sen_Poly' },
    {   'name': 'Very_Large_Cell',
        'conditions': [('cell_area', '>', 5000)],
        'output_label': 'Rule_Sen_VeryLarge' },
    {   'name': 'Low_Circularity',
        'conditions': [('cell_circularity', '<', 0.2)],
        'output_label': 'Rule_Sen_LowCirc' },
    {   'name': 'Low_NucToCellRatio',
        'conditions': [('nucleus_to_cell_area_ratio', '<', 0.1)],
        'output_label': 'Rule_Sen_LowNucRatio' },
    {   'name': 'High_Score_Not_Otherwise_Caught',
        'conditions': [('senescence_score_normalized', '>', 0.75)],
        'output_label': 'Rule_Sen_HighScore' }
]
RULE_BASED_DEFAULT_LABEL = 'Rule_NonSenescent'

# --- Helper Functions ---
def extract_sample_id(filename):
    """Extracts sample ID from filename (adapted from user's original notebook)."""
    base_name = os.path.splitext(filename)[0]
    if base_name.startswith('denoised_'):
        base_name = base_name[len('denoised_'):]
    pattern = re.compile(r'([\d\.]+Pa_[^_]+_[^_]+_[^_]+_[^_]+_[^_]+_seq\d+)')
    match = pattern.search(base_name)
    if match:
        return match.group(1)
    parts = base_name.split('_')
    for i, part in enumerate(parts):
        if part.startswith('seq') and i >= 2:
            return '_'.join(parts[:i+1])
    common_prefix = "_".join(filename.split('_')[:6])
    return common_prefix if 'seq' in common_prefix else os.path.splitext(os.path.basename(filename))[0]

def load_image_as_labeled_mask(filepath):
    """Loads a mask image, ensuring it's a labeled integer mask."""
    print(f"    Loading mask: {os.path.basename(filepath)}")
    try:
        img = io.imread(filepath)
        if img.ndim > 2: # Handle multi-channel
            if img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[-1] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else: img = img[..., 0]

        if img.dtype.kind in 'iu' and np.max(img) > 1: # Already labeled integer mask
            return img.astype(np.uint16)

        if img.dtype.kind == 'f': img = (img > 0.5).astype(np.uint8)
        elif np.max(img) == 1: img = img.astype(np.uint8)

        if np.max(img) <=1 : # Binary after potential conversion
            labeled_img, num_features = ndimage.label(img)
            print(f"    Labeled binary mask {os.path.basename(filepath)}, found {num_features} features.")
            return labeled_img.astype(np.uint16)

        return img.astype(np.uint16)
    except Exception as e:
        print(f"    Error loading image {filepath}: {str(e)}"); return None

def load_data(csv_path):
    """Loads the refined data and derives sample_id if needed."""
    print(f"Loading refined data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded {len(df)} cells.")

        if 'cell_id' in df.columns:
            df['derived_sample_id'] = df['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
            print("  Derived 'derived_sample_id' from 'cell_id' for mask matching.")
        else:
            print("Error: 'cell_id' column missing. Cannot derive sample IDs for mask matching.")
            return None

        essential_cols = ['senescence_score_normalized', 'nuclei_count', 'cell_area',
                          'cell_circularity', 'nucleus_to_cell_area_ratio']
        for col_check in essential_cols:
            if col_check not in df.columns:
                print(f"Warning: Essential column '{col_check}' for rules/scoring not found.")
        return df
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}"); return None

def preprocess_features_for_ml(df, feature_columns, log_transform_cols):
    """Prepares features specifically for ML algorithms (scaling)."""
    print(f"\nPreprocessing features for ML. Selected: {feature_columns}")
    actual_features_for_ml = [col for col in feature_columns if col in df.columns]
    if not actual_features_for_ml: return None, None
    features_for_scaling_df = df[actual_features_for_ml].copy()
    for col in log_transform_cols:
        if col in features_for_scaling_df.columns:
            features_for_scaling_df[col] = np.log1p(features_for_scaling_df[col])
            print(f"  Log-transformed for scaling: {col}")
    if features_for_scaling_df.isnull().sum().any():
        features_for_scaling_df = features_for_scaling_df.fillna(features_for_scaling_df.mean())
    cols_to_drop_scaled = features_for_scaling_df.columns[features_for_scaling_df.isna().all()].tolist()
    if cols_to_drop_scaled:
        features_for_scaling_df = features_for_scaling_df.drop(columns=cols_to_drop_scaled)
        actual_features_for_ml = [f for f in actual_features_for_ml if f not in cols_to_drop_scaled]
    if features_for_scaling_df.empty or not actual_features_for_ml: return None, None
    scaler = StandardScaler(); features_scaled = scaler.fit_transform(features_for_scaling_df)
    print("  Features standardized for ML algorithms.")
    return features_scaled, actual_features_for_ml

def compute_and_plot_diffusion_map(df, scaled_features, feature_names_used, output_dir):
    """Computes and plots diffusion map."""
    if not SCANPY_AVAILABLE: print("Skipping diffusion map: Scanpy not available."); return df
    print("\n--- Computing Diffusion Map ---")
    if scaled_features is None or scaled_features.shape[0] == 0 : print("  No scaled features for Diffusion Map. Skipping."); return df
    adata = sc.AnnData(scaled_features, var=pd.DataFrame(index=feature_names_used))
    adata.obs_names = df.index.astype(str)
    if 'senescence_score_normalized' in df.columns: adata.obs['senescence_score_normalized'] = df['senescence_score_normalized'].values
    if 'cell_type_final' in df.columns: adata.obs['cell_type_final'] = df['cell_type_final'].astype('category').values
    actual_n_neighbors = min(N_NEIGHBORS_FOR_SCANPY, adata.n_obs - 1)
    if actual_n_neighbors < 2: print(f"  Not enough samples for Scanpy neighbors. Skipping."); return df
    print(f"  Computing neighbors (k={actual_n_neighbors})..."); sc.pp.neighbors(adata, n_neighbors=actual_n_neighbors, use_rep='X')
    print("  Running sc.tl.diffmap..."); sc.tl.diffmap(adata, n_comps=N_DIFFUSION_COMPONENTS)
    if 'X_diffmap' in adata.obsm:
        num_dc = min(N_DIFFUSION_COMPONENTS, adata.obsm['X_diffmap'].shape[1] - 1)
        for i in range(num_dc): df[f'dc_{i+1}'] = adata.obsm['X_diffmap'][:, i+1]
        print(f"  Added {num_dc} DCs to DataFrame.")
        pairs = [(f'dc_{i}', f'dc_{j}') for i in range(1, N_DCS_TO_PLOT + 1) for j in range(i + 1, N_DCS_TO_PLOT + 1) if f'dc_{i}' in df.columns and f'dc_{j}' in df.columns]
        for dcx, dcy in pairs:
            if df[dcx].notna().any() and df[dcy].notna().any():
                if 'senescence_score_normalized' in df.columns and df['senescence_score_normalized'].notna().any():
                    plt.figure(figsize=(10,8)); plt.scatter(df[dcx], df[dcy], c=df['senescence_score_normalized'], cmap='viridis', s=12, alpha=0.7); plt.colorbar(label='Norm. Senescence Score')
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Score'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper()); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_score.png'),dpi=300,bbox_inches='tight'); plt.close()
                if 'cell_type_final' in df.columns:
                    plt.figure(figsize=(10,8)); types=df['cell_type_final'].unique(); pal={t:('red' if t=='Senescent' else ('blue' if t=='Non-senescent' else 'grey')) for t in types}
                    for ct,col in pal.items(): subset=df[df['cell_type_final']==ct]; plt.scatter(subset[dcx],subset[dcy],label=ct,color=col,s=12,alpha=0.7)
                    plt.title(f'DiffMap ({dcx} vs {dcy}) by Prev. Classif.'); plt.xlabel(dcx.upper()); plt.ylabel(dcy.upper());
                    if pal: plt.legend(title='Previous Final Cell Type'); plt.grid(True,alpha=0.3); plt.savefig(os.path.join(output_dir, f'diffmap_{dcx}_{dcy}_by_prev_type.png'),dpi=300,bbox_inches='tight'); plt.close()
        print(f"  DiffMap pair plots for top {N_DCS_TO_PLOT} DCs saved.")
    else: print("  Error: 'X_diffmap' not found.")
    return df

def run_dbscan_and_plot(df, data_for_dbscan, data_desc, output_dir, umap_emb=None, current_eps_val=None, current_min_samples_val=None):
    """Runs DBSCAN and plots results. Uses specific eps and min_samples if provided."""
    print(f"\n--- Running DBSCAN on {data_desc} ---")
    if data_for_dbscan is None or data_for_dbscan.shape[0] == 0:
        print(f"  No data for DBSCAN on {data_desc}. Skipping.")
        df[f'dbscan_{data_desc.lower().replace(" ","_")}']=-1
        return df

    eps_to_use = current_eps_val if current_eps_val is not None else DBSCAN_EPS_DEFAULT
    min_s_to_use = current_min_samples_val if current_min_samples_val is not None else DBSCAN_MIN_SAMPLES_DEFAULT

    if ESTIMATE_DBSCAN_EPS and current_eps_val is None :
        k_est = min(K_FOR_EPS_ESTIMATION, data_for_dbscan.shape[0]-1); k_est=max(1,k_est)
        nn=NearestNeighbors(n_neighbors=k_est); nn.fit(data_for_dbscan); dists, _ = nn.kneighbors(data_for_dbscan)
        actual_k_for_dists = min(k_est, dists.shape[1])
        if actual_k_for_dists > 0:
            k_dists = dists[:,actual_k_for_dists-1]
            k_dists_sorted = np.sort(k_dists)
            plt.figure(figsize=(8,6)); plt.plot(k_dists_sorted); plt.title(f'{actual_k_for_dists}-Dist Graph for Eps ({data_desc})');
            plt.xlabel("Points sorted by distance"); plt.ylabel(f"{actual_k_for_dists}-th NN Distance (eps candidate)"); plt.grid(True,alpha=0.3);
            eps_path=os.path.join(output_dir, f'dbscan_eps_est_{data_desc.lower().replace(" ","_")}.png'); plt.savefig(eps_path,dpi=300); plt.close();
            print(f"  Saved k-dist graph: {eps_path}. PLEASE INSPECT THIS PLOT TO SET appropriate DBSCAN_EPS for {data_desc}.")
            if len(k_dists_sorted)>10:
                sug_eps=np.percentile(k_dists_sorted,90);
                print(f"  A percentile-based suggestion for eps for {data_desc} is: {sug_eps:.3f}. The script will use eps={eps_to_use} (default or passed).")
        else:
            print(f"  Could not determine k-distances for eps estimation for {data_desc}. Using eps={eps_to_use}")

    print(f"  Running DBSCAN with eps={eps_to_use}, min_samples={min_s_to_use} on {data_desc}...")
    db=DBSCAN(eps=eps_to_use,min_samples=min_s_to_use).fit(data_for_dbscan)
    clust_col=f'dbscan_{data_desc.lower().replace(" ","_")}'; df[clust_col]=db.labels_
    n_clust=len(set(db.labels_))-(1 if -1 in db.labels_ else 0); n_noise=list(db.labels_).count(-1)
    print(f"  DBSCAN on {data_desc}: {n_clust} clusters, {n_noise} noise ({n_noise/len(df)*100:.2f}%).")

    if umap_emb is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10))
        labels_unique_dbscan=sorted(df[clust_col].unique())
        n_actual_clusters = len([l for l in labels_unique_dbscan if l != -1])
        if n_actual_clusters > 0:
            dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)
        else:
            dbscan_cmap_obj = plt.cm.get_cmap('Spectral', 1)

        cdict = {}
        cluster_idx = 0
        for lbl in labels_unique_dbscan:
            if lbl == -1:
                cdict[lbl] = (0.5, 0.5, 0.5, 1)
            else:
                cdict[lbl] = dbscan_cmap_obj(cluster_idx)
                cluster_idx += 1

        for k_val in labels_unique_dbscan:
            mask=(df[clust_col]==k_val)
            xy=umap_emb[mask]
            if xy.shape[0]>0:
                 # Corrected indentation for the plt.scatter call
                 plt.scatter(xy[:,0],xy[:,1],
                             s=(20 if k_val!=-1 else 10),
                             c=[cdict[k_val]],
                             marker=('o' if k_val!=-1 else 'x'),
                             label=('Noise' if k_val == -1 else f'Cluster {k_val}'))
        plt.title(f'DBSCAN on {data_desc} (UMAP proj.)\neps={eps_to_use:.3f}, min_s={min_s_to_use}')
        plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2')
        plt.legend(title='DBSCAN Cluster',bbox_to_anchor=(1.05,1),loc='upper left',markerscale=1.5)
        plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1])
        plt.savefig(os.path.join(output_dir, f'dbscan_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
        print(f"  DBSCAN on {data_desc} plotted on UMAP.")
    return df

def run_gmm_and_plot(df, data_for_gmm, data_desc, output_dir, umap_embedding=None):
    """Runs GMM and plots results."""
    print(f"\n--- Running GMM on {data_desc} ---")
    if data_for_gmm is None: df[f'gmm_cluster_{data_desc.lower().replace(" ","_")}']=-1; df[f'gmm_prob_max_{data_desc.lower().replace(" ","_")}']=np.nan; return df
    best_gmm=None; lowest_bic=np.inf
    for n_comp in GMM_N_COMPONENTS_RANGE:
        if n_comp > data_for_gmm.shape[0]: continue
        gmm=GaussianMixture(n_components=n_comp,covariance_type=GMM_COVARIANCE_TYPE,random_state=42,n_init=5).fit(data_for_gmm); bic=gmm.bic(data_for_gmm); print(f"    GMM {n_comp} comps: BIC={bic:.2f}")
        if bic<lowest_bic: lowest_bic=bic; best_gmm=gmm
    if best_gmm is None: df[f'gmm_cluster_{data_desc.lower().replace(" ","_")}']=-1; df[f'gmm_prob_max_{data_desc.lower().replace(" ","_")}']=np.nan; return df
    print(f"  Best GMM: {best_gmm.n_components} components (BIC={lowest_bic:.2f}).")
    clust_col,prob_col = f'gmm_{data_desc.lower().replace(" ","_")}',f'gmm_prob_max_{data_desc.lower().replace(" ","_")}'
    df[clust_col]=best_gmm.predict(data_for_gmm); df[prob_col]=np.max(best_gmm.predict_proba(data_for_gmm),axis=1)
    if 'senescence_score_normalized' in df.columns: print(f"  Mean sen_score per GMM comp ({data_desc}):\n{df.groupby(clust_col)['senescence_score_normalized'].mean().sort_values()}")
    if umap_embedding is not None and 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns:
        plt.figure(figsize=(12,10)); labels=sorted(df[clust_col].unique()); pal=sns.color_palette("viridis",n_colors=len(labels))
        for i,lbl in enumerate(labels): subset=df[df[clust_col]==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=f'GMM Comp. {lbl}',color=pal[i],s=15,alpha=0.7)
        plt.title(f'GMM ({best_gmm.n_components} comp.) on {data_desc} (UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='GMM Comp.',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]); plt.savefig(os.path.join(output_dir,f'gmm_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
        plt.figure(figsize=(12,10)); scatter_prob=plt.scatter(df['umap_x_refined'],df['umap_y_refined'],c=df[prob_col],cmap='magma',s=15,alpha=0.7,vmin=0,vmax=1); plt.colorbar(scatter_prob,label='Max GMM Prob.'); plt.title(f'GMM Max Prob. on {data_desc} (UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.grid(True,alpha=0.3); plt.tight_layout(); plt.savefig(os.path.join(output_dir,f'gmm_prob_on_umap_{data_desc.lower().replace(" ","_")}.png'),dpi=300); plt.close()
    return df

def apply_rule_based_gating(df_main, rules, default_label, output_dir, umap_embedding=None):
    """Applies rules to classify cells, creates granular and binary status, and plots."""
    print("\n--- Applying Rule-Based Gating ---")
    df_main['rule_based_classification_granular'] = default_label
    all_rule_features = set(cond[0] for rule in rules for cond in rule['conditions'])
    missing_features = [feat for feat in all_rule_features if feat not in df_main.columns]
    if missing_features:
        print(f"  Error: Features for rules missing from DataFrame: {missing_features}. Skipping."); return df_main

    for rule in rules:
        print(f"  Applying rule: {rule['name']}")
        eligible_mask = (df_main['rule_based_classification_granular'] == default_label)
        if not eligible_mask.any(): print(f"    No cells eligible for rule '{rule['name']}'."); continue
        current_condition_mask = pd.Series([True] * len(df_main), index=df_main.index)
        for feature, operator, value in rule['conditions']:
            if feature not in df_main.columns: print(f"    Feature '{feature}' not found. Skipping rule."); current_condition_mask[:]=False; break
            try:
                feat_series = pd.to_numeric(df_main[feature], errors='coerce')
                if feat_series.isnull().any() and not df_main[feature].isnull().all() : print(f"    Warning: Coercion to numeric for '{feature}' created NaNs.")
                if   operator == '>':  current_condition_mask &= (feat_series > value)
                elif operator == '<':  current_condition_mask &= (feat_series < value)
                elif operator == '>=': current_condition_mask &= (feat_series >= value)
                elif operator == '<=': current_condition_mask &= (feat_series <= value)
                elif operator == '==': current_condition_mask &= (feat_series == value)
                elif operator == '!=': current_condition_mask &= (feat_series != value)
                else: print(f"    Unknown operator '{operator}'."); current_condition_mask[:]=False; break
            except Exception as e: print(f"    Error comparing '{feature}': {e}."); current_condition_mask[:]=False; break

        if isinstance(current_condition_mask, pd.Series) and not current_condition_mask.empty:
            if not current_condition_mask.any():
                print(f"    No cells met conditions for rule '{rule['name']}'.")
                continue
            cells_to_label = eligible_mask & current_condition_mask
            df_main.loc[cells_to_label, 'rule_based_classification_granular'] = rule['output_label']
            print(f"    {cells_to_label.sum()} cells labeled as '{rule['output_label']}'.")
        else:
            print(f"    Rule '{rule['name']}' resulted in an invalid condition mask or no cells met conditions. No cells labeled by this rule.")

    print(f"\nGranular rule-based counts:\n{df_main['rule_based_classification_granular'].value_counts()}")
    df_main['rule_based_binary_status'] = np.where(df_main['rule_based_classification_granular'] == default_label, 'Non-senescent', 'Senescent')
    binary_counts = df_main['rule_based_binary_status'].value_counts(normalize=True) * 100
    print(f"\nBinary rule-based classification counts:\n{df_main['rule_based_binary_status'].value_counts()}")
    print(f"Binary rule-based classification percentages:\n{binary_counts}")


    if umap_embedding is not None and 'umap_x_refined' in df_main.columns and 'umap_y_refined' in df_main.columns:
        plt.figure(figsize=(12,10)); granular_labels = sorted(df_main['rule_based_classification_granular'].unique())
        if granular_labels:
            pal_gran = sns.color_palette("Paired",n_colors=max(10,len(granular_labels)))
            for i,lbl in enumerate(granular_labels): subset=df_main[df_main['rule_based_classification_granular']==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=lbl,color=pal_gran[i%len(pal_gran)],s=15,alpha=0.7)
            plt.title('Rule-Based Gating (Granular - UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule Class (Granular)',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.80,1]); plt.savefig(os.path.join(output_dir,'rule_based_gating_granular_on_umap.png'),dpi=300); plt.close()
        plt.figure(figsize=(12,10)); binary_labels = sorted(df_main['rule_based_binary_status'].unique()); bin_pal={'Senescent':'red','Non-senescent':'blue','Unknown_Due_To_Missing_Features':'grey'}
        for lbl in binary_labels: subset=df_main[df_main['rule_based_binary_status']==lbl]; plt.scatter(subset['umap_x_refined'],subset['umap_y_refined'],label=lbl,color=bin_pal.get(lbl,'grey'),s=15,alpha=0.7)
        plt.title('Rule-Based Gating (Binary - UMAP proj.)'); plt.xlabel('UMAP 1'); plt.ylabel('UMAP 2'); plt.legend(title='Rule Status (Binary)',bbox_to_anchor=(1.05,1),loc='upper left'); plt.grid(True,alpha=0.3); plt.tight_layout(rect=[0,0,0.85,1]); plt.savefig(os.path.join(output_dir,'rule_based_gating_binary_on_umap.png'),dpi=300); plt.close()
    return df_main

def visualize_rule_classification_on_masks(df_results, cell_mask_dir, nuclei_mask_dir, output_dir_masks, classification_column='rule_based_binary_status'):
    """Visualizes a specified classification column (e.g., rule-based) on original mask images."""
    print(f"\nGenerating classification overlays on original masks using column '{classification_column}' in: {output_dir_masks}")
    if not os.path.exists(output_dir_masks): os.makedirs(output_dir_masks)

    sen_color = [255, 0, 0]
    nonsen_color = [0, 0, 255]
    nuc_outline_color = [255, 255, 0]
    cell_boundary_color = [255, 255, 255]
    unknown_color = [128, 128, 128]

    if 'derived_sample_id' not in df_results.columns or classification_column not in df_results.columns:
        print(f"Error: 'derived_sample_id' or '{classification_column}' not found. Cannot proceed."); return

    unique_samples = df_results['derived_sample_id'].unique()
    if 'cell_id' not in df_results.columns:
        print("Error: 'cell_id' column not found in df_results."); return
    classification_lookup = pd.Series(df_results[classification_column].values, index=df_results.cell_id).to_dict()

    available_cell_masks = {extract_sample_id(f): f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff'))}
    available_nuclei_masks = {extract_sample_id(f): f for f in os.listdir(nuclei_mask_dir) if f.endswith(('.tif', '.tiff'))}

    for sample_id_csv in tqdm(unique_samples, desc=f"Mask viz ({classification_column})"):
        cell_mask_filename = available_cell_masks.get(sample_id_csv)
        nuclei_mask_filename = available_nuclei_masks.get(sample_id_csv)
        if not cell_mask_filename: print(f"  Cell mask not found for {sample_id_csv}"); continue

        print(f"\n  Overlaying sample: {sample_id_csv}")
        labeled_cell_mask = load_image_as_labeled_mask(os.path.join(cell_mask_dir, cell_mask_filename))
        if labeled_cell_mask is None: continue

        overlay_image = np.zeros((labeled_cell_mask.shape[0], labeled_cell_mask.shape[1], 3), dtype=np.uint8)

        for props in measure.regionprops(labeled_cell_mask):
            full_cell_id = f"{sample_id_csv}_{props.label}"
            status = classification_lookup.get(full_cell_id, 'Unknown')

            current_fill_color = unknown_color
            if classification_column == 'rule_based_binary_status':
                if status == 'Senescent': current_fill_color = sen_color
                elif status == 'Non-senescent': current_fill_color = nonsen_color
            elif classification_column == 'rule_based_classification_granular':
                if status.startswith('Rule_Sen_'): current_fill_color = sen_color
                elif status == RULE_BASED_DEFAULT_LABEL : current_fill_color = nonsen_color
            overlay_image[labeled_cell_mask == props.label] = current_fill_color

        all_cell_boundaries = segmentation.find_boundaries(labeled_cell_mask, mode='outer', background=0)
        overlay_image[all_cell_boundaries] = cell_boundary_color

        if nuclei_mask_filename:
            labeled_nuclei_mask = load_image_as_labeled_mask(os.path.join(nuclei_mask_dir, nuclei_mask_filename))
            if labeled_nuclei_mask is not None:
                nuclei_boundaries = segmentation.find_boundaries(labeled_nuclei_mask, mode='inner', background=0)
                overlay_image[nuclei_boundaries] = nuc_outline_color

        # Initialize fig_leg and ax_leg here to ensure they are defined
        fig_leg, ax_leg = None, None
        try:
            fig_leg, ax_leg = plt.subplots(figsize=(max(10, overlay_image.shape[1]/150), max(8, overlay_image.shape[0]/150)), dpi=100) # Reduced divisor for figsize
            ax_leg.imshow(overlay_image)

            handles = []
            # Simplified legend for binary status
            if classification_column == 'rule_based_binary_status' or classification_column == 'rule_based_classification_granular': # Common legend items
                handles.append(mpatches.Patch(color=np.array(sen_color)/255., label='Senescent (by Rule)'))
                handles.append(mpatches.Patch(color=np.array(nonsen_color)/255., label='Non-senescent (by Rule)'))

            # Specific legend for granular if needed (can be complex if many rule labels)
            # For now, the binary legend is more general.

            if 'Unknown' in df_results[classification_column].unique() or 'Unknown' in classification_lookup.values():
                 handles.append(mpatches.Patch(color=np.array(unknown_color)/255., label='Unknown'))
            if nuclei_mask_filename and labeled_nuclei_mask is not None:
                handles.append(mpatches.Patch(color=np.array(nuc_outline_color)/255., label='Nuclei Outline'))
            handles.append(mpatches.Patch(edgecolor='white', facecolor='none', label='Cell Boundary', linewidth=1))

            ax_leg.legend(handles=handles, loc='upper right', fontsize='xx-small', bbox_to_anchor=(1.65, 1)); # Adjusted bbox, smaller font
            ax_leg.axis('off');
            plt.tight_layout(pad=0.2)
            output_filename = os.path.join(output_dir_masks, f"{sample_id_csv}_{classification_column}_overlay.png")
            plt.savefig(output_filename, dpi=200); # Increased DPI for better quality
            print(f"    Saved overlay for {sample_id_csv} to {output_filename}")
        except Exception as e_plot:
            print(f"    Error during plotting/saving mask overlay for {sample_id_csv}: {e_plot}")
        finally:
            if fig_leg: # Ensure figure is closed even if error occurs after its creation
                plt.close(fig_leg)

    print(f"Mask overlay visualization for '{classification_column}' complete.")


def main_exploratory_analysis():
    """Main function to run exploratory analysis."""
    if not os.path.exists(EXPLORATORY_OUTPUT_DIR):
        os.makedirs(EXPLORATORY_OUTPUT_DIR)
        print(f"Created output directory: {EXPLORATORY_OUTPUT_DIR}")

    df = load_data(INPUT_REFINED_CSV_PATH)
    if df is None: return

    scaled_features, feature_names_used_for_scaling = preprocess_features_for_ml(df, FEATURES_FOR_ANALYSIS, AREA_FEATURES_TO_LOG)

    umap_embedding_for_plotting = None
    if 'umap_x_refined' in df.columns and 'umap_y_refined' in df.columns and df['umap_x_refined'].notna().all():
        print("\nUsing existing UMAP coordinates from input CSV for visualizations.")
        umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
    elif scaled_features is not None:
        print("\nRecomputing UMAP for visualization...")
        try:
            reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42, n_components=2)
            embedding = reducer.fit_transform(scaled_features)
            df['umap_x_refined'] = embedding[:, 0]; df['umap_y_refined'] = embedding[:, 1]
            umap_embedding_for_plotting = df[['umap_x_refined', 'umap_y_refined']].values
            print("  UMAP recomputed.")
        except Exception as e: print(f"  Error recomputing UMAP: {e}.")
    else: print("\nSkipping UMAP computation as scaled features are unavailable.")

    if scaled_features is not None and feature_names_used_for_scaling is not None:
        df = compute_and_plot_diffusion_map(df, scaled_features, feature_names_used_for_scaling, EXPLORATORY_OUTPUT_DIR)

        DBSCAN_EPS_FOR_SCALED_FEATURES = 2.3
        DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES = 10
        print(f"\nNOTE: For DBSCAN on Scaled Features, using EPS = {DBSCAN_EPS_FOR_SCALED_FEATURES}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES}")
        df = run_dbscan_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR,
                                 umap_emb=umap_embedding_for_plotting,
                                 current_eps_val=DBSCAN_EPS_FOR_SCALED_FEATURES,
                                 current_min_samples_val=DBSCAN_MIN_SAMPLES_FOR_SCALED_FEATURES)

        df = run_gmm_and_plot(df, scaled_features, "Scaled_Features", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

        if RUN_DBSCAN_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols = [f'dc_{i+1}' for i in range(N_DCS_FOR_DBSCAN) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols:
                data_dc = df[dc_cols].values
                DBSCAN_EPS_FOR_DCS = 0.01
                DBSCAN_MIN_SAMPLES_FOR_DCS = 10
                print(f"\nNOTE: For DBSCAN on Top DCs, using EPS = {DBSCAN_EPS_FOR_DCS}, MIN_SAMPLES = {DBSCAN_MIN_SAMPLES_FOR_DCS}")
                df = run_dbscan_and_plot(df, data_dc, f"Top_{len(dc_cols)}_DCs", EXPLORATORY_OUTPUT_DIR,
                                         umap_emb=umap_embedding_for_plotting,
                                         current_eps_val=DBSCAN_EPS_FOR_DCS,
                                         current_min_samples_val=DBSCAN_MIN_SAMPLES_FOR_DCS)
            else: print(f"\nSkipping DBSCAN on DCs: Not enough valid DC columns.")

        global RUN_GMM_ON_DIFFMAP, N_DCS_FOR_GMM
        if RUN_GMM_ON_DIFFMAP and SCANPY_AVAILABLE:
            dc_cols_gmm = [f'dc_{i+1}' for i in range(N_DCS_FOR_GMM) if f'dc_{i+1}' in df.columns and df[f'dc_{i+1}'].notna().any()]
            if dc_cols_gmm:
                data_dc_gmm = df[dc_cols_gmm].values
                df = run_gmm_and_plot(df, data_dc_gmm, f"Top_{len(dc_cols_gmm)}_DCs", EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)
            else: print(f"\nSkipping GMM on DCs: Not enough valid DC columns.")

    # Apply rule-based gating using the main df.
    df = apply_rule_based_gating(df, RULE_BASED_GATES, RULE_BASED_DEFAULT_LABEL, EXPLORATORY_OUTPUT_DIR, umap_embedding=umap_embedding_for_plotting)

    # Visualize the rule-based classification on masks
    mask_overlay_output_path = os.path.join(EXPLORATORY_OUTPUT_DIR, MASK_VISUALIZATION_SUBDIR_RULES)
    if 'rule_based_binary_status' in df.columns:
        visualize_rule_classification_on_masks(df, CELL_MASK_DIR, NUCLEI_MASK_DIR, mask_overlay_output_path, classification_column='rule_based_binary_status')
        # visualize_rule_classification_on_masks(df, CELL_MASK_DIR, NUCLEI_MASK_DIR, mask_overlay_output_path, classification_column='rule_based_classification_granular')
    else:
        print("Skipping mask visualization for rule-based classification as the classification column is missing.")


    exploratory_csv_path = os.path.join(EXPLORATORY_OUTPUT_DIR, 'exploratory_analysis_results_v5_rules_maskviz.csv')
    df.to_csv(exploratory_csv_path, index=False)
    print(f"\nExploratory analysis results saved to: {exploratory_csv_path}")
    print("\nExploratory analysis script finished.")

RUN_GMM_ON_DIFFMAP = True
N_DCS_FOR_GMM = 3

if __name__ == '__main__':
    main_exploratory_analysis()


Loading refined data from /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Senescence_Refined_V5_DiffMap/cell_classification_results_refined.csv...
Successfully loaded 2472 cells.
  Derived 'derived_sample_id' from 'cell_id' for mask matching.

Preprocessing features for ML. Selected: ['cell_area', 'cell_perimeter', 'cell_eccentricity', 'cell_circularity', 'cell_aspect_ratio', 'avg_nucleus_area', 'max_nucleus_area', 'avg_nucleus_eccentricity', 'nucleus_area_std', 'nucleus_displacement', 'nucleus_to_cell_area_ratio', 'nuclear_enlargement', 'cell_enlargement']
  Log-transformed for scaling: cell_area
  Log-transformed for scaling: avg_nucleus_area
  Log-transformed for scaling: max_nucleus_area
  Log-transformed for scaling: cell_perimeter
  Features standardized for ML algorithms.

Using existing UMAP coordinates from input CSV for visualizations.

--- Computing Diffusion Map ---
  Computing neighbors (k=15)...
  Running sc.tl.diffmap...
  Added 9 DCs to Data

  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)


  DBSCAN on Scaled_Features plotted on UMAP.

--- Running GMM on Scaled_Features ---


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 2 comps: BIC=-11876.19


  # that has no feature names.
  # that has no feature names.


    GMM 3 comps: BIC=-20370.16


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 4 comps: BIC=-26836.71
  Best GMM: 4 components (BIC=-26836.71).
  Mean sen_score per GMM comp (Scaled_Features):
gmm_scaled_features
1    0.316948
0    0.365966
3    0.500682
2    0.500958
Name: senescence_score_normalized, dtype: float64

NOTE: For DBSCAN on Top DCs, using EPS = 0.01, MIN_SAMPLES = 10

--- Running DBSCAN on Top_3_DCs ---
  Running DBSCAN with eps=0.01, min_samples=10 on Top_3_DCs...
  DBSCAN on Top_3_DCs: 2 clusters, 42 noise (1.70%).


  dbscan_cmap_obj = plt.cm.get_cmap('Spectral', n_actual_clusters)


  DBSCAN on Top_3_DCs plotted on UMAP.

--- Running GMM on Top_3_DCs ---


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 2 comps: BIC=-42744.54
    GMM 3 comps: BIC=-43790.09


  # that has no feature names.
  # that has no feature names.
  # that has no feature names.


    GMM 4 comps: BIC=-44590.14
  Best GMM: 4 components (BIC=-44590.14).
  Mean sen_score per GMM comp (Top_3_DCs):
gmm_top_3_dcs
0    0.319583
1    0.398703
2    0.427525
3    0.543944
Name: senescence_score_normalized, dtype: float64

--- Applying Rule-Based Gating ---
  Applying rule: Polynucleated
    225 cells labeled as 'Rule_Sen_Poly'.
  Applying rule: Very_Large_Cell
    307 cells labeled as 'Rule_Sen_VeryLarge'.
  Applying rule: Low_Circularity
    3 cells labeled as 'Rule_Sen_LowCirc'.
  Applying rule: Low_NucToCellRatio
    103 cells labeled as 'Rule_Sen_LowNucRatio'.
  Applying rule: High_Score_Not_Otherwise_Caught
    0 cells labeled as 'Rule_Sen_HighScore'.

Granular rule-based counts:
rule_based_classification_granular
Rule_NonSenescent       1834
Rule_Sen_VeryLarge       307
Rule_Sen_Poly            225
Rule_Sen_LowNucRatio     103
Rule_Sen_LowCirc           3
Name: count, dtype: int64

Binary rule-based classification counts:
rule_based_binary_status
Non-senescent    1

Mask viz (rule_based_binary_status):   0%|          | 0/8 [00:00<?, ?it/s]


  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq001
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq001_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq001_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  12%|█▎        | 1/8 [00:01<00:12,  1.74s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/0Pa_U_05mar19_20x_L2RA_Flat_seq001_rule_based_binary_status_overlay.png

  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq002
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq002_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq002_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  25%|██▌       | 2/8 [00:03<00:11,  1.84s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/0Pa_U_05mar19_20x_L2RA_Flat_seq002_rule_based_binary_status_overlay.png

  Overlaying sample: 0Pa_U_05mar19_20x_L2RA_Flat_seq003
    Loading mask: 0Pa_U_05mar19_20x_L2RA_Flat_seq003_cell_mask_merged_conservative.tif
    Loading mask: denoised_0Pa_U_05mar19_20x_L2RA_Flat_seq003_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  38%|███▊      | 3/8 [00:05<00:09,  1.85s/it]

    Saved overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/0Pa_U_05mar19_20x_L2RA_Flat_seq003_rule_based_binary_status_overlay.png

  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq001_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  50%|█████     | 4/8 [00:07<00:07,  1.77s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/1.4Pa_U_05mar19_20x_L2R_Flat_seq001_rule_based_binary_status_overlay.png

  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq002_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  62%|██████▎   | 5/8 [00:08<00:04,  1.63s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/1.4Pa_U_05mar19_20x_L2R_Flat_seq002_rule_based_binary_status_overlay.png

  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq003_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  75%|███████▌  | 6/8 [00:10<00:03,  1.65s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/1.4Pa_U_05mar19_20x_L2R_Flat_seq003_rule_based_binary_status_overlay.png

  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq004_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status):  88%|████████▊ | 7/8 [00:12<00:01,  1.90s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq004 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/1.4Pa_U_05mar19_20x_L2R_Flat_seq004_rule_based_binary_status_overlay.png

  Overlaying sample: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005
    Loading mask: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005_cell_mask_merged_conservative.tif
    Loading mask: denoised_1.4Pa_U_05mar19_20x_L2R_Flat_seq005_Cadherins_filtered_mask.tif


Mask viz (rule_based_binary_status): 100%|██████████| 8/8 [00:14<00:00,  1.86s/it]

    Saved overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq005 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/mask_overlays_rule_based/1.4Pa_U_05mar19_20x_L2R_Flat_seq005_rule_based_binary_status_overlay.png
Mask overlay visualization for 'rule_based_binary_status' complete.

Exploratory analysis results saved to: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/exploratory_analysis_results_v5_rules_maskviz.csv

Exploratory analysis script finished.





In [30]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from skimage import io, measure, segmentation
import cv2
from scipy import ndimage
from tqdm import tqdm

# --- Configuration & Parameters ---
# !! UPDATE THESE PATHS !!
# Path to the CSV file generated by the exploratory_analysis_script
# This CSV must contain 'derived_sample_id', 'cell_id', and 'rule_based_classification_granular'
INPUT_EXPLORATORY_CSV_PATH = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Exploratory_Analysis_V5_rules_mask_viz/exploratory_analysis_results_v5_rules_maskviz.csv"

# Output directory for these new enhanced mask visualizations
ENHANCED_MASK_OUTPUT_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix" # Incremented

# Paths to your original mask image directories
CELL_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Cell_merged_conservative"
NUCLEI_MASK_DIR = "/content/drive/MyDrive/knowledge/University/Master/Thesis/Segmented/flow3-x20/Nuclei"

# Label from your RULE_BASED_GATES that specifically identifies polynucleated senescent cells
POLYNUCLEATED_SENESCENT_RULE_LABEL = 'Rule_Sen_Poly'
# Default label for cells not caught by any "Senescent" rule
NON_SENESCENT_DEFAULT_RULE_LABEL = 'Rule_NonSenescent'


# --- Color Definitions ---
COLOR_NON_SENESCENT = [0, 0, 255]  # Blue
COLOR_SENESCENT_POLYNUCLEATED = [255, 165, 0]  # Orange
COLOR_SENESCENT_OTHER_RULES = [255, 0, 0]  # Red
COLOR_CELL_BOUNDARY = [255, 255, 255]  # White
COLOR_NUCLEI_OUTLINE = [255, 255, 0]  # Yellow
COLOR_UNKNOWN = [128, 128, 128] # Grey

# --- Helper Functions ---
def extract_sample_id(filename):
    """Extracts sample ID from filename."""
    base_name = os.path.splitext(filename)[0]
    if base_name.startswith('denoised_'):
        base_name = base_name[len('denoised_'):]
    pattern = re.compile(r'([\d\.]+Pa_[^_]+_[^_]+_[^_]+_[^_]+_[^_]+_seq\d+)')
    match = pattern.search(base_name)
    if match: return match.group(1)
    parts = base_name.split('_')
    for i, part in enumerate(parts):
        if part.startswith('seq') and i >= 2: return '_'.join(parts[:i+1])
    common_prefix = "_".join(filename.split('_')[:6])
    return common_prefix if 'seq' in common_prefix else os.path.splitext(os.path.basename(filename))[0]

def load_image_as_labeled_mask(filepath):
    """Loads a mask image, ensuring it's a labeled integer mask."""
    # print(f"    Loading mask: {os.path.basename(filepath)}") # Reduced verbosity
    try:
        img = io.imread(filepath)
        if img.ndim > 2:
            if img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            elif img.shape[-1] == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
            else: img = img[..., 0]
        if img.dtype.kind in 'iu' and np.max(img) > 1: return img.astype(np.uint16)
        if img.dtype.kind == 'f': img = (img > 0.5).astype(np.uint8)
        elif np.max(img) == 1: img = img.astype(np.uint8)
        if np.max(img) <=1 :
            labeled_img, num_features = ndimage.label(img)
            # print(f"    Labeled binary mask {os.path.basename(filepath)}, found {num_features} features.")
            return labeled_img.astype(np.uint16)
        return img.astype(np.uint16)
    except Exception as e:
        print(f"    Error loading image {filepath}: {str(e)}"); return None

def visualize_enhanced_classification_on_masks(df_results, cell_mask_dir, nuclei_mask_dir, output_dir):
    """
    Visualizes cell classifications on original mask images, highlighting polynucleated senescent cells.
    Uses 'rule_based_classification_granular' for detailed coloring.
    """
    print(f"\nGenerating enhanced classification overlays in: {output_dir}")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    required_cols = ['derived_sample_id', 'cell_id', 'rule_based_classification_granular']
    if not all(col in df_results.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df_results.columns]
        print(f"Error: DataFrame is missing required columns: {missing}. Cannot proceed with visualization.")
        return

    classification_lookup = pd.Series(
        df_results.rule_based_classification_granular.values,
        index=df_results.cell_id
    ).to_dict()

    unique_samples = df_results['derived_sample_id'].unique()
    available_cell_masks = {extract_sample_id(f): f for f in os.listdir(cell_mask_dir) if f.endswith(('.tif', '.tiff'))}
    available_nuclei_masks = {extract_sample_id(f): f for f in os.listdir(nuclei_mask_dir) if f.endswith(('.tif', '.tiff'))}

    for sample_id_csv in tqdm(unique_samples, desc="Generating Enhanced Mask Overlays"):
        cell_mask_filename = available_cell_masks.get(sample_id_csv)
        nuclei_mask_filename = available_nuclei_masks.get(sample_id_csv)

        if not cell_mask_filename:
            print(f"  Warning: Cell mask file not found for sample ID: {sample_id_csv}")
            continue

        print(f"\n  Processing sample for enhanced overlay: {sample_id_csv}")
        labeled_cell_mask = load_image_as_labeled_mask(os.path.join(cell_mask_dir, cell_mask_filename))
        if labeled_cell_mask is None:
            continue

        overlay_image = np.zeros((labeled_cell_mask.shape[0], labeled_cell_mask.shape[1], 3), dtype=np.uint8)

        # 1. Fill cells based on granular rule classification
        for props in measure.regionprops(labeled_cell_mask):
            full_cell_id = f"{sample_id_csv}_{props.label}"
            granular_status = classification_lookup.get(full_cell_id, 'Unknown_Rule')

            current_fill_color = COLOR_UNKNOWN
            if granular_status == POLYNUCLEATED_SENESCENT_RULE_LABEL:
                current_fill_color = COLOR_SENESCENT_POLYNUCLEATED
            elif granular_status.startswith('Rule_Sen_'):
                current_fill_color = COLOR_SENESCENT_OTHER_RULES
            elif granular_status == NON_SENESCENT_DEFAULT_RULE_LABEL:
                current_fill_color = COLOR_NON_SENESCENT

            overlay_image[labeled_cell_mask == props.label] = current_fill_color

        # 2. Draw cell boundaries
        all_cell_boundaries = segmentation.find_boundaries(labeled_cell_mask, mode='outer', background=0)
        overlay_image[all_cell_boundaries] = COLOR_CELL_BOUNDARY

        # 3. Overlay nuclei outlines
        labeled_nuclei_mask = None # Initialize for the finally block
        if nuclei_mask_filename:
            labeled_nuclei_mask = load_image_as_labeled_mask(os.path.join(nuclei_mask_dir, nuclei_mask_filename))
            if labeled_nuclei_mask is not None:
                nuclei_boundaries = segmentation.find_boundaries(labeled_nuclei_mask, mode='inner', background=0)
                overlay_image[nuclei_boundaries] = COLOR_NUCLEI_OUTLINE

        # Save the overlay image with legend
        fig_legend, ax_legend_obj = None, None # Use a different name to avoid confusion if ax_leg was a typo elsewhere
        try:
            img_h, img_w = overlay_image.shape[:2]
            fig_w = max(10, img_w / 100 if img_w > 0 else 10)
            fig_h = max(8, img_h / 100 if img_h > 0 else 8) * (fig_w / (img_w/100 if img_w > 0 else 1))

            fig_legend, ax_legend_obj = plt.subplots(figsize=(fig_w, fig_h), dpi=100)
            ax_legend_obj.imshow(overlay_image)

            handles = [
                mpatches.Patch(color=np.array(COLOR_NON_SENESCENT)/255., label='Non-senescent (Rule)'),
                mpatches.Patch(color=np.array(COLOR_SENESCENT_OTHER_RULES)/255., label='Senescent (Other Rules)'),
                mpatches.Patch(color=np.array(COLOR_SENESCENT_POLYNUCLEATED)/255., label=f'Senescent ({POLYNUCLEATED_SENESCENT_RULE_LABEL})')
            ]
            # Check if 'Unknown_Rule' actually occurred for this specific sample or globally in the lookup
            # This avoids adding 'Unknown' to legend if no cells were actually unknown.
            unknown_present_in_sample = any(classification_lookup.get(f"{sample_id_csv}_{p.label}", 'Unknown_Rule') == 'Unknown_Rule'
                                            for p in measure.regionprops(labeled_cell_mask))
            if unknown_present_in_sample:
                 handles.append(mpatches.Patch(color=np.array(COLOR_UNKNOWN)/255., label='Unknown/Not in CSV'))

            if nuclei_mask_filename and labeled_nuclei_mask is not None:
                handles.append(mpatches.Patch(color=np.array(COLOR_NUCLEI_OUTLINE)/255., label='Nuclei Outline'))
            handles.append(mpatches.Patch(edgecolor=np.array(COLOR_CELL_BOUNDARY)/255., facecolor='none', label='Cell Boundary', linewidth=1))

            # CORRECTED TYPO: ax_legend_obj instead of ax_leg
            ax_legend_obj.legend(handles=handles, loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize='small')
            ax_legend_obj.axis('off')
            plt.tight_layout(rect=[0, 0, 0.85, 1])

            output_filename = os.path.join(output_dir, f"{sample_id_csv}_enhanced_overlay.png")
            plt.savefig(output_filename, dpi=200)
            print(f"    Saved enhanced overlay for {sample_id_csv} to {output_filename}")
        except Exception as e_plot:
            print(f"    Error during plotting/saving mask overlay for {sample_id_csv}: {e_plot}")
        finally:
            if fig_legend:
                plt.close(fig_legend)

    print("\nEnhanced mask overlay visualization complete.")


def main():
    """Main function to run the enhanced mask visualization."""
    if not os.path.exists(ENHANCED_MASK_OUTPUT_DIR):
        os.makedirs(ENHANCED_MASK_OUTPUT_DIR)
        print(f"Created output directory: {ENHANCED_MASK_OUTPUT_DIR}")

    df_results = pd.read_csv(INPUT_EXPLORATORY_CSV_PATH)
    if df_results is None: # Should be if df_results is None, not if pd.read_csv is None
        print(f"Failed to load data from {INPUT_EXPLORATORY_CSV_PATH}")
        return

    if 'derived_sample_id' not in df_results.columns and 'cell_id' in df_results.columns:
        df_results['derived_sample_id'] = df_results['cell_id'].apply(lambda x: '_'.join(x.split('_')[:-1]))
        print("  Derived 'derived_sample_id' from 'cell_id' for mask matching (main).")


    visualize_enhanced_classification_on_masks(
        df_results,
        CELL_MASK_DIR,
        NUCLEI_MASK_DIR,
        ENHANCED_MASK_OUTPUT_DIR
    )

    print("\nScript finished.")

if __name__ == '__main__':
    main()


Created output directory: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix

Generating enhanced classification overlays in: /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix


Generating Enhanced Mask Overlays:   0%|          | 0/8 [00:00<?, ?it/s]


  Processing sample for enhanced overlay: 0Pa_U_05mar19_20x_L2RA_Flat_seq001


Generating Enhanced Mask Overlays:  12%|█▎        | 1/8 [00:01<00:12,  1.77s/it]

    Saved enhanced overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/0Pa_U_05mar19_20x_L2RA_Flat_seq001_enhanced_overlay.png

  Processing sample for enhanced overlay: 0Pa_U_05mar19_20x_L2RA_Flat_seq002


Generating Enhanced Mask Overlays:  25%|██▌       | 2/8 [00:03<00:10,  1.81s/it]

    Saved enhanced overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/0Pa_U_05mar19_20x_L2RA_Flat_seq002_enhanced_overlay.png

  Processing sample for enhanced overlay: 0Pa_U_05mar19_20x_L2RA_Flat_seq003


Generating Enhanced Mask Overlays:  38%|███▊      | 3/8 [00:05<00:09,  1.88s/it]

    Saved enhanced overlay for 0Pa_U_05mar19_20x_L2RA_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/0Pa_U_05mar19_20x_L2RA_Flat_seq003_enhanced_overlay.png

  Processing sample for enhanced overlay: 1.4Pa_U_05mar19_20x_L2R_Flat_seq001


Generating Enhanced Mask Overlays:  50%|█████     | 4/8 [00:07<00:07,  1.76s/it]

    Saved enhanced overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq001 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/1.4Pa_U_05mar19_20x_L2R_Flat_seq001_enhanced_overlay.png

  Processing sample for enhanced overlay: 1.4Pa_U_05mar19_20x_L2R_Flat_seq002


Generating Enhanced Mask Overlays:  62%|██████▎   | 5/8 [00:09<00:05,  1.84s/it]

    Saved enhanced overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq002 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/1.4Pa_U_05mar19_20x_L2R_Flat_seq002_enhanced_overlay.png

  Processing sample for enhanced overlay: 1.4Pa_U_05mar19_20x_L2R_Flat_seq003


Generating Enhanced Mask Overlays:  75%|███████▌  | 6/8 [00:11<00:04,  2.06s/it]

    Saved enhanced overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq003 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/1.4Pa_U_05mar19_20x_L2R_Flat_seq003_enhanced_overlay.png

  Processing sample for enhanced overlay: 1.4Pa_U_05mar19_20x_L2R_Flat_seq004


Generating Enhanced Mask Overlays:  88%|████████▊ | 7/8 [00:13<00:01,  1.91s/it]

    Saved enhanced overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq004 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/1.4Pa_U_05mar19_20x_L2R_Flat_seq004_enhanced_overlay.png

  Processing sample for enhanced overlay: 1.4Pa_U_05mar19_20x_L2R_Flat_seq005


Generating Enhanced Mask Overlays: 100%|██████████| 8/8 [00:14<00:00,  1.87s/it]

    Saved enhanced overlay for 1.4Pa_U_05mar19_20x_L2R_Flat_seq005 to /content/drive/MyDrive/knowledge/University/Master/Thesis/Analysis/flow3-x20/Enhanced_Mask_Overlays_Poly_Highlight_Fix/1.4Pa_U_05mar19_20x_L2R_Flat_seq005_enhanced_overlay.png

Enhanced mask overlay visualization complete.

Script finished.



